【tensorflow,keras】入力画像と同じ画像を出力する畳込みオートエンコーダーを作る【メモ】

今さらですが、畳込みオートエンコーダーを使って、画像を入力したら同じ画像を出力できるようにしてみます。

ソースコードはgithubにあげておきます。

https://github.com/tocom242242/aifx_blog_codes/blob/master/nn_tf2/conv/cnn_autoencoder.ipynb

実装

まず必要なライブラリをimport します。

import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
from tensorflow.keras.models import Model
import copy
from IPython import display

データセットの読み込みとデータの成形。データはmnistを使います。

(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()

x_train = x_train / 255
x_test = x_test / 255
x_train = x_train.astype("float32")
x_test = x_test.astype("float32")

リサイズなどをします。(モデルの都合上です。cifar10ように最初に作ってしまったので)

def resize(input_image, height, width):
  input_image = tf.image.resize(input_image, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  return input_image

x_train = tf.image.resize(x_train[...,tf.newaxis], [32, 32])
x_test = tf.image.resize(x_test[...,tf.newaxis], [32, 32])

データの確認

def plot_imgs(imgs,shuffle=False):
    plt.figure(figsize=(10,10))
    plot_imgs = copy.deepcopy(imgs)
    if shuffle:
        np.random.shuffle(plot_imgs)
    for i in range(36):
        plt.subplot(6,6,i+1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(np.squeeze(imgs[i]), cmap="gray", vmin=0, vmax=1)
    plt.show()
plot_imgs(x_train)

モデルの作成を行います。ここはkeras のtutorialと同じです。

input_shape = x_train.shape[1:]
output_channel = x_train.shape[-1]

inputs = tf.keras.layers.Input(shape=input_shape)

x = layers.Conv2D(16, (3, 3), activation='relu', padding='same')(inputs)
x = layers.MaxPooling2D((2, 2), padding='same')(x)
x = layers.Conv2D(8, (3, 3), activation='relu', padding='same')(x)
x = layers.MaxPooling2D((2, 2), padding='same')(x)
x = layers.Conv2D(8, (3, 3), activation='relu', padding='same')(x)
encoded = layers.MaxPooling2D((2, 2), padding='same')(x)

x = layers.Conv2D(8, (3, 3), activation='relu', padding='same')(encoded)
x = layers.UpSampling2D((2, 2))(x)
x = layers.Conv2D(8, (3, 3), activation='relu', padding='same')(x)
x = layers.UpSampling2D((2, 2))(x)
x = layers.Conv2D(16, (3, 3), activation='relu', padding="same")(x)
x = layers.UpSampling2D((2, 2))(x)
decoded = layers.Conv2D(output_channel, (3, 3), activation='sigmoid', padding='same')(x)

model = tf.keras.Model(inputs, decoded)
model.compile(optimizer='adam', loss='binary_crossentropy')
model.summary()

学習させます。学習途中に結果を見たいのでコールバック関数を用意しておきます。


n = 10
input_num = 100
def plot_rec(input_images):
    plt.figure(figsize=(20, 4))
    decoded_imgs = model(input_images[:input_num], training=True)
    plt_index = np.random.randint(0,input_num,size=n)
    decoded_imgs = tf.squeeze(decoded_imgs)
    input_images = np.squeeze(input_images)

    for i in range(n):
        ax = plt.subplot(2, n, i + 1)
        idx = plt_index[i]
        plt.imshow(input_images[idx], cmap="gray", vmin=0, vmax=1)
        plt.title("input")
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        
        ax = plt.subplot(2, n, i + 1 + n)
        plt.imshow(decoded_imgs[idx], cmap="gray", vmin=0, vmax=1)
        plt.title("output")
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()

class CustomCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        if epoch % 10 == 0:
            display.clear_output(wait=True)
            plot_rec(x_train)

model.fit(x_train,x_train,epochs=100, batch_size=20, shuffle=True, validation_data=(x_test,x_test),callbacks=[CustomCallback()])

その他の結果

mnistの他にもfashion mnist, cifar-10についてもやってみました。

fashion mnist

cifar10は今回のモデルでepoch数100ぐらいだとあまりうまくはいきませんでした。モデルを変えたりepoch数を増やせばうまく動くかも知れませんが、今回はここまで

cifar10に関してはUNetでもやってみました

https://www.tcom242242.net/entry/ai-2/deeplearning/%e3%80%90tensorflow2-%e3%80%91unet%e3%81%a7%e5%85%a5%e5%8a%9b%e7%94%bb%e5%83%8f%e3%81%a8%e5%90%8c%e3%81%98%e7%94%bb%e5%83%8f%e3%82%92%e5%87%ba%e5%8a%9b%e3%81%95%e3%81%9b%e3%82%8b%e3%80%90%e3%83%a1/

参考文献

タイトルとURLをコピーしました