import glob import sys from collections import OrderedDict import tqdm from natsort import natsort import argparse import models.llflow.option_ as option from models.llflow import Measure, psnr from models.llflow import imresize from models import create_model import torch from util import opt_get import numpy as np import pandas as pd import os import cv2 from rich.console import Console def fiFindByWildcard(wildcard): return natsort.natsorted(glob.glob(wildcard, recursive=True)) def load_model(conf_path): opt = option.parse(conf_path, is_train=False) opt['gpu_ids'] = None opt = option.dict_to_nonedict(opt) model = create_model(opt) model_path = opt_get(opt, ['model_path'], None) model.load_network(load_path=model_path, network=model.netG) return model, opt def predict(model, lr): model.feed_data({"LQ": t(lr)}, need_GT=False) model.test() visuals = model.get_current_visuals(need_GT=False) return visuals.get('rlt', visuals.get('NORMAL')) def t(array): return torch.Tensor(np.expand_dims(array.transpose([2, 0, 1]), axis=0).astype(np.float32)) / 255 def rgb(t): return ( np.clip((t[0] if len(t.shape) == 4 else t).detach().cpu().numpy().transpose([1, 2, 0]), 0, 1) * 255).astype( np.uint8) def imread(path): return cv2.imread(path)[:, :, [2, 1, 0]] def imwrite(path, img): os.makedirs(os.path.dirname(path), exist_ok=True) cv2.imwrite(path, img[:, :, [2, 1, 0]]) def imCropCenter(img, size): h, w, c = img.shape h_start = max(h // 2 - size // 2, 0) h_end = min(h_start + size, h) w_start = max(w // 2 - size // 2, 0) w_end = min(w_start + size, w) return img[h_start:h_end, w_start:w_end] def impad(img, top=0, bottom=0, left=0, right=0, color=255): return np.pad(img, [(top, bottom), (left, right), (0, 0)], 'reflect') def hiseq_color_cv2_img(img): (b, g, r) = cv2.split(img) bH = cv2.equalizeHist(b) gH = cv2.equalizeHist(g) rH = cv2.equalizeHist(r) result = cv2.merge((bH, gH, rH)) return result def auto_padding(img, times=16): # img: numpy image with shape H*W*C h, w, _ = img.shape h1, w1 = (times - h % times) // 2, (times - w % times) // 2 h2, w2 = (times - h % times) - h1, (times - w % times) - w1 img = cv2.copyMakeBorder(img, h1, h2, w1, w2, cv2.BORDER_REFLECT) return img, [h1, h2, w1, w2] def main(path:str): parser = argparse.ArgumentParser() # parser.add_argument("--opt", default="./confs/LOL_smallNet.yml") parser.add_argument("--opt", default="./models/llflow/LOL_smallNet.yml") parser.add_argument("-n", "--name", default="unpaired") # Namespace(opt="./models/llflow/LOL_smallNet.yml", name="unpaired") # args = parser.parse_args() args = parser.parse_args() Console().log(f"🛠️\tLoading model from {args.opt}") conf_path = args.opt conf = conf_path.split('/')[-1].replace('.yml', '') model, opt = load_model(conf_path) model.netG = model.netG.cuda() lr_dir = opt['dataroot_unpaired'] # lr_paths = fiFindByWildcard(os.path.join(lr_dir, '*.*')) lr_paths = path this_dir = os.path.dirname(os.path.realpath(__file__)) test_dir = os.path.join(this_dir, '..', 'results', conf, args.name) print(f"Out dir: {test_dir}") # for lr_path, idx_test in tqdm.tqdm(zip(lr_paths, range(len(lr_paths))), colour='green'): lr_path = lr_paths lr = imread(lr_path) raw_shape = lr.shape lr, padding_params = auto_padding(lr) his = hiseq_color_cv2_img(lr) if opt.get("histeq_as_input", False): lr = his lr_t = t(lr) if opt["datasets"]["train"].get("log_low", False): lr_t = torch.log(torch.clamp(lr_t + 1e-3, min=1e-3)) if opt.get("concat_histeq", False): his = t(his) lr_t = torch.cat([lr_t, his], dim=1) heat = opt['heat'] with torch.cuda.amp.autocast(): sr_t = model.get_sr(lq=lr_t.cuda(), heat=None) sr = rgb(torch.clamp(sr_t, 0, 1)[:, :, padding_params[0]:sr_t.shape[2] - padding_params[1], padding_params[2]:sr_t.shape[3] - padding_params[3]]) assert raw_shape == sr.shape path_out_sr = os.path.join(test_dir, os.path.basename(lr_path)) # imwrite(path_out_sr, sr) # cv2.imwrite(path_out_sr, sr[:, :, [2, 1, 0]]) return sr[:, :, [2, 1, 0]] def format_measurements(meas): s_out = [] for k, v in meas.items(): v = f"{v:0.2f}" if isinstance(v, float) else v s_out.append(f"{k}: {v}") str_out = ", ".join(s_out) return str_out if __name__ == "__main__": main()