【入門】GAN等の評価指標として使われるFIDについての紹介

今回は、GAN等で生成した画像の評価によく使われるFID(Fréchet Inception Distance)について紹介します。

FIDとは

FIDは生成された画像の分布と本物の画像の分布の類似性を比較できる指標の一つです。
二つの分布間の距離を計算します。

具体的にな手順としては、

  1. 事前学習済みモデル(Inceptionモデル)に、GAN等で生成した画像群Xと本物の画像群Yを入力してそれぞれの特徴量を得ます。
  2. 得られた特徴量群の分布は正規分布に従っていると仮定して、平均と共分散行列を計算します。
  3. それぞれの分布の平均と共分散行列を用いてフレシェ距離(Fréchet Distance)を計算します。下のような式になります。


実装と実験

試しに実験してみようと思います。
https://github.com/tocom242242/notebooks/blob/master/metrics/FID.ipynb

今回は、TensorFlowとCIFAR 10を使います。CIFAR10のラベル毎のデータ使ってFIDで比較してみます。

まず、学習済みのInceptionモデルを読み込みます。

import tensorflow as tf
import cv2
import numpy as np
from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input

# 特徴量抽出用のInceptionモデルを読み込む
model = InceptionV3(weights='imagenet', include_top=False, pooling="avg")

CIFAR-10のデータセットを読み込みます。今回はtestデータだけを使います。

# CIFAR-10データセットを読み込む
(_, _), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()

# データの正規化
test_images = test_images.astype('float32') / 255.0

データをクラスごとに分けます。

# クラスごとにデータを分ける
class_images = [[] for _ in range(10)]
for image, label in zip(test_images, test_labels):
    class_images[label[0]].append(image)

class_images = [np.array(images) for images in class_images]

Inceptionモデルに入力できるようにresizeする関数を用意しておきます。

# Inceptionモデルに入力できるサイズにリサイズする
target_size = (224, 224)
def resize(imgs):
    resized_imgs = []
    for img in imgs:
        resized = cv2.resize(img, target_size)
        resized_imgs.append(resized)
    return np.array(resized_imgs)

FIDを計算する関数を用意します。

# fidを計算する
def calc_fid(model, imgs1, imgs2):
    # 特徴量の抽出
    f1 = model.predict(preprocess_input(imgs1))
    f2 = model.predict(preprocess_input(imgs2))
    # 平均を求める
    f1_mean = np.mean(f1,axis=0)
    f2_mean = np.mean(f2,axis=0)
    # 平均の差を求める
    diff = f1_mean - f2_mean
    # 共分散行列を求める
    f1_sigma = np.cov(f1, rowvar=False)
    f2_sigma = np.cov(f2, rowvar=False)
    # 共分散行列の積を取り平方根を計算する
    sqrt_cov_dotted = sqrtm(f1_sigma.dot(f2_sigma))
    # 虚数が含まれる場合には実数のみ用いる
    if np.iscomplexobj(sqrt_cov_dotted):
        sqrt_cov_dotted = sqrt_cov_dotted.real
    fid = np.sum(diff**2.0) + np.trace(f1_sigma+f2_sigma - 2.0*sqrt_cov_dotted)
    return fid

メモリに余裕がないので、データセットから一部のデータを取り出して実験します。(本来であれば、このような少数のデータセットで比較することはありませんので、ご注意ください)

まず、同一クラスのFIDを求めてみます。非常に小さなFIDになることがわかります。

# 同一クラスのfid
resized_imgs1 = resize(class_images[0][:100])
resized_imgs2 = resize(class_images[0][:100])

fid = calc_fid(model, resized_imgs1, resized_imgs2)
print("fid1:",fid)
#=> fid1: -4.099820064954259e-07

次に別々のクラスのデータのFIDを計算します。上のスコアに比べてFIDが高くなることがわかります。

# 別クラスのfid
# 別クラスのfid
resized_imgs1 = resize(class_images[0][:100])
resized_imgs2 = resize(class_images[1][:100])

fid = calc_fid(model, resized_imgs1, resized_imgs2)
print("fid2:",fid)
# => fid2: 0.5701747729023202

参考文献

https://wandb.ai/wandb_fc/japanese/reports/-FID-GAN—Vmlldzo0MzY2ODY

コメント

タイトルとURLをコピーしました