はじめに
工業製品の異常検知の自動化は人材不足の観点からも人件費の観点からも非常に重要です。今回は比較的シンプルな深層学習モデルであるオートエンコーダを使った異常検知手法を用いて、異常検知がどの程度できるか試してみます。
前提知識
この記事では
- Python
- 深層学習
- Keras
を少しは知っていることを前提としています。
(知らなくてもコピペすればプログラムは動きます)
オートエンコーダとは
オートエンコーダは深層学習モデルの一つで、入力データと同じデータを出力するように学習するモデルになります。
オートエンコーダはエンコーダ、デコーダから構成されていて、エンコーダで圧縮して、その圧縮したものをデコーダで再構成します。
オートエンコーダを使った異常検知
このオートエンコーダを使って異常検知を行っていくのですが、どのように使うのかをお話していきます。
まずは、正常データだけを使ってオートエンコーダを学習します。そうすると、正常データ「は」正しく再構成できるようになります。
一方で異常データが入力された際にはうまく入力された異常データを再構成できず、入力画像と出力画像の差異が大きくなることが期待できます。この差異によって正常データか異常かを判定していきます。
例えば、以下のように「1」という画像でオートエンコーダに学習させたとします。
そうするとこのオートエンコーダは「1」の画像は再構成できるようになります。
一方で、推論時に異常画像、今回は「0」の画像を入力すると、このオートエンコーダは「0」を再構成することを学習していないので、再構成に失敗します(するはずです)。以下のイラストように、「0」を入力すると「0」は学習していないので、再構成できず「1」を出力してしまいます。
ここで、入力と出力の差分をとることで異常を検知させます。
今回はMNISTデータセットという人が書いた数字の0から9までのデータセットを使って実験してみます。上の図のように1を正常データ、0を異常データとします。
つまり、1のデータだけで学習して、0が入力されたら、異常を判定させていきます。
実装
では、PythonとTensorFlowを使って実装していきます。Colabを使うので、Webブラウザだけで実験することが可能です。
深層学習モデルはTensorFlow(Keras)を使って書いていきます。
ソースコードは以下にあげてあります。
main_mnist.ipynb
まずは、必要なモジュールをimportします。
import os
from sklearn.metrics import roc_auc_score, roc_curve
import matplotlib.pyplot as plt
import numpy as np
import copy
import tensorflow as tf
from sklearn.metrics import accuracy_score, precision_score, recall_score
from sklearn.model_selection import train_test_split
from tensorflow.keras import datasets, layers, losses, optimizers
from tensorflow.keras.models import Model
MNISTデータセットを読み込みます。MNISTはTensorFlowに組み込まれているので、以下のコードを実行するだけです。
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
x_train = x_train / 255
x_test = x_test / 255
x_train、x_test
は学習用と評価用の画像データになります。下のように、深層学習モデルの学習用データとそのモデルを評価するためのデータに分けます。
今回はMNISTのデータの中でも使う画像は1と0だけなので、それらを取り出します。
正常データと異常データを以下のように指定しておきます。
normal_idx = 1
abnormal_idx = 0
x_trainの正常データ1だけを学習データに使います。
x_normal_train = x_train[np.where(y_train==normal_idx)]
評価では正常データと異常データの2つを使うので、2つを取り出してまとめます。
x_normal_test = x_test[np.where(y_test==normal_idx)]
x_abnormal_test = x_test[np.where(y_test==abnormal_idx)]
正常データ(1)と異常データ(0)を見てみると以下のようになっています。
まずは正常データ、
次に異常データ
次に先程の図のようなオートエンコーダ(深層学習モデル)を構築していきます。今回は畳み込み層も含めたオートエンコーダーを作ります。オートエンコーダーなので、エンコーダー、デコーダーから成ります。
latent_dim = 32
# エンコーダ
inputs = layers.Input(shape=(28, 28, 1))
x = layers.Conv2D(128, 3, strides=1, padding="same",activation="relu")(inputs)
x = layers.Conv2D(64, 3, strides=1, padding="same", activation="relu")(x)
x = layers.Flatten()(x)
x = layers.Dense(1024, activation="relu")(x)
x = layers.Dense(64, activation="relu")(x)
x = layers.Dense(latent_dim)(x)
# デコーダ
x = layers.Dense(64, activation="relu")(x)
x = layers.Dense(1024, activation="relu")(x)
x = layers.Dense(7 * 7 * 64, activation="relu")(x)
x = layers.Reshape((7, 7, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(128, 3, activation="relu", strides=2, padding="same")(x)
outputs = layers.Conv2DTranspose(1, 3, strides=1,activation="sigmoid", padding="same")(x)
# モデルを作る
model = Model(inputs, outputs)
print(model.summary())
model.compile(optimizer=optimizers.Adam(), loss="binary_crossentropy")
今回は画像を入力とするので、畳み込み層と全結合層を使います。エンコーダではまず畳み込み層に入力してその後に全結合層に入力します。デコーダでは全結合層に入れた後に2Dにするために、Conv2DTransposeに入れます。このようにすることで、出力が画像のような2Dになります。
モデルを構築したので、学習していきます。今回は正常データだけ(つまり1だけ)を学習していきます。
モデルの入力は4次元(バッチサイズ、縦、横、チャネル)が必要で、MNISTにはチャネルがないので、その次元を追加します。MNISTはグレースケールですので、チャネルは1になります。また、学習用とvalidationデータに分けます。
x_normal_train = np.expand_dims(x_normal_train, axis=-1)
x_train, x_val = train_test_split(x_normal_train,train_size=0.9)
次に学習していきます。今回はvalidationデータでの誤差が一番小さいモデルを保存してその重みを評価に使うので、以下のようにvalidationのデータの評価をepoch毎に行うcallbackを用意して、学習(fit)を行います。学習が終わったら、保存しておいたvalidationデータに対して最適な重みをモデルに読み込ませます。
MODEL_DIR = "/content/drive/MyDrive/"
checkpoint_filepath = os.path.join(MODEL_DIR,'checkpoint')
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=True,
monitor='val_loss',
mode='min',
verbose=1,
save_best_only=True)
history = model.fit(x_train,x_train, epochs=30, shuffle=True, validation_data=(x_val, x_val), callbacks=[model_checkpoint_callback])
plt.plot(np.arange(len(history.history["loss"])),history.history["loss"], label="loss")
plt.plot(np.arange(len(history.history["val_loss"])),history.history["val_loss"], label="val_loss")
plt.legend()
model.load_weights(checkpoint_filepath)
学習が終わりましたので、正常データと異常データを入力してみて評価してみます。
まずは入力とモデルの出力をプロットできるように関数を定義しておきます。
def plot_input_output(input_imgs):
plt.figure(figsize=(10,10))
plt_idx = 1
for i in range(36):
if plt_idx > 36:
break
plt.subplot(6,6,plt_idx)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(np.squeeze(input_imgs[i]), cmap="gray", vmin=0, vmax=1)
plt.xlabel("input_{}".format(i))
plt.subplot(6,6,plt_idx+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
output = model.predict(tf.expand_dims(input_imgs[i],axis=0))
plt.imshow(np.squeeze(output), cmap="gray", vmin=0, vmax=1)
plt.xlabel("output_{}".format(i))
plt_idx+=2
plt.show()
それでは最初に正常データを入力してみます。
plot_input_output(x_normal_test)
上の画像は左が入力、右が出力です。ある程度きれいに入力データを再構成できています。
では、異常データはどうでしょうか?
plot_input_output(x_abnormal_test)
0を入力しているのに、再構成されず、1を出力したり、ノイズのようなデータを出力していることがわかります。異常データは学習していないので、再構成できなかったと考えられます。
あとは、この入力と出力の差分を「異常度」とすれば異常検知器が完成します。
今回は異常検知器は作りませんが、すべての評価データの正常データと異常データの異常度のヒストグラムをプロットしてみます。
import seaborn as sns
sns.set(style='white', context='notebook', palette='deep')
def get_errors(input_imgs):
output_img = model.predict(input_imgs)
output_imgs = np.squeeze(output_img)
print(output_imgs.shape)
print(input_imgs.shape)
sub_imgs = np.abs(input_imgs-output_imgs)
errors = np.sum(sub_imgs, axis=(1,2))
return errors
x_normal_errors = get_errors(x_normal_test)
x_abnormal_errors = get_errors(x_abnormal_test)
ax=sns.distplot(x_normal_errors,bins=20, label="normal")
sns.distplot(x_abnormal_errors,ax=ax,bins=20, label="abnormal")
ax.set_xlabel("error")
plt.legend()
多くの正常データが異常度0付近に集まっていて、
異常データは異常度が高いところにあることがわかります。
非常にうまく異常データを検出できていることがわかります。
おわりに
今回はオートエンコーダーを用いたシンプルな異常検知をしてみました。結果としては、MNISTの0,1ぐらいであれば、うまく異常検知できることがわかります。
しかし、やってみればわかると思いますが、実際の工業製品の細かい異常の検知などに使うのは難しいです。というのも、シンプルなオートエンコーダーでは解像度が高い画像の再構成が難しいので、正常データでも入力と出力の差分が大きくなることが想定されるからです。
次回以降はもう少しリアルな実問題に近い設定のデータセットで試していこうと思います。
コメント