【PyTorch】画像を分割する【Kornia】

これまで、Unfoldなどを使って画像を分割していましたが、
korniaに簡単に画像を分割できるメソッドがあったので、
それを使って分割してみます。
padding等も自動で決めてくれるので、ものすごい楽になりました。

インストール

korniaというPyTorch用の画像処理ライブラリを用いるので、以下のコマンドでインストールします。

pip install kornia

分割してみる

全コード

先に全コードを張っておきます。
https://github.com/tocom242242/notebooks/blob/master/pytorch/korina_split_img.ipynb

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import kornia
import cv2
import numpy as np
from PIL import Image

with Image.open("<path/to/your/image>") as img:
    img = img.resize((224,224))
    display(img)
    original = np.expand_dims(np.asarray(img, np.float32).transpose([2, 0, 1]), axis=0) / 255.0

patch_size = 50
stride = 50
x = torch.as_tensor(original)
print(x.shape)
padding = kornia.contrib.compute_padding(
            x.shape[-2:], (patch_size, patch_size)
        )
patches = kornia.contrib.extract_tensor_patches(
        x, (patch_size, patch_size), stride=stride, padding=padding
    )
print(patches.shape)
patches = patches.contiguous().view(-1, 3, *patches.shape[-2:])
print(patches.shape)

img = torchvision.utils.make_grid(patches, nrow=5)
img = transforms.functional.to_pil_image(img)
display(img)

解説

今回は以下のような画像を分割したいと思います。

まず、必要なライブラリをインポートして、画像を読み込みます。

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import kornia
import cv2
import numpy as np
from PIL import Image

with Image.open("<path/to/your/image>") as img:
    img = img.resize((224,224))
    display(img)
    original = np.expand_dims(np.asarray(img, np.float32).transpose([2, 0, 1]), axis=0) / 255.0

x = torch.as_tensor(original)

patch_sizeとstrideを下のように設定します。

patch_size = 50
stride = 50

paddingを計算します。kornia.contrib.compute_paddingを使うことで、必要なpadを自動で計算してくれます。

padding = kornia.contrib.compute_padding(
            x.shape[-2:], (patch_size, patch_size)
        )

元画像と上で計算したpaddingを使って画像を分割します。
kornia.contrib.extract_tensor_patchesを使います。

patches = kornia.contrib.extract_tensor_patches(
        x, (patch_size, patch_size), stride=stride, padding=padding
    )

patchesが分割したパッチになります。
プロットしてみます。

patches = patches.contiguous().view(-1, 3, *patches.shape[-2:])
print(patches.shape)

img = torchvision.utils.make_grid(patches, nrow=5)
img = transforms.functional.to_pil_image(img)
display(img)

strideを変更したバージョン

参考

https://kornia.readthedocs.io/en/latest/_modules/kornia/contrib/extract_patches.html

コメント

タイトルとURLをコピーしました