When I was playing with CNN, I had a problem with the accuracy of the CNN when I ran it with training=False in the model.
In conclusion, it seems likely that the BatchNormalization layer is the cause of this problem.
training argument of BatchNormalization layer
BatchNormalization calculates the mean / variance for each batch during training and uses it to normalize the data to improve learning.
When training=True, BatchNormalization layer calculates mean / variance for each batch.
On the other hand, when training=False, inference is made using the moving average of the mean / variance used in training.
So, when training=False, if the input is similar to the training data (including batch size etc.), it may give good accuracy, but in other cases, the accuracy may drop extremely. Hmm.
In my case, I wanted to use batch statistics (mean / variance) in the test data set at the time of testing, so it seems that I should have trained = True at that time.
The tutorial of pip2pix also had the following description.
Note: The training=True is intentional here since you want the batch statistics, while running the model on the test dataset. If you use training=False, you get the accumulated statistics learned from the training dataset (which you don’t want).
Conclusion
It’s unnatural when inferring with a model that contains a Batchnormalization Layer, but it may be better to set training = True.
References
- https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization?version=stable
- https://www.tensorflow.org/tutorials/generative/pix2pix#restore_the_latest_checkpoint_and_test_the_network
- https://stackoverflow.com/questions/58728086/passing-training-true-when-using-tensorflow-2s-keras-functional-api