Spaces:
Running
on
Zero
Running
on
Zero
import glob | |
import os | |
from PIL import Image | |
import random | |
import numpy as np | |
from torch import nn | |
from torchvision import transforms | |
from torch.utils import data as data | |
import torch.nn.functional as F | |
from .realesrgan import RealESRGAN_degradation | |
class PairedCaptionDataset(data.Dataset): | |
def __init__( | |
self, | |
root_folders=None, | |
tokenizer=None, | |
gt_ratio=0, # let lr is gt | |
): | |
super(PairedCaptionDataset, self).__init__() | |
self.gt_ratio = gt_ratio | |
with open(root_folders, 'r') as f: | |
self.gt_list = [line.strip() for line in f.readlines()] | |
self.img_preproc = transforms.Compose([ | |
transforms.RandomCrop((512, 512)), | |
transforms.Resize((512, 512)), | |
transforms.RandomHorizontalFlip(), | |
]) | |
self.degradation = RealESRGAN_degradation('dataloaders/params_ccsr.yml', device='cuda') | |
self.tokenizer = tokenizer | |
def tokenize_caption(self, caption=""): | |
inputs = self.tokenizer( | |
caption, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" | |
) | |
return inputs.input_ids | |
def __getitem__(self, index): | |
gt_path = self.gt_list[index] | |
gt_img = Image.open(gt_path).convert('RGB') | |
gt_img = self.img_preproc(gt_img) | |
gt_img, img_t = self.degradation.degrade_process(np.asarray(gt_img)/255., resize_bak=True) | |
if random.random() < self.gt_ratio: | |
lq_img = gt_img | |
else: | |
lq_img = img_t | |
# no caption used | |
lq_caption = '' | |
example = dict() | |
example["conditioning_pixel_values"] = lq_img.squeeze(0) # [0, 1] | |
example["pixel_values"] = gt_img.squeeze(0) * 2.0 - 1.0 # [-1, 1] | |
example["input_caption"] = self.tokenize_caption(caption=lq_caption).squeeze(0) | |
lq_img = lq_img.squeeze() | |
return example | |
def __len__(self): | |
return len(self.gt_list) |