在强化学习领域,优化游戏代理的表现是一个常见且有趣的任务。本文将探讨如何通过减少不必要的动作抖动和进行参数网格搜索来优化Atari Breakout游戏代理。
在Breakout游戏中,代理的表现并不总是平滑的。例如,球拍经常无缘无故地来回抖动。为了减少这种不必要的动作,可以通过惩罚代理来进行优化。
import gym
import ray
from gym import Wrapper
from ray import tune
from ray.rllib.agents.impala import ImpalaTrainer
from ray.tune.registry import register_env
ray.shutdown()
ray.init(include_webui=False, ignore_reinit_error=True)
class PenaliseMovement(Wrapper):
def __init__(self, env):
super().__init__(env)
self._call_count = 0
def step(self, action):
observation, reward, done, info = super().step(action)
if reward > 1.0:
reward = 1.0
threshold = 375_000
if self._call_count >= threshold and action not in (0, 3):
multiplier = min((self._call_count - threshold) / 100_000, 1.0)
reward -= 0.0001 * multiplier
self._call_count += 1
return observation, reward, done, info
def env_creator(env_config):
env = gym.make('BreakoutNoFrameskip-v4')
env = PenaliseMovement(env)
return env
register_env("penalise_movement_breakout", env_creator)
ENV = "penalise_movement_breakout"
TARGET_REWARD = 200
TRAINER = ImpalaTrainer
tune.run(
TRAINER,
stop={
"episode_reward_mean": TARGET_REWARD},
config={
"env": ENV,
"monitor": True,
"evaluation_num_episodes": 25,
"rollout_fragment_length": 50,
"train_batch_size": 500,
"num_workers": 7,
"num_envs_per_worker": 5,
"clip_rewards": False,
"lr_schedule": [
[0, 0.0005],
[20_000_000, 0.000000000001],
],
}
)
通过环境包装器引入一个小的负奖励,用于惩罚代理执行非NO-OP(0)和FIRE(3)的动作。学习挑战变得更难,因为移动的负奖励是立即给出的,而击中砖块的正奖励则是在球弹回球拍的动作之后给出的。
在本系列中,一直在使用RLlib的tune函数作为方便的训练运行方式,但实际上并没有用它来调整参数。下面是一个进行简单网格搜索的例子。
import ray
from ray import tune
from ray.rllib.agents.dqn import DQNTrainer
ray.shutdown()
ray.init(
include_webui=False,
ignore_reinit_error=True,
object_store_memory=8 * 1024 * 1024 * 1024 # 8GB
)
ENV = 'CartPole-v0'
TRAINER = DQNTrainer
analysis = tune.run(
TRAINER,
stop={
"training_iteration": 5
},
config={
"env": ENV,
"num_workers": 0,
"num_gpus": 0,
"monitor": False,
"lr": tune.grid_search([0.001, 0.0003, 0.0001]),
"hiddens": tune.grid_search([[256], [128], [200, 100]]),
},
num_samples=2
)
print("Best config: ", analysis.get_best_config(metric="episode_reward_mean"))
df = analysis.dataframe()
print(df)