PITI-Synthesis / glide_text2im /image_datasets_sketch.py
tfwang's picture
Update glide_text2im/image_datasets_sketch.py
2651d59
raw
history blame
1.89 kB
import math
import random
from PIL import Image
import blobfile as bf
#from mpi4py import MPI
import numpy as np
from torch.utils.data import DataLoader, Dataset
import os
import torchvision.transforms as transforms
import torch as th
from .degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
from functools import partial
import cv2
def get_params( size, resize_size, crop_size):
w, h = size
new_h = h
new_w = w
ss, ls = min(w, h), max(w, h) # shortside and longside
width_is_shorter = w == ss
ls = int(resize_size * ls / ss)
ss = resize_size
new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss)
x = random.randint(0, np.maximum(0, new_w - crop_size))
y = random.randint(0, np.maximum(0, new_h - crop_size))
flip = random.random() > 0.5
return {'crop_pos': (x, y), 'flip': flip}
def get_transform(params, resize_size, crop_size, method=Image.BICUBIC, flip=True, crop = True):
transform_list = []
transform_list.append(transforms.Lambda(lambda img: __scale(img, crop_size, method)))
if flip:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
return transforms.Compose(transform_list)
def get_tensor(normalize=True, toTensor=True):
transform_list = []
if toTensor:
transform_list += [transforms.ToTensor()]
if normalize:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
def normalize():
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
def __scale(img, target_width, method=Image.BICUBIC):
return img.resize((target_width, target_width), method)
def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img