探讨如何通过像素学习来训练强化学习代理。尽管取得了一定的成果,但当尝试使用128字节的RAM内容来训练代理时,结果并不尽如人意。本文将探讨一种不同的输入特征处理方法,以期获得更好的训练效果。
根据《Atari学习环境(ALE)》论文的附录A所述,Atari 2600游戏程序员经常将这些位作为4位或8位字的一部分使用。在单个位上进行线性函数近似可以捕捉到这些多位字的价值。因此,可以尝试从1024位而不是将它们视为128字节值来学习。
考虑到RAM中可能使用不同的方案编码各种数字(例如,前四位中的一个值,后四位中的另一个值),可以尝试从1024位而不是将它们视为128字节值来学习。以下是用于位学习的学习代码。它提供了一个自定义的观察包装器,使用NumPy的unpackbits函数将128字节的输入观察扩展到1024字节的观察(每个字节要么是0要么是1)。它还使用了一个大缓冲区和一个小学习率,遵循在更高学习率尝试中不太成功的经验。
import numpy as np
import gym
import ray
from gym.spaces import Box
from gym import ObservationWrapper
from ray import tune
from ray.rllib.agents.dqn import DQNTrainer
from ray.tune.registry import register_env
class BytesToBits(ObservationWrapper):
def __init__(self, env):
super().__init__(env)
self.observation_space = Box(low=0, high=1, shape=(1024,), dtype=np.uint8)
def observation(self, obs):
return np.unpackbits(obs)
def env_creator(env_config):
env = gym.make('Breakout-ramDeterministic-v4')
env = BytesToBits(env)
return env
register_env("ram_bits_breakout", env_creator)
ENV = "ram_bits_breakout"
TARGET_REWARD = 200
TRAINER = DQNTrainer
ray.shutdown()
ray.init(include_webui=False, ignore_reinit_error=True)
tune.run(
TRAINER,
stop={
"episode_reward_mean": TARGET_REWARD},
config={
env: ENV,
monitor: True,
evaluation_num_episodes: 25,
double_q: True,
hiddens: [1024],
num_workers: 0,
num_gpus: 1,
target_network_update_freq: 12_000,
lr: 5E-6,
adam_epsilon: 1E-5,
learning_starts: 150_000,
buffer_size: 1_500_000,
}
)