タイトルにある通り、Flask-RESTX を使って、
PyTorch で作ったモデルで推論した結果を返すだけの API を作ってみます。
※筆者は Flask-RESTX の入門者なので、間違い等があれば教えてください
今回は、入力は画像で、PyTorch のモデルは特に学習しません。
環境
- python 3.10.5
- flask-restx 1.0.6
- pytorch 1.12.1
- Flask 2.2.3
インストール
以下のようにインストールしていきます。
注意点としては、flask-restx は Python 3.7 以上でないと動きません。
pip install Flask
pip install flask-restx
pip install torch torchvision torchaudio
実装
コードは github にあげてあります。
https://github.com/tocom242242/example-api-flask_restx-pytorch
model.py
モデルは公式のチュートリアルのものをそのまま使います。
学習も何もしていないので、出力は適当です。
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
main.py
API の部分は以下のようになります。
import torch
import torchvision.transforms as transforms
from flask import Flask
from flask_restx import Api, Resource
from PIL import Image
from werkzeug.datastructures import FileStorage
from model import Net
transform = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
# モデルを作成
model = Net()
# APIを生成
app = Flask(__name__)
api = Api(app, doc="/doc/")
upload_parser = api.parser()
upload_parser.add_argument("file", location="files", type=FileStorage, required=True)
def pre_process(image):
image = transform(image)
return image
def post_process(model_output):
return model_output
@api.route("/prediction")
@api.expect(upload_parser)
class ExampleResource(Resource):
def post(self):
args = upload_parser.parse_args()
uploaded_file = args["file"]
img = Image.open(uploaded_file).convert("RGB")
image = pre_process(img)
pred = model(image.unsqueeze(0))
pred = int(torch.argmax(pred))
return {"pred": pred}, 201
if __name__ == "__main__":
app.run(debug=True)
あまり難しい部分はないので、説明は省略します。
実行
python main.py
自動で OPEN API に従ったドキュメントが生成されるので、
Swagger UI を使って動かしてみます。
http://127.0.0.1:5000/doc/
にアクセスして、サンプルのリクエストを送信してみてください。
推論した結果が返ってくるはずです。
参考文献
- https://flask-restx.readthedocs.io/en/latest/
- https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
コメント