【TensorFlow 2】training=Falseにして推論すると精度がかなり落ちた時の原因メモ【BatchNormalization】

tensorflow

CNNを作って遊んでいた時に、modelのtraining=Falseして実行すると精度がめちゃくちゃ落ちる現象が起きて少し悩んだのでメモしておきます。

結論としては、BatchNormalization layerが原因である可能性が高そう。

BatchNormalization layerのtraining引数

BatchNormalization は学習時のBatch毎に平均・分散を計算して、それを使ってデータを正規化することで学習を効率化します。

training=True時にはシンプルにBatch毎に平均・分散を計算しています。

一方training=Falseの時には、trainingで使った平均・分散の移動平均を使って推論をします。
なので、training=Falseとした時は入力が学習データと同じような場合(batchサイズ等も含めて)には良い精度を出すかも知れませんが、それ以外の時には極端に精度が落ちるかもしれません。

自分の場合は、テスト時はテストデータセットでのbatch統計(平均・分散)を使って欲しかったので、そういう時にはtraining=Trueというふうにするべきだったようです。

pip2pixのチュートリアルにも同様のことが記述してありました。

Note: The training=True is intentional here since you want the batch statistics, while running the model on the test dataset. If you use training=False, you get the accumulated statistics learned from the training dataset (which you don't want).

簡単な訳

Note:testing=Trueは、テストデータセットでモデルを実行しているときにバッチ統計が必要なため、意図的にここで使っています。 training=Falseを使用すると、トレーニングデータセットから学習した累積統計を使用してしまうので(ここでは不要)

対策

なので、Batchnormalization Layerが入っているモデルで推論する時には、不自然ですが、training=Trueとするのがベターかもしれません。

参考文献

タイトルとURLをコピーしました