Flask-RESTX を使って PyTorch で作ったモデルで推論した結果を返すだけの API を作ってみた

タイトルにある通り、Flask-RESTX を使って、
PyTorch で作ったモデルで推論した結果を返すだけの API を作ってみます。
※筆者は Flask-RESTX の入門者なので、間違い等があれば教えてください

今回は、入力は画像で、PyTorch のモデルは特に学習しません。

環境

  • python 3.10.5
  • flask-restx 1.0.6
  • pytorch 1.12.1
  • Flask 2.2.3

インストール

以下のようにインストールしていきます。
注意点としては、flask-restx は Python 3.7 以上でないと動きません。

pip install Flask
pip install flask-restx
pip install torch torchvision torchaudio

実装

コードは github にあげてあります。
https://github.com/tocom242242/example-api-flask_restx-pytorch

model.py

モデルは公式のチュートリアルのものをそのまま使います。
学習も何もしていないので、出力は適当です。

import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    # https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

main.py

API の部分は以下のようになります。

import torch
import torchvision.transforms as transforms
from flask import Flask
from flask_restx import Api, Resource
from PIL import Image
from werkzeug.datastructures import FileStorage

from model import Net

transform = transforms.Compose(
    [
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

# モデルを作成
model = Net()

# APIを生成
app = Flask(__name__)
api = Api(app, doc="/doc/")

upload_parser = api.parser()
upload_parser.add_argument("file", location="files", type=FileStorage, required=True)


def pre_process(image):
    image = transform(image)
    return image


def post_process(model_output):
    return model_output


@api.route("/prediction")
@api.expect(upload_parser)
class ExampleResource(Resource):
    def post(self):
        args = upload_parser.parse_args()
        uploaded_file = args["file"]
        img = Image.open(uploaded_file).convert("RGB")
        image = pre_process(img)
        pred = model(image.unsqueeze(0))
        pred = int(torch.argmax(pred))
        return {"pred": pred}, 201


if __name__ == "__main__":
    app.run(debug=True)

あまり難しい部分はないので、説明は省略します。

実行

python main.py

自動で OPEN API に従ったドキュメントが生成されるので、
Swagger UI を使って動かしてみます。

http://127.0.0.1:5000/doc/
にアクセスして、サンプルのリクエストを送信してみてください。

推論した結果が返ってくるはずです。

参考文献

コメント

タイトルとURLをコピーしました