データオーギュメンテーションをプロットするnotebook【tf2, pytorch】

以前にtensorflowとpytorchでデータオーギュメンテーションをいくつか確認する記事を書きました。

今回も同じようなものなのですが、データオーギュメンテーションを確認するjupyter notebookを共有します。https://gist.github.com/tocom242242/00877e1a904d86cf785482b39805649a
コードとしては以下のようになります。

import tensorflow as tf
import matplotlib.pyplot as plt
import requests
from PIL import Image
import numpy as np
from io import BytesIO

def apply_augmentation_tensorflow(x):
    if x.shape[-1] == 4:
        x = x[..., :3]
    # 以下に試したい処理を加えk
    x = tf.image.fip_left_right(x)
    x = tf.image.flip_up_down(x)
    x = tf.image.adjust_brightness(x, delta=0.5)
    x = tf.image.adjust_contrast(x, contrast_factor=0.5)
    x = tf.image.adjust_saturation(x, saturation_factor=0.5)
    x = tf.image.adjust_hue(x, delta=0.2)

    return x

def load_image_tensorflow(url):
    response = requests.get(url)
    img = Image.open(BytesIO(response.content))
    img = np.array(img)
    img_tensor = tf.convert_to_tensor(img, dtype=tf.float32) / 255.0
    return img_tensor

image_url = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png"

original_image_tensorflow = load_image_tensorflow(image_url)

augmented_image_tensorflow = apply_augmentation_tensorflow(original_image_tensorflow)

plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(original_image_tensorflow)
plt.title("Original Image")
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(augmented_image_tensorflow)
plt.title("Augmented Image")
plt.axis('off')

plt.show()

コメントに書いてあるところに任意のデータオーギュメンテーションを加えていく形になります。

以下はPyTorchバージョンになります。

import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image
import requests
from io import BytesIO

def apply_augmentation_pytorch(image):
    # 以下を変更する
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(1),  # 常に水平フリップ
        transforms.RandomVerticalFlip(1),    # 常に垂直フリップ
        transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.2)  # 色調整
    ])
    return transform(image)

def load_image_pytorch(url):
    response = requests.get(url)
    img = Image.open(BytesIO(response.content))
    return img

image_url = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png"

original_image_pytorch = load_image_pytorch(image_url)

augmented_image_pytorch = apply_augmentation_pytorch(original_image_pytorch)

plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(original_image_pytorch)
plt.title("Original Image")
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(augmented_image_pytorch)
plt.title("Augmented Image")
plt.axis('off')

plt.show()

コメント

タイトルとURLをコピーしました