NamedCurves / data /image_transformations.py
davidserra9's picture
First commit from github repo
117183e verified
import random
import torchvision.transforms.functional as F
from torchvision import transforms
class RandomCropPair:
def __init__(self, size):
self.size = size
def __call__(self, img1, img2):
i, j, h, w = transforms.RandomCrop.get_params(img1, self.size)
img1 = F.crop(img1, i, j, h, w)
img2 = F.crop(img2, i, j, h, w)
return img1, img2
class ResizePair:
def __init__(self, size):
self.size = size
def __call__(self, img1, img2):
# antialias=True is used to avoid torchvision warning
img1 = F.resize(img1, self.size, antialias=True)
img2 = F.resize(img2, self.size, antialias=True)
return img1, img2
class RandomHorizontalFlipPair:
def __init__(self, p=0.5):
self.p = p
def __call__(self, img1, img2):
if random.random() < self.p:
img1 = F.hflip(img1)
img2 = F.hflip(img2)
return img1, img2
class RandomVerticalFlipPair:
def __init__(self, p=0.5):
self.p = p
def __call__(self, img1, img2):
if random.random() < self.p:
img1 = F.vflip(img1)
img2 = F.vflip(img2)
return img1, img2
def get_transforms(transforms_config):
transform_list = []
for transform in transforms_config:
transform_type = transform['type']
params = transform['params']
if transform_type == 'RandomCrop':
transform_list.append(RandomCropPair(**params))
elif transform_type == 'Resize':
transform_list.append(ResizePair(**params))
elif transform_type == 'RandomHorizontalFlip':
transform_list.append(RandomHorizontalFlipPair(**params))
elif transform_type == 'RandomVerticalFlip':
transform_list.append(RandomVerticalFlipPair(**params))
else:
raise ValueError(f"Unsupported transform type: {transform_type}")
return transform_list