学習済みのViTを使ってCIFAR10の画像を可視化してみる

学習済みのViTを使って、CIfar10の画像の特徴量を抽出して、特徴量をプロットしてみました。
コードは以下にもあげてあります。
https://github.com/tocom242242/notebooks/blob/master/pytorch/vit_feature_extractor.ipynb

実装

Colab上で実行していきます。
今回はtimmというライブラリにある学習済みViTを使うのでインストールします。

!pip install timm

学習済みモデルと、transform等を読みこみます

import timm
from PIL import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0)
model.eval()
config = resolve_data_config({}, model=model)
transform = create_transform(**config)

CIFAR10のデータローダーを用意する

import torch
from torchvision.datasets import CIFAR10
from torch.utils.data import Dataset, DataLoader
testset = CIFAR10(root='./data', train=False, download=True, transform=transform)


test_loader = DataLoader(
    testset, batch_size=64, shuffle=False, **kwargs)

TSNEで2次元まで圧縮して可視化する関数

from tqdm import tqdm
from sklearn.metrics import roc_auc_score
from sklearn.manifold import TSNE
import matplotlib.pyplot  as plt
import numpy as np
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")


def plot_scatter(x_test,y_test):
    scatter_x = x_test[:, 0]
    scatter_y = x_test[:, 1]
    

    fig, ax = plt.subplots()
    for g in np.unique(y_test):
        ix = np.where(y_test == g)
        ax.scatter(scatter_x[ix], scatter_y[ix],  label = g, s = 10)
    plt.legend()


def plot_features(model, device, data_loader):
    model.to(device)
    feature_space = []
    labels = []
    with torch.no_grad():
        for (imgs, label) in tqdm(data_loader, desc='feature extracting'):
            imgs = imgs.to(device)
            features = model(imgs)
            feature_space.append(features)
            labels.append(label)
        feature_space = torch.cat(feature_space, dim=0).contiguous().cpu().numpy()
        labels = torch.cat(labels, dim=0).cpu().numpy()
    X_reduced_test = TSNE(n_components=2, random_state=0).fit_transform(feature_space)
    plot_scatter(X_reduced_test,labels)

プロットする。

plot_features(model, device, test_loader)

ここまで実行すると以下のような画像がプロットされます。

参考文献

コメント

タイトルとURLをコピーしました