【入門,異常検知】fashion_mnistデータセットを使って単純にOneClassSVMで分類してみて失敗した【Python】

以前、簡単な2クラスのデータセットを使ってOneClassSVMを使って異常検知をやってみました。

今回はもう少し問題を難しくして、fashion_mnistデータセットのサンダルと靴のデータを使って異常検知してみます。

データセットと正常値・異常値

今回使うFashion MNISTはMNISTっていう数字のデータセットを靴とか衣服とかにしたデータセットです。
なので、MNISTと同様に0〜9までのクラスがあるデータセットです

fashion mnistの例


詳細はこの記事は参考にしてください。

今回はこのFashin MNISTのサンダルを正常値、ブーツを異常値として扱います。

ちなみにサンダルとブーツは以下のような感じです。

実装していく

今回は、単純にOneClassSVMを使って、異常検知をしていきます。

必要なモジュールのインポート

まずは必要なモジュール群をimportしていきます。

from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras import datasets, layers, models
from sklearn.svm import OneClassSVM
from sklearn.metrics import roc_curve

学習・評価用データの作成

正常データ(学習用)、正常・異常データ(評価用)のデータセットを作成します。

まずは、fashion mnistデータセットを取得します。

(x_train, y_train), (x_test, y_test) = datasets.fashion_mnist.load_data()

画像データは正規化しておきます。

x_train = x_train / 255
x_test = x_test / 255

正常データはサンダル(左)、異常データ(右)はブーツとし、データを作成します。

normal_idx = 5 # サンダル
abnormal_idx = 7 # ブーツ

x_normal_train = x_train[np.where(y_train==normal_idx,True,False)]
x_normal_test = x_test[np.where(y_test==normal_idx,True,False)]

x_abnormal_train = x_train[np.where(y_train==abnormal_idx,True,False)]
x_abnormal_test = x_test[np.where(y_test==abnormal_idx,True,False)]

評価用の入力データを作成します。

x_test = np.concatenate((x_normal_test,x_abnormal_test))

今回は、各ピクセルの値を1次元にして使うので、以下のようにreshapeします。

x_test = x_test.reshape(x_test.shape[0], 28*28)
x_normal_train = x_normal_train.reshape(x_normal_train.shape[0], 28*28)

評価用にラベルデータを作成します。

# normalが1,異常が0
y_test = np.concatenate((np.ones(x_normal_test.shape[0]),np.zeros(x_abnormal_test.shape[0])))

OneClassSVMの学習と評価

では、さっそくOneClassSVMに学習させていきたいと思います。


from sklearn.svm import OneClassSVM

svm = OneClassSVM()
svm.fit(x_normal_train)

コードとしてはこれだけです。

では、評価していきます。

ROCカーブとAUCを見ていきます。

y_score = svm.decision_function(x_test)
from sklearn.metrics import roc_auc_score
auc = roc_auc_score(y_test,y_score)

from sklearn.metrics import roc_curve
fpr, tpr, thresholds = roc_curve(y_test, y_score)

plt.plot(fpr, tpr, label='baseline(AUC = %.2f)'%auc)
plt.plot([0,1],[0,1],'k--')
plt.legend()
plt.title('ROC curve')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.grid(True)
plt.show()

予想通りですが、あまりうまく学習出来ていませんね。


はやり今回の画像で似たような画像であると純粋な画像のデータをそのまま特徴量として扱うのでは中々うまくいかないようです。

終わりに

今回はFashion Mnistを使った単純なOne Class SVMを使った異常検知をやってみました。

結果としては純粋な画像データを一次元にしただけでは、うまく異常検知できないことがわかりました。

次回としては画像データを次元圧縮するなどの工夫を加えていこうと思います。
以下、次の記事

タイトルとURLをコピーしました