在本文中,将探讨如何使用强化学习(Reinforcement Learning, RL)来训练一个神经网络,使其能够掌握一款类似Breakout的游戏。强化学习是一种强大的工具,它可以帮助机器学习算法在自动驾驶汽车、股票交易等多个领域取得积极成果。
首先,需要定义一些基本概念:
有了这些词汇,可以这样描述强化学习:它涉及训练一个策略,使代理能够通过在其环境中采取最优行动来最大化奖励。注意,“不采取行动”通常是可用行动之一。
本文结束时,应该能够运行并完成第一块强化学习代码。运行代码有几种不同的选项:
将使用Ray项目的RLlib框架。要在没有GUI的系统上启用RLlib记录训练进度,可以安装一个虚拟显示。在远程Linux终端和托管笔记本中使用了以下命令(在后一种情况下,每行命令前都加上感叹号):
apt-get install -y xvfb x11-utils
pip install pyvirtualdisplay==0.2.* PyOpenGL==3.1.* PyOpenGL-accelerate==3.1.*
import pyvirtualdisplay
_display = pyvirtualdisplay.Display(visible=False, size=(1400, 900))
_ = _display.start()
将使用OpenAI Gym来提供学习环境。第一个是cartpole。这个环境包含一个带有轮子的推车,推车上平衡着一个垂直的杆。杆是不稳定的,倾向于倒下。代理可以根据对杆状态的观察来移动推车,并根据杆平衡的时间长度获得奖励。
将使用的总体框架是Ray/RLlib。安装这个(例如,使用pip install ray[rllib]==0.8.5或通过Anaconda)将带来其依赖项,包括OpenAI Gym。
也许最好的做法是启动每个新环境并查看一下。如果正在运行一个带有图形显示的环境,可以直接“玩”这个环境:
import gym
from gym.utils.play import play
env = gym.make("CartPole-v0")
play(env, zoom=4)
将直接开始训练一个代理来解决这个环境。将在下一篇文章中深入探讨细节。
import ray
from ray import tune
from ray.rllib.agents.dqn import DQNTrainer
ray.shutdown()
ray.init(include_webui=False, ignore_reinit_error=True, object_store_memory=8*1024*1024*1024) # 8GB limit … feel free to increase this if you can
ENV = 'CartPole-v0'
TARGET_REWARD = 195
TRAINER = DQNTrainer
tune.run(
TRAINER,
stop={
"episode_reward_mean": TARGET_REWARD},
config={
"env": ENV,
"num_workers": 0,
"num_gpus": 0,
"monitor": True,
"evaluation_num_episodes": 25,
}
)
应该看到很多输出。每个批次的最后一行显示了状态,包括获得的平均奖励。继续训练,直到这个奖励达到环境的目标195。进度不是线性的,所以奖励可能会在接近目标后再次回落。要有耐心;它应该在不久的将来到达那里。在电脑上,这不到15分钟。
195?恭喜!现在已经训练了第一个强化学习模型!
告诉Ray存储进度的快照,它将把快照放在home目录下的ray_results文件夹中。应该看到很多mp4视频。如果在托管笔记本中运行,请参见下一节;否则,可以跳过它。
视频文件应该已经创建,但没有简单的方法来查看它们。编写了一些辅助代码来解决这个问题:
from base64 import b64encode
from pathlib import Path
from typing import List
OUT_PATH = Path('/root/ray_results/')
def latest_experiment() -> Path:
"""
Get the path of the results directory of the most recent training run.
"""
experiment_dirs = []
for algorithm in OUT_PATH.iterdir():
if not algorithm.is_dir():
continue
for experiment in algorithm.iterdir():
if not experiment.is_dir():
continue
experiment_dirs.append((experiment.stat().st_mtime, experiment))
return max(experiment_dirs)[1]
def latest_videos() -> List[Path]:
return list(sorted(latest_experiment().glob('*.mp4')))
def render_mp4(videopath: Path) -> str:
mp4 = open(videopath, 'rb').read()
base64_encoded_mp4 = b64encode(mp4).decode()
return f'{videopath.name}
'
from IPython.display import HTML
html = render_mp4(latest_videos()[-1])
HTML(html)
将此添加到笔记本单元格中应该会渲染最新的视频。