Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import numpy as np | |
import cv2 | |
import glob | |
import math | |
import yaml | |
import random | |
from collections import OrderedDict | |
import torch | |
import torch.nn.functional as F | |
from basicsr.data.transforms import augment | |
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels | |
from basicsr.utils import DiffJPEG, USMSharp, img2tensor, tensor2img | |
from basicsr.utils.img_process_util import filter2D | |
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt | |
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation, | |
normalize, rgb_to_grayscale) | |
cur_path = os.path.dirname(os.path.abspath(__file__)) | |
def ordered_yaml(): | |
"""Support OrderedDict for yaml. | |
Returns: | |
yaml Loader and Dumper. | |
""" | |
try: | |
from yaml import CDumper as Dumper | |
from yaml import CLoader as Loader | |
except ImportError: | |
from yaml import Dumper, Loader | |
_mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG | |
def dict_representer(dumper, data): | |
return dumper.represent_dict(data.items()) | |
def dict_constructor(loader, node): | |
return OrderedDict(loader.construct_pairs(node)) | |
Dumper.add_representer(OrderedDict, dict_representer) | |
Loader.add_constructor(_mapping_tag, dict_constructor) | |
return Loader, Dumper | |
def opt_parse(opt_path): | |
with open(opt_path, mode='r') as f: | |
Loader, _ = ordered_yaml() | |
opt = yaml.load(f, Loader=Loader) | |
return opt | |
class RealESRGAN_degradation(object): | |
def __init__(self, opt_path='', device='cpu'): | |
self.opt = opt_parse(opt_path) | |
self.device = device #torch.device('cpu') | |
optk = self.opt['kernel_info'] | |
# blur settings for the first degradation | |
self.blur_kernel_size = optk['blur_kernel_size'] | |
self.kernel_list = optk['kernel_list'] | |
self.kernel_prob = optk['kernel_prob'] | |
self.blur_sigma = optk['blur_sigma'] | |
self.betag_range = optk['betag_range'] | |
self.betap_range = optk['betap_range'] | |
self.sinc_prob = optk['sinc_prob'] | |
# blur settings for the second degradation | |
self.blur_kernel_size2 = optk['blur_kernel_size2'] | |
self.kernel_list2 = optk['kernel_list2'] | |
self.kernel_prob2 = optk['kernel_prob2'] | |
self.blur_sigma2 = optk['blur_sigma2'] | |
self.betag_range2 = optk['betag_range2'] | |
self.betap_range2 = optk['betap_range2'] | |
self.sinc_prob2 = optk['sinc_prob2'] | |
# a final sinc filter | |
self.final_sinc_prob = optk['final_sinc_prob'] | |
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21 | |
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect | |
self.pulse_tensor[10, 10] = 1 | |
self.jpeger = DiffJPEG(differentiable=False).to(self.device) | |
self.usm_shaper = USMSharp().to(self.device) | |
def color_jitter_pt(self, img, brightness, contrast, saturation, hue): | |
fn_idx = torch.randperm(4) | |
for fn_id in fn_idx: | |
if fn_id == 0 and brightness is not None: | |
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item() | |
img = adjust_brightness(img, brightness_factor) | |
if fn_id == 1 and contrast is not None: | |
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item() | |
img = adjust_contrast(img, contrast_factor) | |
if fn_id == 2 and saturation is not None: | |
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item() | |
img = adjust_saturation(img, saturation_factor) | |
if fn_id == 3 and hue is not None: | |
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item() | |
img = adjust_hue(img, hue_factor) | |
return img | |
def random_augment(self, img_gt): | |
# random horizontal flip | |
img_gt, status = augment(img_gt, hflip=True, rotation=False, return_status=True) | |
""" | |
# random color jitter | |
if np.random.uniform() < self.opt['color_jitter_prob']: | |
jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32) | |
img_gt = img_gt + jitter_val | |
img_gt = np.clip(img_gt, 0, 1) | |
# random grayscale | |
if np.random.uniform() < self.opt['gray_prob']: | |
#img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY) | |
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_RGB2GRAY) | |
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) | |
""" | |
# BGR to RGB, HWC to CHW, numpy to tensor | |
img_gt = img2tensor([img_gt], bgr2rgb=False, float32=True)[0].unsqueeze(0) | |
return img_gt | |
def random_kernels(self): | |
# ------------------------ Generate kernels (used in the first degradation) ------------------------ # | |
kernel_size = random.choice(self.kernel_range) | |
if np.random.uniform() < self.sinc_prob: | |
# this sinc filter setting is for kernels ranging from [7, 21] | |
if kernel_size < 13: | |
omega_c = np.random.uniform(np.pi / 3, np.pi) | |
else: | |
omega_c = np.random.uniform(np.pi / 5, np.pi) | |
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) | |
else: | |
kernel = random_mixed_kernels( | |
self.kernel_list, | |
self.kernel_prob, | |
kernel_size, | |
self.blur_sigma, | |
self.blur_sigma, [-math.pi, math.pi], | |
self.betag_range, | |
self.betap_range, | |
noise_range=None) | |
# pad kernel | |
pad_size = (21 - kernel_size) // 2 | |
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) | |
# ------------------------ Generate kernels (used in the second degradation) ------------------------ # | |
kernel_size = random.choice(self.kernel_range) | |
if np.random.uniform() < self.sinc_prob2: | |
if kernel_size < 13: | |
omega_c = np.random.uniform(np.pi / 3, np.pi) | |
else: | |
omega_c = np.random.uniform(np.pi / 5, np.pi) | |
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) | |
else: | |
kernel2 = random_mixed_kernels( | |
self.kernel_list2, | |
self.kernel_prob2, | |
kernel_size, | |
self.blur_sigma2, | |
self.blur_sigma2, [-math.pi, math.pi], | |
self.betag_range2, | |
self.betap_range2, | |
noise_range=None) | |
# pad kernel | |
pad_size = (21 - kernel_size) // 2 | |
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size))) | |
# ------------------------------------- sinc kernel ------------------------------------- # | |
if np.random.uniform() < self.final_sinc_prob: | |
kernel_size = random.choice(self.kernel_range) | |
omega_c = np.random.uniform(np.pi / 3, np.pi) | |
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21) | |
sinc_kernel = torch.FloatTensor(sinc_kernel) | |
else: | |
sinc_kernel = self.pulse_tensor | |
kernel = torch.FloatTensor(kernel) | |
kernel2 = torch.FloatTensor(kernel2) | |
return kernel, kernel2, sinc_kernel | |
def degrade_process(self, img_gt, resize_bak=False): | |
img_gt = self.random_augment(img_gt) | |
kernel1, kernel2, sinc_kernel = self.random_kernels() | |
img_gt, kernel1, kernel2, sinc_kernel = img_gt.to(self.device), kernel1.to(self.device), kernel2.to(self.device), sinc_kernel.to(self.device) | |
#img_gt = self.usm_shaper(img_gt) # shaper gt | |
ori_h, ori_w = img_gt.size()[2:4] | |
#scale_final = random.randint(4, 16) | |
scale_final = 4 | |
# ----------------------- The first degradation process ----------------------- # | |
# blur | |
out = filter2D(img_gt, kernel1) | |
# random resize | |
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0] | |
if updown_type == 'up': | |
scale = np.random.uniform(1, self.opt['resize_range'][1]) | |
elif updown_type == 'down': | |
scale = np.random.uniform(self.opt['resize_range'][0], 1) | |
else: | |
scale = 1 | |
mode = random.choice(['area', 'bilinear', 'bicubic']) | |
out = F.interpolate(out, scale_factor=scale, mode=mode) | |
# noise | |
gray_noise_prob = self.opt['gray_noise_prob'] | |
if np.random.uniform() < self.opt['gaussian_noise_prob']: | |
out = random_add_gaussian_noise_pt( | |
out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob) | |
else: | |
out = random_add_poisson_noise_pt( | |
out, | |
scale_range=self.opt['poisson_scale_range'], | |
gray_prob=gray_noise_prob, | |
clip=True, | |
rounds=False) | |
# JPEG compression | |
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range']) | |
out = torch.clamp(out, 0, 1) | |
out = self.jpeger(out, quality=jpeg_p) | |
# ----------------------- The second degradation process ----------------------- # | |
# blur | |
if np.random.uniform() < self.opt['second_blur_prob']: | |
out = filter2D(out, kernel2) | |
# random resize | |
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0] | |
if updown_type == 'up': | |
scale = np.random.uniform(1, self.opt['resize_range2'][1]) | |
elif updown_type == 'down': | |
scale = np.random.uniform(self.opt['resize_range2'][0], 1) | |
else: | |
scale = 1 | |
mode = random.choice(['area', 'bilinear', 'bicubic']) | |
out = F.interpolate( | |
out, size=(int(ori_h / scale_final * scale), int(ori_w / scale_final * scale)), mode=mode) | |
# noise | |
gray_noise_prob = self.opt['gray_noise_prob2'] | |
if np.random.uniform() < self.opt['gaussian_noise_prob2']: | |
out = random_add_gaussian_noise_pt( | |
out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob) | |
else: | |
out = random_add_poisson_noise_pt( | |
out, | |
scale_range=self.opt['poisson_scale_range2'], | |
gray_prob=gray_noise_prob, | |
clip=True, | |
rounds=False) | |
# JPEG compression + the final sinc filter | |
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together | |
# as one operation. | |
# We consider two orders: | |
# 1. [resize back + sinc filter] + JPEG compression | |
# 2. JPEG compression + [resize back + sinc filter] | |
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines. | |
if np.random.uniform() < 0.5: | |
# resize back + the final sinc filter | |
mode = random.choice(['area', 'bilinear', 'bicubic']) | |
out = F.interpolate(out, size=(ori_h // scale_final, ori_w // scale_final), mode=mode) | |
out = filter2D(out, sinc_kernel) | |
# JPEG compression | |
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) | |
out = torch.clamp(out, 0, 1) | |
out = self.jpeger(out, quality=jpeg_p) | |
else: | |
# JPEG compression | |
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) | |
out = torch.clamp(out, 0, 1) | |
out = self.jpeger(out, quality=jpeg_p) | |
# resize back + the final sinc filter | |
mode = random.choice(['area', 'bilinear', 'bicubic']) | |
out = F.interpolate(out, size=(ori_h // scale_final, ori_w // scale_final), mode=mode) | |
out = filter2D(out, sinc_kernel) | |
if np.random.uniform() < self.opt['gray_prob']: | |
out = rgb_to_grayscale(out, num_output_channels=1) | |
if np.random.uniform() < self.opt['color_jitter_prob']: | |
brightness = self.opt.get('brightness', (0.5, 1.5)) | |
contrast = self.opt.get('contrast', (0.5, 1.5)) | |
saturation = self.opt.get('saturation', (0, 1.5)) | |
hue = self.opt.get('hue', (-0.1, 0.1)) | |
out = self.color_jitter_pt(out, brightness, contrast, saturation, hue) | |
if resize_bak: | |
mode = random.choice(['area', 'bilinear', 'bicubic']) | |
out = F.interpolate(out, size=(ori_h, ori_w), mode=mode) | |
# clamp and round | |
img_lq = torch.clamp((out * 255.0).round(), 0, 255) / 255. | |
return img_gt, img_lq | |