【入門】PyTorchでモデルの中間層を取得する

事前学習済みモデルを使って、中間層を出力してみます。

環境

  • pytorch: 1.12.1+cu113
  • torchvision: 0.13.1+cu113

実装

https://github.com/tocom242242/notebooks/blob/master/pytorch/feature_extractor_example.ipynb

まずは必要なモジュールをimportします。

import torch
from torchvision import models
from torchvision.models.feature_extraction import create_feature_extractor

resnet18を読み込みます。

model = models.resnet18()

今回はこのresnet18のlayer1とlayer3の特徴量を取得してみます。
この中間層の特徴量を取得するために、torchvisionのcreate_feature_exractorを使います。

extractor = create_feature_extractor(model, ["layer1", "layer3"])

create_feature_extractorには対象のmodelと取得したいlayerを配列でいれます。

あとは、freature_extractorにデータを入力すれば、指定した中間層の出力を取得できます。

x = torch.randn(1, 3, 224, 224)
features = extractor(x)
print(features["layer1"].shape)
print(features["layer3"].shape)
torch.Size([1, 64, 56, 56])
torch.Size([1, 256, 14, 14])

参考文献

コメント

タイトルとURLをコピーしました