【入門,PyTorch】ディープラーニングで多クラス分類をしてみる【PyTorch,Cifar10】

PyTorch を使って多クラス分類を実装してみようと思います。

やってみること

今回はCifar10というデータセットを使って、他クラス分類の実装をしてみます。

Cifa10は10クラス分のラベル付きの画像データが格納されているデータセットです。
以下のような飛行機、自動車、猫などの10クラス分のデータが格納されています。

今回は畳み込み層と全結合層のニューラルネットワークを作成(以下の図)し、Cifar10の画像データのクラスを予測できるプログラムを作成します。(チュートリアルにたくさんありそうですが笑)

model,モデル
作成するモデルのイメージ図

実装

では、実装していきます。
全体のコードはGithubで公開しています。

https://github.com/tocom242242/pytorch-cifar10-simple-example/blob/main/pytorch_cifar10.ipynb

まずは必要なモジュールをimportします。

import torch
from torch.utils.data import DataLoader, Subset
import torchvision
import torchvision.transforms as transforms
import numpy as np
from torchvision.transforms import InterpolationMode
import matplotlib.pyplot as plt
from tqdm import tqdm

# 補間手法
BICUBIC = InterpolationMode.BICUBIC

データローダーを取得する関数を書いていきます。

def get_loaders(batch_size):
    ds = torchvision.datasets.CIFAR10
    transform = transforms.Compose([
        transforms.Resize(32, interpolation=BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])


    trainset = ds(root='data', train=True, download=True, transform=transform)
    indices = torch.arange(10000) 
    trainset = Subset(trainset, indices)

    n_samples = len(trainset)
    train_size = int(len(trainset) * 0.9)
    val_size = n_samples - train_size
    trainset, valset = torch.utils.data.random_split(trainset, [train_size, val_size])


    train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2,
                                                drop_last=False)
    val_loader = DataLoader(valset, batch_size=batch_size, shuffle=True, num_workers=2,
                                                drop_last=False)

    testset = ds(root='data', train=False, download=True, transform=transform)
    test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2,
                                                drop_last=False)
    return train_loader, val_loader, test_loader

DataLoaderは指定したバッチ数でデータセットからデータをロードしてくれます。
基本的にはCifar10は公式のDatasetがあるので、それを利用します。
今回はこちらの計算資源の都合上学習データは10000枚程度にしています。(実際には10000枚のうち1000枚はバリデーションに用いるので、9000枚程度しか学習に使いません)

cifar10のデータをプロットしてみます。(これは任意)

def plot_ds(dataset, row=10, col=1, figsize=(20,10)):
    fig_img, ax_img = plt.subplots(row, col, figsize=figsize, tight_layout=True)
    plt.figure()
    for i in range(row):
        img1,_ = dataset[i]
        img1 = denormalization(img1)
        img1 = np.squeeze(img1)
        ax_img[i].imshow(img1)
        
    fig_img.savefig("data_sample.png", dpi=100)
    plt.close()

def inverse_normalize(tensor, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
    mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
    std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
    if mean.ndim == 1:
        mean = mean.view(-1, 1, 1)
    if std.ndim == 1:
        std = std.view(-1, 1, 1)
    tensor.mul_(std).add_(mean)
    return tensor

def denormalization(x):
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2023, 0.1994, 0.2010)
    x = inverse_normalize(x, mean, std)
    x = x.cpu().detach().numpy()

    x = (x.transpose(1, 2, 0) * 255.0).astype(np.uint8)

    return x

train_loader, val_loader, test_loader = get_loaders(batch_size=32)
plot_ds(train_loader.dataset)

上を実行すると以下のようにプロットされると思います。

cifar10_2

次にモデルを作成します。

class ClassifierModel(torch.nn.Module):
    def __init__(self):
        super(ClassifierModel, self).__init__()

        # 畳み込み層
        self.cnn_layers = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=3, out_channels=128,
                      kernel_size=3, stride=2, padding=1),
            torch.nn.ReLU(True),
            torch.nn.BatchNorm2d(128),
            torch.nn.Conv2d(in_channels=128, out_channels=32,
                      kernel_size=3, stride=2, padding=1),
            torch.nn.ReLU(True),
            torch.nn.BatchNorm2d(32),
            torch.nn.Conv2d(in_channels=32, out_channels=16,
                      kernel_size=3, stride=2, padding=1),
            torch.nn.BatchNorm2d(16),
            torch.nn.Flatten()
        )
        # 全結合層
        self.mlp_layers = torch.nn.Sequential(
            torch.nn.Linear(256, 50),
            torch.nn.ReLU(True),
            torch.nn.Linear(50, 10),
            torch.nn.Softmax(dim=1),
        )

    def forward(self, x):
        x = self.cnn_layers(x)
        y = self.mlp_layers(x)
        return y

