学習済みのViTを使って、CIfar10の画像の特徴量を抽出して、特徴量をプロットしてみました。
コードは以下にもあげてあります。
https://github.com/tocom242242/notebooks/blob/master/pytorch/vit_feature_extractor.ipynb
実装
Colab上で実行していきます。
今回はtimmというライブラリにある学習済みViTを使うのでインストールします。
!pip install timm
学習済みモデルと、transform等を読みこみます
import timm
from PIL import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0)
model.eval()
config = resolve_data_config({}, model=model)
transform = create_transform(**config)
CIFAR10のデータローダーを用意する
import torch
from torchvision.datasets import CIFAR10
from torch.utils.data import Dataset, DataLoader
testset = CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(
testset, batch_size=64, shuffle=False, **kwargs)
TSNEで2次元まで圧縮して可視化する関数
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
def plot_scatter(x_test,y_test):
scatter_x = x_test[:, 0]
scatter_y = x_test[:, 1]
fig, ax = plt.subplots()
for g in np.unique(y_test):
ix = np.where(y_test == g)
ax.scatter(scatter_x[ix], scatter_y[ix], label = g, s = 10)
plt.legend()
def plot_features(model, device, data_loader):
model.to(device)
feature_space = []
labels = []
with torch.no_grad():
for (imgs, label) in tqdm(data_loader, desc='feature extracting'):
imgs = imgs.to(device)
features = model(imgs)
feature_space.append(features)
labels.append(label)
feature_space = torch.cat(feature_space, dim=0).contiguous().cpu().numpy()
labels = torch.cat(labels, dim=0).cpu().numpy()
X_reduced_test = TSNE(n_components=2, random_state=0).fit_transform(feature_space)
plot_scatter(X_reduced_test,labels)
プロットする。
plot_features(model, device, test_loader)
ここまで実行すると以下のような画像がプロットされます。
コメント