【PyTorch】リスト形式でモデルを保持できるModuleList

ModuleListを使うことで、サブモジュールをリスト形式で保持でできます。

公式ドキュメントを少し変更して使ってみます。

環境

  • python: 3.8.16
  • torch: 1.13.0+cu116

コード

import torch
from torch import nn

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

    def forward(self, x):
        for l in self.linears:
            x = l(x)
        # 配列なので、以下でも同様に動く
        # for i, l in enumerate(self.linears):
        #     x = self.linears[i](x)
    return x

net = MyModule()
print(net)

以下は出力になります。
MyModuleにサブモジュールがfor文で回した分だけ保持されていることがわかります。

MyModule(
  (linears): ModuleList(
    (0): Linear(in_features=10, out_features=10, bias=True)
    (1): Linear(in_features=10, out_features=10, bias=True)
    (2): Linear(in_features=10, out_features=10, bias=True)
    (3): Linear(in_features=10, out_features=10, bias=True)
    (4): Linear(in_features=10, out_features=10, bias=True)
    (5): Linear(in_features=10, out_features=10, bias=True)
    (6): Linear(in_features=10, out_features=10, bias=True)
    (7): Linear(in_features=10, out_features=10, bias=True)
    (8): Linear(in_features=10, out_features=10, bias=True)
    (9): Linear(in_features=10, out_features=10, bias=True)
  )
)

入力してみます。

x = torch.randn(10)
print(net(x))
#=> tensor([ 0.3028, -0.2671,  0.2963,  0.3608,  0.3891, -0.3487, -0.0378,  0.2440,
#         0.3610, -0.1338], grad_fn=<AddBackward0>)

参考文献

コメント

タイトルとURLをコピーしました