機械学習でよく使われるSTL-10データセットを紹介します。
STL-10データセットはスタンフォード大学が公開しているデータセットで、以下のような特徴があります。
- ラベルありとなしのデータがある
- ラベルあり
- 10クラス(飛行機、鳥、車、猫、熊、犬、馬、猿、船、トラック)のデータ
- それぞれのクラスにつき、500枚の学習用画像、800枚のテスト用画像が格納
- ラベルなし
- 100000のラベルなしの画像データ
- ラベルあり
- 解像度は96×96
PyTorchでSTL-10データセットを使ってみる
PyTorchにSTL-10が内包されているので、使ってみて、とりあえず表示してみます。
import torch
import torchvision
import matplotlib.pyplot as plt
ds = torchvision.datasets.STL10
# splitに{‘train’, ‘test’, ‘unlabeled’, ‘train+unlabeled’} が設定できる
trainset = ds(root='data', split="train", download=True)
testset = ds(root='data', split="test", download=True)
def plot_ds(dataset, row=10, col=1, figsize=(10,40)):
fig_img, ax_img = plt.subplots(row, col, figsize=figsize, tight_layout=True)
plt.figure()
for i in range(row):
img, label = dataset[i]
ax_img[i].imshow(img)
ax_img[i].set_title(label)
fig_img.savefig("data_sample.png", dpi=100)
plt.close()
plot_ds(trainset)
画像とタイトルにラベルを出力しています。
コメント