強化学習、マルチエージェント強化学習、その他機械学習全般

機械学習関連のことをまとめていきます。強化学習関連が多いかもしれません

【深層強化学習、TensorFlow】Deep-Q-Networkの紹介と実装

Deep Q Network

Deep Q Networkの概要

Deep Q Network(DQN)は一言で言ってしまうとQテーブルをニューラルネットワーク関数近似したQ Learningです。
もしかしたら、Deep Reinforcement Learningと呼んだほうが良いかもしれません。
厳密にはQテーブルをニューラルネットワーク関数近似しただけでなく、
Experience Replay、Target Network、Reward Clippingの三つの工夫を加えたものをDeep Q Networkと呼びます。

Deep Q Networkの構成

最も基本的なDeep Q Networkは以下の図のようになります。

f:id:ttt242242:20190325202401j:plain:w400

エージェントはTarget NetworkとQ Networkというニューラルネットワークを保持しています。 この2つのネットワークの構造は同一です。パラメータは異なります。
このネットワークの入力は状態、出力はQ値になります。
得られた経験$\langle s, a, s', r \rangle$(ある状態$s$において行動$a$を行い、状態$s'$に遷移し報酬$r$を得た)を保存するためのExperience Bufferを保持しています。

Deep Q Networkの学習の流れ

図にある通り、以下を繰り返す。

  1. 現在の状態$s$をTarget Networkに入力
  2. Target Networkから出力されたQ値を元に行動選択(後述)
  3. 行動したことによって変化した状態 $s'$ 及び報酬$r$の観測
  4. $e=\langle s, a, s', r \rangle$をExperience Bufferに保存
  5. Experience Bufferから任意の経験を取り出し、Q Networkをミニバッチ学習(Experience Replay)(詳細は後述)
  6. あるインターバルでTarget Networkの更新(いくつか種類がある)(詳細は後述)

Experience Replay(Q Networkの学習)

基本的にはステップ毎もしくはエピソード毎に、
バッチサイズ分Experience Bufferから 経験をサンプリング($B$) し、
以下のTD誤差$\mathcal{L}(\theta)$を最小化するように更新します。

$$ \begin{eqnarray} \mathcal{L}(\theta) = \frac{1}{|B|}\sum_{e \in B} (r + \gamma \max_a Q(s', a;\theta^{-}) - Q(s, a;\theta))^2 \end{eqnarray} $$

ここで、 $\theta^{-}$ はTarget Networkのパラメータ、
$\theta$ はQ Networkのパラメータを表します。

Fixed Target Q-Networkの更新

Q Networkの学習結果を少しづつTarget Networkに反映していきます。
これは、Q Learningの過大評価という課題を緩和するために重要なります。
反映方法には大きくわけて2つあって、
1つはHard Update もう1つはSoft Update になります。
Hard Updateでは、定期的にQ NetworkのパラメータをTarget Networkにコピーします。
Soft Update では、Q Networkを更新する度に少しずつずづQ Networkのパラメータを反映させていきます。

Reward Clipping

Reward Clippingは外れ値等に過剰に反応しすぎないために、
報酬値を-1〜1の範囲にクリップすること。
最近は報酬値をクリップというよりhuber lossを用いることが主流になっているように思います。

プログラム

keras-rlを参考にして、tensorflowを使って実装してみました。
けっこうkeras-rlのコードを参考にしています。cartpole問題を用いて、実験してみました。

keras-rlとcartpoleについては以下参照

www.tcom242242.net

プログラムはgithubにあげました。

github.com

実行するためには、まずクローンして、

git clone https://github.com/tocom242242/deep_qlearning_sample.git
cd deep_qlearning_sample

以下のコマンドを実行します。

python run_dqn_sample.py

実験結果

f:id:ttt242242:20190401180947p:plain

横軸はepisode、縦軸はpoleが立っていられたステップ数を表します。
poleが立っていられる最大ステップは200ステップが最大になります。
エピソードが進むごとにpoleを200ステップ立たせられていることがわかります。