【入門】畳み込みのFLOPsについて

深層学習における畳み込みのFLOPs(Floating Point Operations Per Secode)について紹介します。

まず、FLOPsとは、浮動小数点数の乗算・加算演算の数になります。つまり、何回演算を行うかになります。

畳み込みのFLOPsを計算ですが、イメージ図を使って説明します。

まず、単純に1度畳み込みの処理を行う際のFLOPsは以下のようになります。

上の畳み込みの図の出典:https://stats.stackexchange.com/questions/280179/why-is-resnet-faster-than-vgg

カーネルのそれぞれの要素毎の乗算(青)と乗算した要素を加算する処理(緑)の総和になります。

上の操作が入力画像のチャネル数出力サイズ出力のチャネル数分繰り返されるので、畳み込みのFLOPsは以下のようになります。

Pythonコードにすると以下のようになります。

W_in = 100
H_in = 100
C_in = 1
C_out = 1
K_width = 3
K_height = 3
S = 1

W_out = (W_in - K_width) // S + 1
H_out = (H_in - K_height) // S + 1

print("W_out:", W_out, "H_out:", H_out)

FLOPs_per_operation = (K_width * K_height + (K_width * K_height - 1)) * C_in

Total_FLOPs = FLOPs_per_operation * W_out * H_out * C_out
print("Total_FLOPs:",Total_FLOPs)

コメント

タイトルとURLをコピーしました