import math import random from PIL import Image import blobfile as bf #from mpi4py import MPI import numpy as np from torch.utils.data import DataLoader, Dataset import os import torchvision.transforms as transforms import torch as th from .degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light from functools import partial import cv2 def get_params( size, resize_size, crop_size): w, h = size new_h = h new_w = w ss, ls = min(w, h), max(w, h) # shortside and longside width_is_shorter = w == ss ls = int(resize_size * ls / ss) ss = resize_size new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss) x = random.randint(0, np.maximum(0, new_w - crop_size)) y = random.randint(0, np.maximum(0, new_h - crop_size)) flip = random.random() > 0.5 return {'crop_pos': (x, y), 'flip': flip} def get_transform(params, resize_size, crop_size, method=Image.BICUBIC, flip=True, crop = True): transform_list = [] transform_list.append(transforms.Lambda(lambda img: __scale(img, crop_size, method))) if flip: transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) return transforms.Compose(transform_list) def get_tensor(normalize=True, toTensor=True): transform_list = [] if toTensor: transform_list += [transforms.ToTensor()] if normalize: transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] return transforms.Compose(transform_list) def normalize(): return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) def __scale(img, target_width, method=Image.BICUBIC): return img.resize((target_width, target_width), method) def __flip(img, flip): if flip: return img.transpose(Image.FLIP_LEFT_RIGHT) return img