Spaces:
Running
on
Zero
Running
on
Zero
import random | |
import torchvision.transforms.functional as F | |
from torchvision import transforms | |
class RandomCropPair: | |
def __init__(self, size): | |
self.size = size | |
def __call__(self, img1, img2): | |
i, j, h, w = transforms.RandomCrop.get_params(img1, self.size) | |
img1 = F.crop(img1, i, j, h, w) | |
img2 = F.crop(img2, i, j, h, w) | |
return img1, img2 | |
class ResizePair: | |
def __init__(self, size): | |
self.size = size | |
def __call__(self, img1, img2): | |
# antialias=True is used to avoid torchvision warning | |
img1 = F.resize(img1, self.size, antialias=True) | |
img2 = F.resize(img2, self.size, antialias=True) | |
return img1, img2 | |
class RandomHorizontalFlipPair: | |
def __init__(self, p=0.5): | |
self.p = p | |
def __call__(self, img1, img2): | |
if random.random() < self.p: | |
img1 = F.hflip(img1) | |
img2 = F.hflip(img2) | |
return img1, img2 | |
class RandomVerticalFlipPair: | |
def __init__(self, p=0.5): | |
self.p = p | |
def __call__(self, img1, img2): | |
if random.random() < self.p: | |
img1 = F.vflip(img1) | |
img2 = F.vflip(img2) | |
return img1, img2 | |
def get_transforms(transforms_config): | |
transform_list = [] | |
for transform in transforms_config: | |
transform_type = transform['type'] | |
params = transform['params'] | |
if transform_type == 'RandomCrop': | |
transform_list.append(RandomCropPair(**params)) | |
elif transform_type == 'Resize': | |
transform_list.append(ResizePair(**params)) | |
elif transform_type == 'RandomHorizontalFlip': | |
transform_list.append(RandomHorizontalFlipPair(**params)) | |
elif transform_type == 'RandomVerticalFlip': | |
transform_list.append(RandomVerticalFlipPair(**params)) | |
else: | |
raise ValueError(f"Unsupported transform type: {transform_type}") | |
return transform_list |