深層学習における畳み込みのFLOPs(Floating Point Operations Per Secode)について紹介します。
まず、FLOPsとは、浮動小数点数の乗算・加算演算の数になります。つまり、何回演算を行うかになります。
畳み込みのFLOPsを計算ですが、イメージ図を使って説明します。
まず、単純に1度畳み込みの処理を行う際のFLOPsは以下のようになります。
カーネルのそれぞれの要素毎の乗算(青)と乗算した要素を加算する処理(緑)の総和になります。
上の操作が入力画像のチャネル数と出力サイズと出力のチャネル数分繰り返されるので、畳み込みのFLOPsは以下のようになります。
Pythonコードにすると以下のようになります。
W_in = 100
H_in = 100
C_in = 1
C_out = 1
K_width = 3
K_height = 3
S = 1
W_out = (W_in - K_width) // S + 1
H_out = (H_in - K_height) // S + 1
print("W_out:", W_out, "H_out:", H_out)
FLOPs_per_operation = (K_width * K_height + (K_width * K_height - 1)) * C_in
Total_FLOPs = FLOPs_per_operation * W_out * H_out * C_out
print("Total_FLOPs:",Total_FLOPs)
コメント