强化学习:优化Breakout游戏代理

强化学习领域,优化游戏代理的表现是一个常见且有趣的任务。本文将探讨如何通过减少不必要的动作抖动和进行参数网格搜索来优化Atari Breakout游戏代理。

减少动作抖动

Breakout游戏中,代理的表现并不总是平滑的。例如,球拍经常无缘无故地来回抖动。为了减少这种不必要的动作,可以通过惩罚代理来进行优化

import gym import ray from gym import Wrapper from ray import tune from ray.rllib.agents.impala import ImpalaTrainer from ray.tune.registry import register_env ray.shutdown() ray.init(include_webui=False, ignore_reinit_error=True) class PenaliseMovement(Wrapper): def __init__(self, env): super().__init__(env) self._call_count = 0 def step(self, action): observation, reward, done, info = super().step(action) if reward > 1.0: reward = 1.0 threshold = 375_000 if self._call_count >= threshold and action not in (0, 3): multiplier = min((self._call_count - threshold) / 100_000, 1.0) reward -= 0.0001 * multiplier self._call_count += 1 return observation, reward, done, info def env_creator(env_config): env = gym.make('BreakoutNoFrameskip-v4') env = PenaliseMovement(env) return env register_env("penalise_movement_breakout", env_creator) ENV = "penalise_movement_breakout" TARGET_REWARD = 200 TRAINER = ImpalaTrainer tune.run( TRAINER, stop={ "episode_reward_mean": TARGET_REWARD}, config={ "env": ENV, "monitor": True, "evaluation_num_episodes": 25, "rollout_fragment_length": 50, "train_batch_size": 500, "num_workers": 7, "num_envs_per_worker": 5, "clip_rewards": False, "lr_schedule": [ [0, 0.0005], [20_000_000, 0.000000000001], ], } )

通过环境包装器引入一个小的负奖励,用于惩罚代理执行非NO-OP(0)和FIRE(3)的动作。学习挑战变得更难,因为移动的负奖励是立即给出的,而击中砖块的正奖励则是在球弹回球拍的动作之后给出的。

参数网格搜索

在本系列中,一直在使用RLlib的tune函数作为方便的训练运行方式,但实际上并没有用它来调整参数。下面是一个进行简单网格搜索的例子。

import ray from ray import tune from ray.rllib.agents.dqn import DQNTrainer ray.shutdown() ray.init( include_webui=False, ignore_reinit_error=True, object_store_memory=8 * 1024 * 1024 * 1024 # 8GB ) ENV = 'CartPole-v0' TRAINER = DQNTrainer analysis = tune.run( TRAINER, stop={ "training_iteration": 5 }, config={ "env": ENV, "num_workers": 0, "num_gpus": 0, "monitor": False, "lr": tune.grid_search([0.001, 0.0003, 0.0001]), "hiddens": tune.grid_search([[256], [128], [200, 100]]), }, num_samples=2 ) print("Best config: ", analysis.get_best_config(metric="episode_reward_mean")) df = analysis.dataframe() print(df)
沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485