方策勾配定理の証明の個人的なメモです。まだミスや説明足らずのところがあると思うので適宜修正していきます。
方策勾配定理とは
まず方策勾配定理を復習しておきます。
方策勾配定理によって得られる勾配は以下のようなものでした。
\begin{aligned}
\nabla_\theta J(\theta) \propto \mathbb{E}_{\pi} [Q^\pi(s, a) \nabla_{\theta} \ln \pi_\theta(a \vert s)]
\end{aligned}
方策勾配定理の証明
では、証明していきます。
まずエピソード型の期待報酬値は以下のように表現できます。
\begin{aligned}
J(\theta) = V^{\pi}(s_0)
\end{aligned}
なので、この期待報酬値の勾配を求めるために、\(\nabla_\theta V^\pi(s) \)をひたすら計算していきます。
\begin{aligned}
\nabla_\theta V^\pi(s) &= \nabla_\theta \Big(\sum_{a} \pi_\theta(a \vert s)Q^\pi(s, a) \Big) \\
&=\sum_{a} \nabla_\theta \Big( \pi_\theta(a \vert s)Q^\pi(s, a) \Big) \\
& \color{blue}{※((f(x)g(x))’ = f(x)’g(x)+f(x)g'(x))} \\
& = \sum_a [\nabla_{\theta}\pi_{\theta}(a|s)Q^\pi(s,a)+\pi_{\theta}(a|s)\color{blue}{\nabla_{\theta}Q^\pi(s,a)}] \\
& = \sum_a [\nabla_{\theta}\pi_{\theta}(a|s)Q^\pi(s,a)+\pi_{\theta}(a|s)\color{blue}{\nabla_{\theta}\sum_{s’} P(s’|s,a)(r+V(s’))}] \\
& = \sum_a [\nabla_{\theta}\pi_{\theta}(a|s)Q^\pi(s,a)+\pi_{\theta}(a|s)\sum_{s’} P(s’|s,a)\color{blue}{\nabla_{\theta}V(s’)}] \\
&= \sum_a [\nabla_{\theta}\pi(a|s)Q_\pi(s,a) \\
&+\pi(a|s)\sum_{s’}P(s’|s,a)\color{blue}{\sum_{a’} [\nabla_{\theta}\pi_{\theta} (a’|s’)Q^\pi(s’,a’) }\color{blue}{+\pi_{\theta} (a’|s’)\sum_{s”}P(s”|s’,a’)\nabla_{\theta} V^\pi(s”)]}] \\
&= \sum_a \nabla_{\theta}\pi(a|s)Q_\pi(s,a) \\
&\color{blue}{{+\sum_a}\pi_{\theta}(a|s)\sum_{s’}Q^\pi(s’|s,a)\sum_{a’}\nabla_{\theta}\pi_{\theta}(a’|s’)Q^\pi(s’,a’)}\\
&\color{blue}{+ \sum_a\pi_{\theta}(a|s)\sum_{s’}P(s’|s,a)\sum_{a’}\pi_{\theta}(a’|s’)\sum_{s”}P(s”|s’,a’)\nabla_{\theta} V^\pi(s”)}
\end{aligned}
ここで、新しい遷移関数\(Pr(s\rightarrow x,k,\pi_{\theta}) \)というを定義します。これは方策\(\pi\)を使って状態\(s\)から状態\(x\)に\(k\)ステップで遷移する確率を表現した関数です。
\begin{aligned}
Pr(s\rightarrow x,0,\pi_{\theta}) &=\begin{cases}1 & (x=s)\\0 & (otherwise)\end{cases} \\
Pr(s\rightarrow x,1,\pi_\theta) &= \sum_a \pi_{\theta} (a|s)p(x | s,a) \\
Pr(s\rightarrow x,2,\pi_{\theta}) &=\sum_a\pi_{\theta}(a|s)\sum_{s’}P(s’|s,a)\sum_{a’}\pi_{\theta}(a’|s’)P(x|s’,a’)
\end{aligned}
最初の式は\(s\)に居る場合1、いない場合0となる関数です。2つめは1ステップで\(s\)から\(x\)に遷移する関数になります。
この遷移関数を使って先程の式をさらに展開していきます。
\begin{aligned}
&= \color{blue}{\sum_a \nabla_{\theta}\pi(a|s)Q_\pi(s,a)} \\
&\color{blue}{{+\sum_a}\pi_{\theta}(a|s)\sum_{s’}Q^\pi(s’|s,a)}\sum_{a’}\nabla_{\theta}\pi_{\theta}(a’|s’)Q^\pi(s’,a’)\\
&\color{blue}{+ \sum_a\pi_{\theta}(a|s)\sum_{s’}P(s’|s,a)}\sum_{a’}\pi_{\theta}(a’|s’)\sum_{s”}P(s”|s’,a’)\nabla_{\theta} V^\pi(s”) \\
&= \color{blue}{\sum_{s\in \mathcal{S}}Pr(s\rightarrow s,0,\pi_{\theta})}\sum_a \nabla_{\theta}\pi_{\theta}(a|s)Q^\pi(s,a) \\
&+\color{blue}{\sum_{s’\in \mathcal{S}}Pr(s\rightarrow s’,1,\pi_{\theta})}\sum_{a’}\nabla_{\theta}\pi_{\theta}(a’|s’)Q^\pi(s’,a’) \\
&+\color{blue}{\sum_{s”\in \mathcal{S}}Pr(s\rightarrow s”,2,\pi_{\theta})}\nabla_{\theta} V^\pi(s”) \\
&= \cdots (unrolling) \\
&= \sum_{x\in \mathcal{S}}\sum_{k=0}^\infty Pr(s\rightarrow x,k,\pi_{\theta})\sum_a \nabla_{\theta} \pi_{\theta}(a|x)Q^\pi(x,a)
\end{aligned}
上の式展開で補足しておくと、
\begin{aligned}
※\sum_a\pi(a|s)\sum_{x \in S}Q^{\pi}(x|s,a) &= \sum_{x \in S} \sum_a \pi(a|s)Q^{\pi}(x|s,a) \\
&= \sum_{x \in S} Pr(s\rightarrow x,k,\pi_{\theta})
\end{aligned}
ではここまでで求めた\( \nabla_{\theta}V^{\pi}(s_0)\)を使って\(\nabla_{\theta}J(\theta) \)を求めていきます。スタートの状態\(s_0\)から展開していきます。ここ以降の展開は参考文献2を参考にしていただければと思います。
\begin{aligned}
\nabla_{\theta}J(\theta) & = \nabla_{\theta}V^{\pi}(s_0) \\
&= \sum_{x\in \mathcal{S}}\color{blue}{\sum_{k=0}^\infty Pr(s\rightarrow x,k,\pi_{\theta})} \sum_a \nabla_{\theta} \pi_{\theta}(a|x)Q^\pi(x,a) \\
&= \sum_s \eta (s) \sum_a \nabla_{\theta} \pi_{\theta}(a|x)Q^\pi(x,a) \\
&= \color{blue}{(\sum_s \eta (s))} \sum_s \frac{\eta(s)}{\color{blue}{\sum_s \eta (s)}}\sum_a \nabla_{\theta}\pi_{\theta}(a|x)Q^\pi(x,a)\\
& \propto \sum_s \color{blue}{\frac{\eta(s)}{\sum_s \eta (s)}} \sum_a \nabla_{\theta}\pi_{\theta}(a|x)Q^\pi(x,a) \\
&= \sum_s \color{blue}{d^{\pi}(s)} \sum_a \nabla_{\theta}\pi_{\theta}(a|x)Q^\pi(x,a) \\
\end{aligned}
さらに展開していきます。
\begin{aligned}
\nabla_\theta J(\theta)
&\propto \sum_{s \in \mathcal{S}} d^\pi(s) \sum_{a \in \mathcal{A}} Q^\pi(s, a) \nabla_\theta \pi_\theta(a \vert s) \\
&= \sum_{s \in \mathcal{S}} d^\pi(s) \sum_{a \in \mathcal{A}} \pi_\theta(a \vert s) Q^\pi(s, a) \frac{\nabla_\theta \pi_\theta(a \vert s)}{\pi_\theta(a \vert s)} \\
& ※\nabla_{\theta} log\pi_\theta = \frac{1}{\pi_\theta} \nabla_{\theta}\pi_\theta \\
&= \sum_{s \in \mathcal{S}} d^\pi(s) \sum_{a \in \mathcal{A}} \pi_\theta(a \vert s) Q^\pi(s, a) \nabla_{\theta} \ln \pi_\theta(a \vert s) \\
&= \mathbb{E}_{\pi} [Q^\pi(s, a) \nabla_{\theta} \ln \pi_\theta(a \vert s)]
\end{aligned}
以上です。お疲れ様でした。
参考文献
- https://stats.stackexchange.com/questions/325152/how-is-one-of-the-steps-in-the-policy-gradient-theorem-done
- https://lilianweng.github.io/lil-log/2018/04/08/policy-gradient-algorithms.html#reinforce