機械学習(特に強化学習)が好きな人のノート

機械学習関連のことをまとめていきます。強化学習関連が多いかもしれません

【深層学習、keras】mnistの数字の回帰モデルをkerasで作る

kerasでmnistの数字の画像を入力したら、その数字を出力するような回帰モデルを
作ってみます。 0と書いてある画像を入力したら、0を出力し、
9と書いてある画像を入力したら、9と出力するようにモデルを作成します。
以下はイメージ図

f:id:ttt242242:20190331135232j:plain

ちなみにCNNは使いません!

データセット

有名なmnistを用います。
手書き文字の認識用データセットです。
0〜9までの手書きの数字の画像が、学習用、テスト用でそれぞれ60000枚、10000枚用意されています。
特別ダウンロードする必要はありません。
下のプログラムを実行すれば、自動的にダウンロードされます。

プログラム

import keras
from keras.datasets import mnist
import numpy as np
from keras.models import Sequential
from keras.layers.core import Dense, Activation

# mnistのデータの取得
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 二次元配列から一次元に変換
x_train = np.array(x_train).reshape(len(x_train), 784)  
x_test = np.array(x_test).reshape(len(x_test), 784)
x_train = np.array(x_train).astype("float32")
x_test = np.array(x_test).astype("float32")
# 0〜1に正規化
x_train /= 255
x_test /= 255

y_train = np.array(y_train)
y_test = np.array(y_test)

# モデルの構築
model = Sequential()
model.add(Dense(256, input_shape=(784,)))
model.add(Activation('relu'))

model.add(Dense(32))
model.add(Activation('relu'))

model.add(Dense(1))
model.add(Activation('linear'))

# 誤差関数は平均二乗誤差、最適化手法はrmsprop
model.compile(loss="mean_squared_error", optimizer="rmsprop")

# 学習
history = model.fit(x_train, y_train,
                    batch_size=32, nb_epoch=100,
                    verbose=1, validation_split=0.2)

# 20個程度のテストデータを使って予測結果の確認
print("正解:予測")
for x, y in zip(x_test[0:20], y_test[0:20]):
    predicted_y = model.predict(np.array([x]))[0][0]
    print("{}:{}".format(y,predicted_y))

結果

プログラムにある通り、いくつかtestデータを取り出して、結果を簡易的に見てみます。

正解:予測
7:7.071030616760254
2:1.9824368953704834
1:1.0134353637695312
0:0.13662612438201904
4:4.098000526428223
1:0.9860554933547974
4:4.018962383270264
9:7.225198745727539
5:6.004607200622559
9:9.107109069824219
0:0.09038352966308594
6:6.015869140625
9:9.079833984375
0:0.18159937858581543
1:0.9922631978988647
5:4.972866058349609
9:9.171490669250488
7:7.072229385375977
3:2.5052080154418945
4:3.9980337619781494

今回は過学習とかは特に考慮していませんが、ある程度うまくいってそうです。
あと、正解データも正規化したほうが良かったのですが、
今回はしませんでした。

終わりに

さすがにkerasを使うとシンプルに記述できます。
kerasでもう少しいろいろやってみたいです。