今さらですが、畳込みオートエンコーダーを使って、画像を入力したら同じ画像を出力できるようにしてみます。
ソースコードはgithubにあげておきます。
https://github.com/tocom242242/aifx_blog_codes/blob/master/nn_tf2/conv/cnn_autoencoder.ipynb
実装
まず必要なライブラリをimport します。
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
from tensorflow.keras.models import Model
import copy
from IPython import display
データセットの読み込みとデータの成形。データはmnistを使います。
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
x_train = x_train / 255
x_test = x_test / 255
x_train = x_train.astype("float32")
x_test = x_test.astype("float32")
リサイズなどをします。(モデルの都合上です。cifar10ように最初に作ってしまったので)
def resize(input_image, height, width):
input_image = tf.image.resize(input_image, [height, width],
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
return input_image
x_train = tf.image.resize(x_train[...,tf.newaxis], [32, 32])
x_test = tf.image.resize(x_test[...,tf.newaxis], [32, 32])
データの確認
def plot_imgs(imgs,shuffle=False):
plt.figure(figsize=(10,10))
plot_imgs = copy.deepcopy(imgs)
if shuffle:
np.random.shuffle(plot_imgs)
for i in range(36):
plt.subplot(6,6,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(np.squeeze(imgs[i]), cmap="gray", vmin=0, vmax=1)
plt.show()
plot_imgs(x_train)
モデルの作成を行います。ここはkeras のtutorialと同じです。
input_shape = x_train.shape[1:]
output_channel = x_train.shape[-1]
inputs = tf.keras.layers.Input(shape=input_shape)
x = layers.Conv2D(16, (3, 3), activation='relu', padding='same')(inputs)
x = layers.MaxPooling2D((2, 2), padding='same')(x)
x = layers.Conv2D(8, (3, 3), activation='relu', padding='same')(x)
x = layers.MaxPooling2D((2, 2), padding='same')(x)
x = layers.Conv2D(8, (3, 3), activation='relu', padding='same')(x)
encoded = layers.MaxPooling2D((2, 2), padding='same')(x)
x = layers.Conv2D(8, (3, 3), activation='relu', padding='same')(encoded)
x = layers.UpSampling2D((2, 2))(x)
x = layers.Conv2D(8, (3, 3), activation='relu', padding='same')(x)
x = layers.UpSampling2D((2, 2))(x)
x = layers.Conv2D(16, (3, 3), activation='relu', padding="same")(x)
x = layers.UpSampling2D((2, 2))(x)
decoded = layers.Conv2D(output_channel, (3, 3), activation='sigmoid', padding='same')(x)
model = tf.keras.Model(inputs, decoded)
model.compile(optimizer='adam', loss='binary_crossentropy')
model.summary()
学習させます。学習途中に結果を見たいのでコールバック関数を用意しておきます。
n = 10
input_num = 100
def plot_rec(input_images):
plt.figure(figsize=(20, 4))
decoded_imgs = model(input_images[:input_num], training=True)
plt_index = np.random.randint(0,input_num,size=n)
decoded_imgs = tf.squeeze(decoded_imgs)
input_images = np.squeeze(input_images)
for i in range(n):
ax = plt.subplot(2, n, i + 1)
idx = plt_index[i]
plt.imshow(input_images[idx], cmap="gray", vmin=0, vmax=1)
plt.title("input")
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax = plt.subplot(2, n, i + 1 + n)
plt.imshow(decoded_imgs[idx], cmap="gray", vmin=0, vmax=1)
plt.title("output")
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
class CustomCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
if epoch % 10 == 0:
display.clear_output(wait=True)
plot_rec(x_train)
model.fit(x_train,x_train,epochs=100, batch_size=20, shuffle=True, validation_data=(x_test,x_test),callbacks=[CustomCallback()])
その他の結果
mnistの他にもfashion mnist, cifar-10についてもやってみました。
fashion mnist
cifar10は今回のモデルでepoch数100ぐらいだとあまりうまくはいきませんでした。モデルを変えたりepoch数を増やせばうまく動くかも知れませんが、今回はここまで
cifar10に関してはUNetでもやってみました