以前にtensorflowとpytorchでデータオーギュメンテーションをいくつか確認する記事を書きました。
今回も同じようなものなのですが、データオーギュメンテーションを確認するjupyter notebookを共有します。https://gist.github.com/tocom242242/00877e1a904d86cf785482b39805649a
コードとしては以下のようになります。
import tensorflow as tf
import matplotlib.pyplot as plt
import requests
from PIL import Image
import numpy as np
from io import BytesIO
def apply_augmentation_tensorflow(x):
if x.shape[-1] == 4:
x = x[..., :3]
# 以下に試したい処理を加えk
x = tf.image.fip_left_right(x)
x = tf.image.flip_up_down(x)
x = tf.image.adjust_brightness(x, delta=0.5)
x = tf.image.adjust_contrast(x, contrast_factor=0.5)
x = tf.image.adjust_saturation(x, saturation_factor=0.5)
x = tf.image.adjust_hue(x, delta=0.2)
return x
def load_image_tensorflow(url):
response = requests.get(url)
img = Image.open(BytesIO(response.content))
img = np.array(img)
img_tensor = tf.convert_to_tensor(img, dtype=tf.float32) / 255.0
return img_tensor
image_url = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png"
original_image_tensorflow = load_image_tensorflow(image_url)
augmented_image_tensorflow = apply_augmentation_tensorflow(original_image_tensorflow)
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(original_image_tensorflow)
plt.title("Original Image")
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(augmented_image_tensorflow)
plt.title("Augmented Image")
plt.axis('off')
plt.show()
コメントに書いてあるところに任意のデータオーギュメンテーションを加えていく形になります。
以下はPyTorchバージョンになります。
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image
import requests
from io import BytesIO
def apply_augmentation_pytorch(image):
# 以下を変更する
transform = transforms.Compose([
transforms.RandomHorizontalFlip(1), # 常に水平フリップ
transforms.RandomVerticalFlip(1), # 常に垂直フリップ
transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.2) # 色調整
])
return transform(image)
def load_image_pytorch(url):
response = requests.get(url)
img = Image.open(BytesIO(response.content))
return img
image_url = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png"
original_image_pytorch = load_image_pytorch(image_url)
augmented_image_pytorch = apply_augmentation_pytorch(original_image_pytorch)
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(original_image_pytorch)
plt.title("Original Image")
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(augmented_image_pytorch)
plt.title("Augmented Image")
plt.axis('off')
plt.show()
コメント