从游戏RAM学习:一种机器学习方法

机器学习领域,尤其是强化学习中,经常需要从环境中学习。通常,通过观察环境的像素来学习,但这种方法计算成本较高。本文将探讨一种不同的方法:直接从游戏的RAM中学习。

作为一名软件工程师,原本以为从RAM环境中学习会更容易。毕竟,内存中某个位置很可能存储着球拍的x坐标,另外两个位置则存储着球的位置。如果在编写代码来玩这个游戏,而且不使用机器学习,那这可能是想要开始的地方。如果被迫使用图形,可能会处理它们以提取这些信息,所以直接跳过这一步似乎更简单。

然而,错了!从图像中学习比从RAM中学习更容易。现代的卷积神经网络架构擅长从图像中提取有用的特征。相比之下,程序员习惯于使用尽可能少的内存,并想出各种“巧妙的技巧”来尽可能多地存储信息。一个字节可能代表一个数字,或者两个数字,每个数字4位,或者八个标志...

从RAM中学习

以下是使用的代码:

import ray from ray import tune from ray.rllib.agents.dqn import DQNTrainer ray.shutdown() ray.init(include_webui=False, ignore_reinit_error=True) ENV = "Breakout-ramDeterministic-v4" TARGET_REWARD = 200 TRAINER = DQNTrainer tune.run( TRAINER, stop={ "episode_reward_mean": TARGET_REWARD}, config={ "env": ENV, "monitor": True, "evaluation_num_episodes": 25, "double_q": True, "hiddens": [128], "num_workers": 0, "num_gpus": 1, "target_network_update_freq": 12_000, "lr": 5E-6, "adam_epsilon": 1E-5, "learning_starts": 150_000, "buffer_size": 1_500_000, } )

这是停止进程时的进展:这不是一个巨大的成功。让训练运行了54小时,才达到40分。所以它学到了一些东西,图表表明它还在继续改进,但进展非常缓慢。在下一篇文章中,将看到如何做得更好。

从RAM的一个子集中学习

人们很容易认为,尽管Atari只有128字节的内存,但许多存储的值只是噪音。例如,其中某处将是玩家当前的得分,使用这个作为输入特征对学习没有帮助。

所以尝试识别一个包含有用信息的位的子集。通过记录观察结果并查看哪些似乎在有意义地变化(即,在前一百个时间步中有大量的不同值),挑选出了以下列为“有趣的”:70、71、72、74、75、90、94、95、99、101、103、105和119。

以下是使用这些值训练模型的代码。转而使用PPO算法,因为它似乎比DQN表现得更好一点。

import pyvirtualdisplay _display = pyvirtualdisplay.Display(visible=False, size=(1400, 900)) _ = _display.start() import ray from ray import tune from ray.rllib.agents.ppo import PPOTrainer ray.shutdown() ray.init(include_webui=False, ignore_reinit_error=True) import numpy as np import gym from gym.wrappers import TransformObservation from gym.spaces import Box from ray.tune.registry import register_env from gym import ObservationWrapper class TruncateObservation(ObservationWrapper): interesting_columns = [70, 71, 72, 74, 75, 90, 94, 95, 99, 101, 103, 105, 119] def __init__(self, env): super().__init__(env) self.observation_space = Box(low=0, high=255, shape=(len(self.interesting_columns),), dtype=np.uint8) def observation(self, obs): print(obs.tolist()) # print full observation to find interesting columns return obs[self.interesting_columns] def env_creator(env_config): env = gym.make('Breakout-ramDeterministic-v4') env = TruncateObservation(env) return env register_env("simpler_breakout", env_creator) ENV = "simpler_breakout" TARGET_REWARD = 200 TRAINER = PPOTrainer tune.run( TRAINER, stop={ "episode_reward_mean": TARGET_REWARD}, config={ "env": ENV, "num_workers": 1, "num_gpus": 0, "monitor": True, "evaluation_num_episodes": 25 } )
沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485