【PyTorch】データセットを拡張する


CIFAR10のデータセットを自分なりに改造して新しいデータセットを作ってみます。

先にコードを載せておきます。
(CIFAR10のデータをランダムで回転させるようにしています。)

import torch
import torchvision
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

class MyCIFAR10(torch.utils.data.Dataset):
    def __init__(self):
         self.dataset = torchvision.datasets.CIFAR10(root="data", train="train", download=True)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        (img, target) = self.dataset[idx]
        # ランダムで回転させる
        rot_num = np.random.choice([Image.ROTATE_90,Image.ROTATE_180,Image.ROTATE_270])
        img = img.transpose(rot_num)
        return img, target

ds = MyCIFAR10()
img, target = ds[0]
plt.imshow(img)

やっていることは単純で、新しいデータセットのクラスを用意して、
その中でCIFAR10のデータセットを用意して、getitemで取り出して、
何かしらの操作をしているだけです。

コメント

タイトルとURLをコピーしました