keras-rlをcolab上で動かせるようにしたのでメモしておきます。
Colaboratory(colab)
Colaboratory(colab)はgoogleが提供してくれるJupyterノートブック環境です。
無料でGPUが使えるので、非常にありがたいサービスです。
最近ではvimでも操作ができるようになったので本当に感謝です。
https://colab.research.google.com/notebooks/welcome.ipynb?hl=ja
keras-rlとは
Keras-rlとは 深層学習用ライブラリであるkerasを用いて、深層強化学習のアルゴリズムを実装したライブラリです。
https://github.com/keras-rl/keras-rl
容易に深層強化学習を試すことができます。
とりあえずローカルでkeras-rlを試したい場合は以下の記事を参照していただければと思います。
https://www.tcom242242.net/entry/2017/09/05/061405
colabのgpuを用いてkeras-rlを動かしてみる(atari)のブロック崩し)
以下のatariのブロック崩しをkeras-rlのdqnで解いてみました。
出典: https://gym.openai.com/envs/Breakout-v0/
dqnに関しては以下の記事を参照してください。
【深層強化学習,入門】Deep Q Network(DQN)の解説とPythonで実装 〜図を使って説明〜
GPUの有効化
実行する前にcolabでGPUを使えるようにcolabの設定を変更する必要があります。
手順は以下のようになります。
- 「ランタイム」
- 「ランタイムのタイプの変更」
- ハードウェアアクセラレータでGPUに変更
ソースコード
以下にソースコードを示します。
from PIL import Image import numpy as np import gym from keras.models import Sequential from keras.layers import Dense, Activation, Flatten, Convolution2D, Permute from keras.optimizers import Adam import keras.backend as K from rl.agents.dqn import DQNAgent from rl.policy import LinearAnnealedPolicy, BoltzmannQPolicy, EpsGreedyQPolicy from rl.memory import SequentialMemory from rl.core import Processor from rl.callbacks import FileLogger, ModelIntervalCheckpoint INPUT_SHAPE = (84, 84) WINDOW_LENGTH = 4 class AtariProcessor(Processor): def process_observation(self, observation): assert observation.ndim == 3 # (height, width, channel) img = Image.fromarray(observation) img = img.resize(INPUT_SHAPE).convert('L') processed_observation = np.array(img) assert processed_observation.shape == INPUT_SHAPE return processed_observation.astype('uint8') def process_state_batch(self, batch): processed_batch = batch.astype('float32') / 255. return processed_batch def process_reward(self, reward): return np.clip(reward, -1., 1.) env = gym.make('BreakoutDeterministic-v4') np.random.seed(123) env.seed(123) nb_actions = env.action_space.n input_shape = (WINDOW_LENGTH,) + INPUT_SHAPE model = Sequential() if K.image_dim_ordering() == 'tf': # (width, height, channels) model.add(Permute((2, 3, 1), input_shape=input_shape)) elif K.image_dim_ordering() == 'th': # (channels, width, height) model.add(Permute((1, 2, 3), input_shape=input_shape)) else: raise RuntimeError('Unknown image_dim_ordering.') model.add(Convolution2D(32, (8, 8), strides=(4, 4))) model.add(Activation('relu')) model.add(Convolution2D(64, (4, 4), strides=(2, 2))) model.add(Activation('relu')) model.add(Convolution2D(64, (3, 3), strides=(1, 1))) model.add(Activation('relu')) model.add(Flatten()) model.add(Dense(512)) model.add(Activation('relu')) model.add(Dense(nb_actions)) model.add(Activation('linear')) print(model.summary()) memory = SequentialMemory(limit=1000000, window_length=WINDOW_LENGTH) processor = AtariProcessor() policy = LinearAnnealedPolicy(EpsGreedyQPolicy(), attr='eps', value_max=1., value_min=.1, value_test=.05, nb_steps=1000000) dqn = DQNAgent(model=model, nb_actions=nb_actions, policy=policy, memory=memory, processor=processor, nb_steps_warmup=50000, gamma=.99, target_model_update=10000, train_interval=4, delta_clip=1.) dqn.compile(Adam(lr=.00025), metrics=['mae']) # google drive上にモデルや途中経過等保存する weights_filename = '/content/drive/My Drive/BreakoutDeterministic.h5f' checkpoint_weights_filename = '/content/drive/My Drive/dqn__weights_{step}_2.h5f' log_filename = '/content/drive/My Drive/dqn_log.json' callbacks = [ModelIntervalCheckpoint(checkpoint_weights_filename, interval=250000)] callbacks += [FileLogger(log_filename, interval=100)] dqn.fit(env, callbacks=callbacks, nb_steps=1750000, log_interval=10000) dqn.save_weights(weights_filename, overwrite=True) dqn.test(env, nb_episodes=10, visualize=False)
参考文献
つくりながら学ぶ! 深層強化学習 ~PyTorchによる実践プログラミング~
コメント