【PyTorch】STL-10 データセット

機械学習でよく使われるSTL-10データセットを紹介します。

STL-10データセットはスタンフォード大学が公開しているデータセットで、以下のような特徴があります。

  • ラベルありとなしのデータがある
    • ラベルあり
      • 10クラス(飛行機、鳥、車、猫、熊、犬、馬、猿、船、トラック)のデータ
      • それぞれのクラスにつき、500枚の学習用画像、800枚のテスト用画像が格納
    • ラベルなし
      • 100000のラベルなしの画像データ
  • 解像度は96×96

PyTorchでSTL-10データセットを使ってみる

PyTorchにSTL-10が内包されているので、使ってみて、とりあえず表示してみます。

import torch
import torchvision
import matplotlib.pyplot as plt

ds = torchvision.datasets.STL10

# splitに{‘train’, ‘test’, ‘unlabeled’, ‘train+unlabeled’} が設定できる
trainset = ds(root='data', split="train", download=True)
testset = ds(root='data', split="test", download=True)

def plot_ds(dataset, row=10, col=1, figsize=(10,40)):
    fig_img, ax_img = plt.subplots(row, col, figsize=figsize, tight_layout=True)
    plt.figure()
    for i in range(row):
        img, label = dataset[i]
        ax_img[i].imshow(img)
        ax_img[i].set_title(label)

    fig_img.savefig("data_sample.png", dpi=100)
    plt.close()

plot_ds(trainset)

画像とタイトルにラベルを出力しています。

参考文献

コメント

タイトルとURLをコピーしました