【Tensorflow2 】UNetで入力画像と同じ画像を出力させる【メモ】

前回は畳込みオートエンコーダーを使って入力画像と出力画像が同じになるように学習しましたが、今回はUNetを使います。

UNet自体の説明は他にたくさんわかりやすい記事があるので、とりあえず実装していきます。

ソースコードはgithubにあげておきます。

https://github.com/tocom242242/aifx_blog_codes/blob/master/nn_tf2/conv/unet.ipynb

実装

まず必要なライブラリをimport します。

import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
from tensorflow.keras.models import Model
import copy
from IPython import display

データセットの読み込みとデータの成形。データはcifar10を使います。

(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()

x_train = x_train / 255
x_test = x_test / 255
x_train = x_train.astype("float32")
x_test = x_test.astype("float32")

データの確認

def plot_imgs(imgs,shuffle=False):
    plt.figure(figsize=(10,10))
    plot_imgs = copy.deepcopy(imgs)
    if shuffle:
        np.random.shuffle(plot_imgs)
    for i in range(36):
        plt.subplot(6,6,i+1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(np.squeeze(imgs[i]), cmap="gray", vmin=0, vmax=1)
    plt.show()
plot_imgs(x_train)

モデルの作成を行います。

def downsample(filters, size, apply_batchnorm=True, strides=2):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2D(filters, size, strides=strides, padding='same',
                             kernel_initializer=initializer, use_bias=False))

  result.add(tf.keras.layers.LeakyReLU())
  if apply_batchnorm:
    result.add(tf.keras.layers.BatchNormalization())


  return result

def upsample(filters, size, apply_batchnorm=True, strides=2):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=strides,
                                    padding='same',
                                    kernel_initializer=initializer,
                                    use_bias=False))

  result.add(tf.keras.layers.ReLU())
  result.add(tf.keras.layers.BatchNormalization())

  if apply_batchnorm:
    result.add(tf.keras.layers.BatchNormalization())

  return result


input_shape = x_train.shape[1:]
output_channel = x_train.shape[-1]
size = 3 # カーネルサイズ

inputs = tf.keras.layers.Input(shape=input_shape)

def build_model():
  inputs = tf.keras.layers.Input(shape=input_shape)

  down1 = downsample(32, size, strides=1,apply_batchnorm=False)(inputs)
  down2 = downsample(64, size,strides=2,apply_batchnorm=True)(down1)
  down3 = downsample(128, size,strides=1,apply_batchnorm=False)(down2)

  up1 =upsample(64, size,  strides=1,apply_batchnorm=True)(down3)
  con1 = tf.keras.layers.Concatenate()([up1, down2])
  up2 = upsample(32, size, strides=2,apply_batchnorm=False)(con1)
  con1 = tf.keras.layers.Concatenate()([up2, down1])

  initializer = tf.random_normal_initializer(0., 0.02)
  last = tf.keras.layers.Conv2DTranspose(3, size,
                                         strides=1,
                                         padding='same',
                                         kernel_initializer=initializer,
                                        #  activation='tanh')  # (batch_size, 256, 256, 3)
                                         activation='sigmoid')  # (batch_size, 256, 256, 3)

  x = last(con1)
  return tf.keras.Model(inputs=inputs, outputs=x)

model = build_model()
tf.keras.utils.plot_model(model, show_shapes=True, dpi=64)
model.compile(optimizer='adam', loss='binary_crossentropy')
model.summary()

学習させます。学習途中に結果を見たいのでコールバック関数を用意しておきます。

n = 10
input_num = 100
def plot_rec(input_images):
    plt.figure(figsize=(20, 4))
    decoded_imgs = model(input_images[:input_num], training=True)
    plt_index = np.random.randint(0,input_num,size=n)
    decoded_imgs = tf.squeeze(decoded_imgs)
    input_images = np.squeeze(input_images)

    for i in range(n):
        ax = plt.subplot(2, n, i + 1)
        idx = plt_index[i]
        plt.imshow(input_images[idx], cmap="gray", vmin=0, vmax=1)
        plt.title("input")
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        
        ax = plt.subplot(2, n, i + 1 + n)
        plt.imshow(decoded_imgs[idx], cmap="gray", vmin=0, vmax=1)
        plt.title("output")
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()

class CustomCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        if epoch % 10 == 0:
            display.clear_output(wait=True)
            plot_rec(x_train)

model.fit(x_train,x_train,epochs=100, batch_size=120, shuffle=True, validation_data=(x_test,x_test),callbacks=[CustomCallback()])

終わりに

単純な畳込みオートエンコーダーよりかなり早い段階でcifar10の再構成に成功しているので、びっくりします。

他にも様々なところで使えると思うので使っていこうと思います。

参考文献

タイトルとURLをコピーしました