画像データとその他の数値データを入力可能なモデルをPyTorchで作ってみます。
ここで紹介したコードはgithubにあげてあります。
https://github.com/tocom242242/example-input-img-numerical_value-model
データの用意
画像path, meta1, labelの情報があるcsvがあるとします。
path,info1,label
imgs\img1.png,1,0
imgs\img2.png,2,1
imgs\img3.png,3,1
データセットの作成
csvを読み込むデータセットを作成します。
mydataset.py
import pandas as pd
import torch
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224, 224)),
])
class MyDataset(Dataset):
def __init__(self, csv_path):
df = pd.read_csv(csv_path, encoding="shift-jis")
self.img_paths = df["path"].tolist()
self.meta = df["info1"].tolist()
self.y = df["label"].tolist()
def __len__(self):
return len(self.img_paths)
def __getitem__(self, idx):
img_path = self.img_paths[idx]
img = Image.open(img_path).convert("RGB")
img = transform(img)
_meta = self.meta[idx]
label = self.y[idx]
return img, _meta, label
モデルの作成
画像と数値データを入力できるモデルを作成します。
構成としては、
- 画像から特徴量を抽出するモデル(resnet18)
- 数値データを入力するモデル
- それをconcatするレイヤー
になります。
画像から特徴量を抽出するモデルはresnet18を用いています。
model.py
import torch
import torchvision.models as models
from torch import nn
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.resnet = models.resnet18(pretrained=True)
self.mlp = nn.Sequential(
nn.Linear(1, 30), nn.BatchNorm1d(30), nn.ReLU()
)
self.output_layer = nn.Linear(1000 + 30, 1)
def forward(self, x, meta):
z = self.resnet(x)
z2 = self.mlp(meta)
z3 = torch.cat((z, z2), dim=1)
output = self.output_layer(z3)
return output
動かしてみる
もろもろを組み合わせて動かしてみます。
ここでは、学習などはせずに、単純にデータを入力するだけです。
(csvとかデータローダーを作る必要なかったのだけれど・・・)
main.py
import torch
from mydataset import MyDataset
from torch.utils.data import DataLoader
mydataset = MyDataset("data.csv")
train_dataloader = DataLoader(mydataset, batch_size=3, shuffle=True)
from model import MyModel
model = MyModel()
for x, meta, y in train_dataloader:
meta = meta.unsqueeze(1).float()
y = y.type(torch.float64)
output = model(x, meta)
print(output)
これを実行すると、とりあえず推論した結果が返ってくるはずです。
コメント
githubのコードが閲覧することが出来ません.
ご連絡ありがとうございます。
また、ご迷惑おかけして申し訳ありません。
githubの設定がprivateになっていました。先ほど、publicに変更しました。
ご確認よろしくお願いいたします。