ModuleListを使うことで、サブモジュールをリスト形式で保持でできます。
公式ドキュメントを少し変更して使ってみます。
環境
- python: 3.8.16
- torch: 1.13.0+cu116
コード
import torch
from torch import nn
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
def forward(self, x):
for l in self.linears:
x = l(x)
# 配列なので、以下でも同様に動く
# for i, l in enumerate(self.linears):
# x = self.linears[i](x)
return x
net = MyModule()
print(net)
以下は出力になります。
MyModuleにサブモジュールがfor文で回した分だけ保持されていることがわかります。
MyModule(
(linears): ModuleList(
(0): Linear(in_features=10, out_features=10, bias=True)
(1): Linear(in_features=10, out_features=10, bias=True)
(2): Linear(in_features=10, out_features=10, bias=True)
(3): Linear(in_features=10, out_features=10, bias=True)
(4): Linear(in_features=10, out_features=10, bias=True)
(5): Linear(in_features=10, out_features=10, bias=True)
(6): Linear(in_features=10, out_features=10, bias=True)
(7): Linear(in_features=10, out_features=10, bias=True)
(8): Linear(in_features=10, out_features=10, bias=True)
(9): Linear(in_features=10, out_features=10, bias=True)
)
)
入力してみます。
x = torch.randn(10)
print(net(x))
#=> tensor([ 0.3028, -0.2671, 0.2963, 0.3608, 0.3891, -0.3487, -0.0378, 0.2440,
# 0.3610, -0.1338], grad_fn=<AddBackward0>)
コメント