最初に紹介したように畳み込み層と全結合層から構成されています。
forward関数でデータの処理の流れを書いています。基本的には畳み込み層に通した後に全結合層に通しているだけです。
最後にSoftmax層に通すことで確率分布にしています。

次にlossや正答率を計算するためのAverageMeterを書きます。これはpytorchの公式のexampleそのままです。

class AverageMeter(object):
    def __init__(self, name, fmt=":f"):
        self.name = name
        self.fmt = fmt
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
        return fmtstr.format(**self.__dict__)

モデルのインスタンスの作成などを行います。

epochs = 200
model = ClassifierModel()
train_loss = AverageMeter("train_loss")
train_acc = AverageMeter("train_acc")
val_loss = AverageMeter("val_loss")
val_acc = AverageMeter("val_acc")
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, step_size=30, gamma=0.1
)

min_loss = np.inf

学習するコードを書いていきます。
各epochで学習とバリデーションデータでの評価を行っています。

for epoch in tqdm(range(epochs)):
    model.train()
    for x, y in train_loader:
        optimizer.zero_grad()
        outputs = model(x)
        loss = criterion(outputs.squeeze(), y)
        _, predicted = torch.max(outputs.data, 1)
        accuracy = (predicted==y).sum().item()/y.size(0)
        train_loss.update(loss.data)
        train_acc.update(accuracy)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()
    for x, y in val_loader:
        outputs = model(x)
        loss = criterion(outputs.squeeze(), y)
        val_loss.update(loss.data)

        _, predicted = torch.max(outputs.data, 1)
        accuracy = (predicted==y).sum().item()/y.size(0)
        val_acc.update(accuracy)

    if val_loss.avg < min_loss:
        torch.save(model.state_dict(), "model.pth")
        min_loss = val_loss.avg

    print(
        "[epoch :{:.1f} train_loss: {} val_loss: {} train_acc: {} val_acc: {}] ".format(
            epoch, train_loss.avg, val_loss.avg,
            train_acc.avg, val_acc.avg
        )
    )
    scheduler.step()
    train_loss.reset()
    val_loss.reset()

では、上のコードを実行し学習させていきます。
徐々に誤差が小さくなっていき、正答率が向上していることがわかります。

..... 省略

[epoch :95.0 train_loss: 1.629211187362671 val_loss: 1.9153571128845215 train_acc: 0.7793041550186368 val_acc: 0.5395710613019169] 
 97%|█████████▋| 97/100 [24:32<00:45, 15.12s/it][epoch :96.0 train_loss: 1.6286406517028809 val_loss: 1.914373755455017 train_acc: 0.780023097065314 val_acc: 0.5396643720562564] 
 98%|█████████▊| 98/100 [24:47<00:30, 15.10s/it][epoch :97.0 train_loss: 1.6299864053726196 val_loss: 1.9141188859939575 train_acc: 0.7807253292690878 val_acc: 0.53975985362196] 
 99%|█████████▉| 99/100 [25:03<00:15, 15.10s/it][epoch :98.0 train_loss: 1.6289007663726807 val_loss: 1.9147344827651978 train_acc: 0.7814244683254268 val_acc: 0.5398312195436796] 
100%|██████████| 100/100 [25:18<00:00, 15.18s/it][epoch :99.0 train_loss: 1.6294656991958618 val_loss: 1.9138569831848145 train_acc: 0.782098642172524 val_acc: 0.5399291134185303] 

学習データに対しては70%の正答率ですが、バリデーションデータに対しては50%程度になっています。

最後にテストデータでテストしてみます。

model.eval()
accuracy = 0
num_total = 0
for x, y in test_loader:
    outputs = model(x)
    _, predicted = torch.max(outputs.data, 1)
    accuracy += (predicted==y).sum().item()
    num_total += y.size(0)

accracy = accuracy/num_total
print(f"accuracy:{accracy}")

上を実行するとテストデータでの正答率が算出されます。

accuracy:0.552

バリデーションデータに対しての正答率と同等です。ベースラインとしてはまぁまぁだと思います。

おわりに

今回はPyTorchを用いて他クラス分類を実装してみました。
他クラス分類の入門としては良いのかなと思います。
今後はさらに精度を上げていくためにいろいろやっていきたいです。

コメント

タイトルとURLをコピーしました