【強化学習入門】方策勾配定理の証明メモ 【Policy Gradient Theorem】

方策勾配定理の証明の個人的なメモです。まだミスや説明足らずのところがあると思うので適宜修正していきます。

方策勾配定理とは

まず方策勾配定理を復習しておきます。

方策勾配定理によって得られる勾配は以下のようなものでした。

\begin{aligned}
\nabla_\theta J(\theta) \propto \mathbb{E}_{\pi} [Q^\pi(s, a) \nabla_{\theta} \ln \pi_\theta(a \vert s)]
\end{aligned}

方策勾配定理の証明

では、証明していきます。

まずエピソード型の期待報酬値は以下のように表現できます。

\begin{aligned}
J(\theta) = V^{\pi}(s_0)
\end{aligned}

なので、この期待報酬値の勾配を求めるために、\(\nabla_\theta V^\pi(s) \)をひたすら計算していきます。


\begin{aligned}
\nabla_\theta V^\pi(s) &= \nabla_\theta \Big(\sum_{a} \pi_\theta(a \vert s)Q^\pi(s, a) \Big) \\
&=\sum_{a} \nabla_\theta \Big( \pi_\theta(a \vert s)Q^\pi(s, a) \Big) \\
& \color{blue}{※((f(x)g(x))' = f(x)'g(x)+f(x)g'(x))} \\
& = \sum_a [\nabla_{\theta}\pi_{\theta}(a|s)Q^\pi(s,a)+\pi_{\theta}(a|s)\color{blue}{\nabla_{\theta}Q^\pi(s,a)}] \\
& = \sum_a [\nabla_{\theta}\pi_{\theta}(a|s)Q^\pi(s,a)+\pi_{\theta}(a|s)\color{blue}{\nabla_{\theta}\sum_{s'} P(s'|s,a)(r+V(s'))}] \\
& = \sum_a [\nabla_{\theta}\pi_{\theta}(a|s)Q^\pi(s,a)+\pi_{\theta}(a|s)\sum_{s'} P(s'|s,a)\color{blue}{\nabla_{\theta}V(s')}] \\
&= \sum_a [\nabla_{\theta}\pi(a|s)Q_\pi(s,a) \\
&+\pi(a|s)\sum_{s'}P(s'|s,a)\color{blue}{\sum_{a'} [\nabla_{\theta}\pi_{\theta} (a'|s')Q^\pi(s',a') }\color{blue}{+\pi_{\theta} (a'|s')\sum_{s''}P(s''|s',a')\nabla_{\theta} V^\pi(s'')]}] \\
&= \sum_a \nabla_{\theta}\pi(a|s)Q_\pi(s,a) \\
&\color{blue}{{+\sum_a}\pi_{\theta}(a|s)\sum_{s'}Q^\pi(s'|s,a)\sum_{a'}\nabla_{\theta}\pi_{\theta}(a'|s')Q^\pi(s',a')}\\
&\color{blue}{+ \sum_a\pi_{\theta}(a|s)\sum_{s'}P(s'|s,a)\sum_{a'}\pi_{\theta}(a'|s')\sum_{s''}P(s''|s',a')\nabla_{\theta} V^\pi(s'')}
\end{aligned}

ここで、新しい遷移関数\(Pr(s\rightarrow x,k,\pi_{\theta}) \)というを定義します。これは方策\(\pi\)を使って状態\(s\)から状態\(x\)に\(k\)ステップで遷移する確率を表現した関数です。

\begin{aligned}
Pr(s\rightarrow x,0,\pi_{\theta}) &=\begin{cases}1 & (x=s)\\0 & (otherwise)\end{cases} \\
Pr(s\rightarrow x,1,\pi_\theta) &= \sum_a \pi_{\theta} (a|s)p(x | s,a) \\
Pr(s\rightarrow x,2,\pi_{\theta}) &=\sum_a\pi_{\theta}(a|s)\sum_{s'}P(s'|s,a)\sum_{a'}\pi_{\theta}(a'|s')P(x|s',a')
\end{aligned}

最初の式は\(s\)に居る場合1、いない場合0となる関数です。2つめは1ステップで\(s\)から\(x\)に遷移する関数になります。

この遷移関数を使って先程の式をさらに展開していきます。

\begin{aligned}
&= \color{blue}{\sum_a \nabla_{\theta}\pi(a|s)Q_\pi(s,a)} \\
&\color{blue}{{+\sum_a}\pi_{\theta}(a|s)\sum_{s'}Q^\pi(s'|s,a)}\sum_{a'}\nabla_{\theta}\pi_{\theta}(a'|s')Q^\pi(s',a')\\
&\color{blue}{+ \sum_a\pi_{\theta}(a|s)\sum_{s'}P(s'|s,a)}\sum_{a'}\pi_{\theta}(a'|s')\sum_{s''}P(s''|s',a')\nabla_{\theta} V^\pi(s'') \\
&= \color{blue}{\sum_{s\in \mathcal{S}}Pr(s\rightarrow s,0,\pi_{\theta})}\sum_a \nabla_{\theta}\pi_{\theta}(a|s)Q^\pi(s,a) \\
&+\color{blue}{\sum_{s'\in \mathcal{S}}Pr(s\rightarrow s',1,\pi_{\theta})}\sum_{a'}\nabla_{\theta}\pi_{\theta}(a'|s')Q^\pi(s',a') \\
&+\color{blue}{\sum_{s''\in \mathcal{S}}Pr(s\rightarrow s'',2,\pi_{\theta})}\nabla_{\theta} V^\pi(s'') \\
&= \cdots (unrolling) \\
&= \sum_{x\in \mathcal{S}}\sum_{k=0}^\infty Pr(s\rightarrow x,k,\pi_{\theta})\sum_a \nabla_{\theta} \pi_{\theta}(a|x)Q^\pi(x,a)
\end{aligned}

上の式展開で補足しておくと、

\begin{aligned}
※\sum_a\pi(a|s)\sum_{x \in S}Q^{\pi}(x|s,a) &= \sum_{x \in S} \sum_a \pi(a|s)Q^{\pi}(x|s,a) \\
&= \sum_{x \in S} Pr(s\rightarrow x,k,\pi_{\theta})
\end{aligned}

ではここまでで求めた\( \nabla_{\theta}V^{\pi}(s_0)\)を使って\(\nabla_{\theta}J(\theta) \)を求めていきます。スタートの状態\(s_0\)から展開していきます。ここ以降の展開は参考文献2を参考にしていただければと思います。

\begin{aligned}
\nabla_{\theta}J(\theta) & = \nabla_{\theta}V^{\pi}(s_0) \\
&= \sum_{x\in \mathcal{S}}\color{blue}{\sum_{k=0}^\infty Pr(s\rightarrow x,k,\pi_{\theta})} \sum_a \nabla_{\theta} \pi_{\theta}(a|x)Q^\pi(x,a) \\
&= \sum_s \eta (s) \sum_a \nabla_{\theta} \pi_{\theta}(a|x)Q^\pi(x,a) \\
&= \color{blue}{(\sum_s \eta (s))} \sum_s \frac{\eta(s)}{\color{blue}{\sum_s \eta (s)}}\sum_a \nabla_{\theta}\pi_{\theta}(a|x)Q^\pi(x,a)\\
& \propto \sum_s \color{blue}{\frac{\eta(s)}{\sum_s \eta (s)}} \sum_a \nabla_{\theta}\pi_{\theta}(a|x)Q^\pi(x,a) \\
&= \sum_s \color{blue}{d^{\pi}(s)} \sum_a \nabla_{\theta}\pi_{\theta}(a|x)Q^\pi(x,a) \\
\end{aligned}


さらに展開していきます。

\begin{aligned}
\nabla_\theta J(\theta)
&\propto \sum_{s \in \mathcal{S}} d^\pi(s) \sum_{a \in \mathcal{A}} Q^\pi(s, a) \nabla_\theta \pi_\theta(a \vert s) \\
&= \sum_{s \in \mathcal{S}} d^\pi(s) \sum_{a \in \mathcal{A}} \pi_\theta(a \vert s) Q^\pi(s, a) \frac{\nabla_\theta \pi_\theta(a \vert s)}{\pi_\theta(a \vert s)} \\
& ※\nabla_{\theta} log\pi_\theta = \frac{1}{\pi_\theta} \nabla_{\theta}\pi_\theta \\
&= \sum_{s \in \mathcal{S}} d^\pi(s) \sum_{a \in \mathcal{A}} \pi_\theta(a \vert s) Q^\pi(s, a) \nabla_{\theta} \ln \pi_\theta(a \vert s) \\
&= \mathbb{E}_{\pi} [Q^\pi(s, a) \nabla_{\theta} \ln \pi_\theta(a \vert s)]
\end{aligned}

以上です。お疲れ様でした。


参考文献

タイトルとURLをコピーしました