Spaces:
Runtime error
Runtime error
import math | |
import random | |
from PIL import Image | |
import blobfile as bf | |
#from mpi4py import MPI | |
import numpy as np | |
from torch.utils.data import DataLoader, Dataset | |
import os | |
import torchvision.transforms as transforms | |
import torch as th | |
from functools import partial | |
import cv2 | |
def get_params( size, resize_size, crop_size): | |
w, h = size | |
new_h = h | |
new_w = w | |
ss, ls = min(w, h), max(w, h) # shortside and longside | |
width_is_shorter = w == ss | |
ls = int(resize_size * ls / ss) | |
ss = resize_size | |
new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss) | |
x = random.randint(0, np.maximum(0, new_w - crop_size)) | |
y = random.randint(0, np.maximum(0, new_h - crop_size)) | |
flip = random.random() > 0.5 | |
return {'crop_pos': (x, y), 'flip': flip} | |
def get_transform(params, resize_size, crop_size, method=Image.BICUBIC, flip=True, crop = True): | |
transform_list = [] | |
transform_list.append(transforms.Lambda(lambda img: __scale(img, crop_size, method))) | |
if flip: | |
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) | |
return transforms.Compose(transform_list) | |
def get_tensor(normalize=True, toTensor=True): | |
transform_list = [] | |
if toTensor: | |
transform_list += [transforms.ToTensor()] | |
if normalize: | |
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), | |
(0.5, 0.5, 0.5))] | |
return transforms.Compose(transform_list) | |
def normalize(): | |
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
def __scale(img, target_width, method=Image.BICUBIC): | |
return img.resize((target_width, target_width), method) | |
def __flip(img, flip): | |
if flip: | |
return img.transpose(Image.FLIP_LEFT_RIGHT) | |
return img |