【深層強化学習】Deep Q Network をTensorFlowで実装

今回はDeep Q Networkについて紹介します。 実装にはTensorFlowを用いました。

Deep Q Network

Deep Q Network(DQN)はQ学習Qテーブルをニューラルネットワークで関数近似した強化学習アルゴリズムです。

f:id:ttt242242:20190812053308p:plain左がQ学習のQテーブル、右はQテーブルをニューラルネットワークで関数近似(Q Network)

厳密にはQテーブルをニューラルネットワークで関数近似しただけでなく、
Experience Replay、Target Network、Reward Clipping等の工夫を加えたものをDeep Q Networkと呼びます。

Deep Q Networkの構成

Deep Q Networkは以下の図のようになります。

f:id:ttt242242:20190811173656p:plain

エージェントはTarget NetworkとQ Networkというニューラルネットワークを保持しています。 この2つのネットワークの構造は同一です。パラメータは異なります。
このネットワークの入力は状態、出力はQ値になります。
得られた経験\(\langle s, a, s', r \rangle \)(ある状態\(s \)において行動\(a\)を行い、状態\(s'\)に遷移し報酬\(r\)を得た)を保存するためのExperience Bufferを保持しています。

Deep Q Networkの学習の流れ

図にある通り、以下を繰り返します。

  1. 現在の状態$s$をTarget Network\(Q(s', a|\theta^{-})\)に入力
  2. Target Networkから出力されたQ値を元に行動選択
  3. 行動したことによって変化した状態 \(s'\) 及び報酬\(r\)の観測
  4. 経験\(e=\langle s, a, s', r \rangle \)をExperience Bufferに保存
  5. Experience Bufferから任意の経験を取り出し、Q Networkをミニバッチ学習(Experience Replay)(詳細は後述
  6. Target Networkの更新(詳細は後述

Target Networkの更新

任意のインターバルで、 Q NetworkのパラメータをTarget Networkに反映していきます。
これは、Q Learningの過大評価という課題を緩和するために重要なります。

Q Networkのパラメータの反映方法には大きくわけて2つあって、
1つはHard Update もう1つはSoft Update になります。

Hard Updateでは、定期的にQ NetworkのパラメータをTarget Networkにコピーします。
Soft Update では、Q Networkを更新する度に少しずつずづQ Networkのパラメータを反映させていきます。

Experience Replay(Q Networkの学習)

任意のインターバルで、
バッチサイズ分Experience Bufferから 経験をサンプリング(\(B={e_0, e1, ..., e{|B|} } \)) し、
以下のTD誤差\(\mathcal{L}(\theta)\)を最小化するようにQ Networkのパラメータ\(\theta\)を更新します。

$$
\begin{eqnarray}
\mathcal{L}(\theta) = \frac{1}{|B|}\sum_{e \in B} (r + \gamma \max_a Q(s', a|\theta^{-}) - Q(s, a|\theta))^2
\end{eqnarray}
$$

Reward Clipping

Reward Clippingは外れ値等に過剰に反応しすぎないために、
報酬値を-1〜1の範囲にクリップすることです。

実装と実験

tensorflowを使って実装してみました。
cartpole問題を用いて、実験します。

Cartpole問題とは

CartPoleは、 棒が設置してある台車があり、
台車を棒が倒れないように
うまくコントロールする問題になります。

f:id:ttt242242:20190428190208p:plain

出典:Leaderboard · openai/gym Wiki · GitHub

制御値、観測、報酬等について

制御値(行動)

制御値は、台を左に押す(0)か 右に押す(1)の二択になります。

---
0 左に押す
1 右に押す

観測

観測値は、台車の位置、台車の速度、棒の角度、棒の先端の速度の4つになります。

---
台車の位置 -2.4 2.4
台車の速度 -inf inf
棒の角度 -41.8° 41.8°
棒の先端の速度 -inf inf

報酬

報酬としては1を与え続けます。

エピソードの終了判定

以下のどれかの条件を満たした場合に、
エピソードが終了したと判定されます。

  • ポールのアングルが±12°以内
  • 台車の位置が±2.4以内
  • エピソードの長さが200以上

ソースコードと解説

ソースコード

一部解説(注意すべき点)

Q Networkを学習する時に注意する必要があります。
経験\(\langle s, a_0, r, s' \rangle\)が与えられた時に\(Q(s, a_0)\)だけを学習させるために 少し工夫を加えます。 余計な学習をしないように、Q Networkの出力とtarget\(r+\gamma \max_a Q(s',a|\theta^{-})\)にマスク処理を行います。

図を用いて説明します。
単純に学習しようとすると、以下の図のようにすべての出力層を考慮してしまいます。 f:id:ttt242242:20190812053338p:plain

しかし、今回考慮したいのは行動\(a_0\)に対する評価のみです。
なので、余計な学習をしないように、\(a_1, a_2\)を無視するために、出力\(Q(s, a_1), Q(s, a_2)\)と\(a_1, a_2\)に対応したtargetを0にし、学習しないようにします。

f:id:ttt242242:20190812053353p:plain

ソースコード上では、120行〜122行の部分になります。
one hot encodingを用いて対処しています。

実験結果

f:id:ttt242242:20190811182301p:plain

横軸はepisode、縦軸はpoleが立っていられたステップ数を表します。
poleが立っていられる最大ステップは200ステップです。
エピソードが進むごとにpoleを200ステップ立たせられていることがわかります。

コメント

  1. […] Deep-Q-Networkの紹介と実装 […]

タイトルとURLをコピーしました