今回は、PyTorchを使って深層学習モデルがどこを注視して判断したかを可視化してみます。公開されているライブラリを使うだけですので、簡単にできます。
ライブラリについて
pytorch-grad-camというライブラリを使います。
可視化する手法がまとまっているのライブラリです。
簡単にインストールして使うことができます。
pip install grad-cam
実験(実装)
前回同様にPyTorchに分類問題を解かせた後に、そのモデルがどこを見て判断したかを見ていきます。
データセットはCIFAR10を使います。
コードはgithubにあげてあります。
https://github.com/tocom242242/pytorch-cifar10-simple-example/blob/main/gradcam_cifar10.ipynb
前準備:モデルの学習まで
まずは、必要なモジュール群をインポートします。
import torch
from torch.utils.data import DataLoader, Subset
import torchvision
from torchvision import models
import torchvision.transforms as transforms
import numpy as np
from torchvision.transforms import InterpolationMode
import matplotlib.pyplot as plt
from tqdm import tqdm
BICUBIC = InterpolationMode.BICUBIC
データローダーを取得します。
def get_loaders(batch_size):
ds = torchvision.datasets.CIFAR10
transform = transforms.Compose([
transforms.Resize(32, interpolation=BICUBIC),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
trainset = ds(root='data', train=True, download=True, transform=transform)
indices = torch.arange(5000)
trainset = Subset(trainset, indices)
n_samples = len(trainset)
train_size = int(len(trainset) * 0.9)
val_size = n_samples - train_size
trainset, valset = torch.utils.data.random_split(trainset, [train_size, val_size])
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2,
drop_last=False)
val_loader = DataLoader(valset, batch_size=batch_size, shuffle=True, num_workers=2,
drop_last=False)
testset = ds(root='data', train=False, download=True, transform=transform)
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2,
drop_last=False)
return train_loader, val_loader, test_loader
def plot_ds(dataset, row=10, col=1, figsize=(20,10)):
fig_img, ax_img = plt.subplots(row, col, figsize=figsize, tight_layout=True)
plt.figure()
for i in range(row):
img1,_ = dataset[i]
img1 = denormalization(img1)
img1 = np.squeeze(img1)
ax_img[i].imshow(img1)
fig_img.savefig("data_sample.png", dpi=100)
plt.close()
def inverse_normalize(tensor, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
if mean.ndim == 1:
mean = mean.view(-1, 1, 1)
if std.ndim == 1:
std = std.view(-1, 1, 1)
tensor.mul_(std).add_(mean)
return tensor
def denormalization(x):
mean = (0.4914, 0.4822, 0.4465)
std = (0.2023, 0.1994, 0.2010)
x = inverse_normalize(x, mean, std)
x = x.cpu().detach().numpy()
x = (x.transpose(1, 2, 0) * 255.0).astype(np.uint8)
return x
train_loader, val_loader, test_loader = get_loaders(batch_size=32)
plot_ds(train_loader.dataset)
前回同様にCIFAR10を分類するモデルを作成します。
class AverageMeter(object):
def __init__(self, name, fmt=":f"):
self.name = name
self.fmt = fmt
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
return fmtstr.format(**self.__dict__)
class ClassifierModel(torch.nn.Module):
def __init__(self):
super(ClassifierModel, self).__init__()
# 畳み込み層
self.backbone = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=3, out_channels=128,
kernel_size=3, stride=2, padding=1),
torch.nn.ReLU(True),
torch.nn.BatchNorm2d(128),
torch.nn.Conv2d(in_channels=128, out_channels=32,
kernel_size=3, stride=2, padding=1),
torch.nn.ReLU(True),
torch.nn.BatchNorm2d(32),
torch.nn.Conv2d(in_channels=32, out_channels=16,
kernel_size=3, stride=2, padding=1),
torch.nn.BatchNorm2d(16),
torch.nn.Flatten()
)
# 全結合層
self.mlp_layers = torch.nn.Sequential(
torch.nn.Linear(256, 50),
torch.nn.ReLU(True),
torch.nn.Linear(50, 10),
torch.nn.Softmax(dim=1),
)
def forward(self, x):
x = self.backbone(x)
y = self.mlp_layers(x)
return y
では、学習させます。
# 学習など
model = ClassifierModel()
epochs = 10
train_loss = AverageMeter("train_loss")
train_acc = AverageMeter("train_acc")
val_loss = AverageMeter("val_loss")
val_acc = AverageMeter("val_acc")
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=30, gamma=0.1
)
min_loss = np.inf
for epoch in tqdm(range(epochs)):
model.train()
for x, y in train_loader:
optimizer.zero_grad()
outputs = model(x)
loss = criterion(outputs.squeeze(), y)
_, predicted = torch.max(outputs.data, 1)
accuracy = (predicted==y).sum().item()/y.size(0)
train_loss.update(loss.data)
train_acc.update(accuracy)
optimizer.zero_grad()
loss.backward()
optimizer.step()
model.eval()
for x, y in val_loader:
outputs = model(x)
loss = criterion(outputs.squeeze(), y)
val_loss.update(loss.data)
_, predicted = torch.max(outputs.data, 1)
accuracy = (predicted==y).sum().item()/y.size(0)
val_acc.update(accuracy)
if val_loss.avg < min_loss:
print("save model")
torch.save(model.state_dict(), "model.pth")
min_loss = val_loss.avg
print(
"[epoch :{:.1f} train_loss: {} val_loss: {} train_acc: {} val_acc: {}] ".format(
epoch, train_loss.avg, val_loss.avg,
train_acc.avg, val_acc.avg
)
)
scheduler.step()
train_loss.reset()
val_loss.reset()
model.load_state_dict(torch.load("model.pth"))
可視化
ようやく可視化していきます。まずは使うモジュール群をimportします。
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam import GradCAM
データセット読み込んで、モデルに入力するようのtransformと単純に表示するtransformを用意します。
また、今回は学習データの60番目のデータを可視化してみます。(車です)
model.eval()
ds = torchvision.datasets.CIFAR10
input_transform = transforms.Compose([
transforms.Resize(32, interpolation=BICUBIC),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img_transform = transforms.Compose([
transforms.Resize(32, interpolation=BICUBIC),
transforms.ToTensor(),
])
trainset = ds(root='data', train=True, download=True)
img, label = trainset[60]
input_img = input_transform(img)
img = img_transform(img)
どこの層を使うかと決める必要があります。今回は畳み込みネットワークの出力をつかいます。
target_layers = [model.backbone[-3]]
cam = GradCAM(
model=model, target_layers=target_layers, use_cuda=torch.cuda.is_available()
)
入力画像をGradCamのインスタンスに入力し、heatmapを取得し結果を出力します。
grayscale_cam = cam(
input_tensor=input_img.unsqueeze(0),
targets=[ClassifierOutputTarget(label)],
)
# 最初の出力だけ取得
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(img.permute(1, 2, 0).numpy(), grayscale_cam, use_rgb=True)
fig, ax = plt.subplots(1,2)
ax[0].imshow(img.permute(1, 2, 0).numpy())
ax[1].imshow(visualization)
ここまでを実行すると以下のように出力されます。

車のタイヤ部分に主に反応しているようです。
ついでにカエルも可視化してみます。

終わりに
今回はGrad-Camを使って判断根拠の可視化をしてみました。まぁまぁうまく可視化できていると思います。ただ、たまに変なところに反応していることもあるので、もしかしたら何かが間違っているかもしれません。
コメント