|
import torch.utils.data as data |
|
from PIL import Image |
|
import torchvision.transforms as transforms |
|
import random |
|
|
|
class BaseDataset(data.Dataset): |
|
def __init__(self): |
|
super(BaseDataset, self).__init__() |
|
|
|
def name(self): |
|
return 'BaseDataset' |
|
|
|
def initialize(self, opt): |
|
pass |
|
|
|
def get_transform(opt): |
|
transform_list = [] |
|
if opt.resize_or_crop == 'resize_and_crop': |
|
zoom = 1 + 0.1*radom.randint(0,4) |
|
osize = [int(400*zoom), int(600*zoom)] |
|
transform_list.append(transforms.Scale(osize, Image.BICUBIC)) |
|
transform_list.append(transforms.RandomCrop(opt.fineSize)) |
|
elif opt.resize_or_crop == 'crop': |
|
transform_list.append(transforms.RandomCrop(opt.fineSize)) |
|
elif opt.resize_or_crop == 'scale_width': |
|
transform_list.append(transforms.Lambda( |
|
lambda img: __scale_width(img, opt.fineSize))) |
|
elif opt.resize_or_crop == 'scale_width_and_crop': |
|
transform_list.append(transforms.Lambda( |
|
lambda img: __scale_width(img, opt.loadSize))) |
|
transform_list.append(transforms.RandomCrop(opt.fineSize)) |
|
|
|
|
|
|
|
|
|
if opt.isTrain and not opt.no_flip: |
|
transform_list.append(transforms.RandomHorizontalFlip()) |
|
|
|
transform_list += [transforms.ToTensor(), |
|
transforms.Normalize((0.5, 0.5, 0.5), |
|
(0.5, 0.5, 0.5))] |
|
return transforms.Compose(transform_list) |
|
|
|
def __scale_width(img, target_width): |
|
ow, oh = img.size |
|
if (ow == target_width): |
|
return img |
|
w = target_width |
|
h = int(target_width * oh / ow) |
|
return img.resize((w, h), Image.BICUBIC) |
|
|