【ディープラーニング】kerasで多クラス分類 〜irisデータセットを用いて〜

前回はKerasを使って2値分類をしました。

【超入門、ニューラルネットワーク】Kerasで2値分類 〜超簡単な例〜
今回はKerasを使ってニューラルネットワークを実装して、2値分類をやってみます。 超初心者向けの内容になります。 KerasとはKerasは、ニューラルネットワークを非常にシンプルに構築できるライブラリです。TensorF...

今回はirisデータをセットを使って、多クラス分類をしてみました。
本記事では、pythonの基本操作はできることを想定しています。

IT・Web・ゲーム業界のエンジニア転職なら【Tech Stars Agent】

Kerasとは

Kerasは、ニューラルネットワークを非常にシンプルに構築できるライブラリです。

Home - Keras Documentation
Documentation for Keras, the Python Deep Learning library.

TensorFlow等で書くとかなり長くなってしまうコードがKerasを使うことでシンプルなコードとなります。

ちなみにKerasは下でTensorFlow等が動いています。

多クラス分類とは

多クラス分類とは、データの性質によって複数の対象(クラス)に分類することいいます。

例えば、ある果物が与えられたらその果物をりんご、バナナ、みかん等に分類することです。

問題設定

今回はiris(あやめ)の分類問題を扱います。
ここでは、あやめのデータを取得し、そのデータをモデル(ニューラルネットワーク)に学習させます。
そして、入力情報からどのあやめかを出力させます。

irisデータセット

irisデータセットとは、
あやめの花に関するデータセットになります。
3種類のあやめのデータがそれぞれ50個づつ格納されています。

あやめの情報として、

  • がく片の長さ(sepal length)
  • がく片の幅(sepal width)
  • 花びらの長さ(patal length)
  • 花びらの幅(petal width)

が保存されています。

データセットを見てみましょう。


実装

では、実装していきます。

まず必要なモジュールをimportします。

irisデータセットの読み込みやデータの加工などを行います。

ラベルデータをone-hotエンコードをします。

学習用データとテスト用データに分けます。

モデルを生成します。
今回は非常にシンプルなモデルを用います。

モデルをコンパイルします。
ここで、誤差関数や最適化手法を設定します。

モデルに学習させていきます

学習過程のプロットしてみます。

accが正答率、lossが誤差(binary_crossentropy)を表しています。
epoch(学習回数)が進むに連れて、正答率が向上して、誤差が小さくなっていることがわかります。
うまく学習が進んでいるようです。

モデルをテストデータで評価してみます。

Test scoreが誤差、Test accuracyは正答率を表しています。
正答率が93.3%ですのでけっこう良くできていることがわかります。

いくつかデータを予測してみます。

以下が出力になります。

correct answerが正解データ、predict answerがモデルからの予想値になります。
正しく予想できていることがわかります。

ソースコードの全体像

最後にソースコードの全体像を以下に示します。

参考文献


コメント

タイトルとURLをコピーしました