[TensorFlow 2] Cause when the accuracy drops considerably when inferring with training = False[BatchNormalization]


When I was playing with CNN, I had a problem with the accuracy of the CNN when I ran it with training=False in the model.

In conclusion, it seems likely that the BatchNormalization layer is the cause of this problem.

training argument of BatchNormalization layer

BatchNormalization calculates the mean / variance for each batch during training and uses it to normalize the data to improve learning.

When training=True, BatchNormalization layer calculates mean / variance for each batch.

On the other hand, when training=False, inference is made using the moving average of the mean / variance used in training.

So, when training=False, if the input is similar to the training data (including batch size etc.), it may give good accuracy, but in other cases, the accuracy may drop extremely. Hmm.

In my case, I wanted to use batch statistics (mean / variance) in the test data set at the time of testing, so it seems that I should have trained = True at that time.

The tutorial of pip2pix also had the following description.

Note: The training=True is intentional here since you want the batch statistics, while running the model on the test dataset. If you use training=False, you get the accumulated statistics learned from the training dataset (which you don’t want).


It’s unnatural when inferring with a model that contains a Batchnormalization Layer, but it may be better to set training = True.