import gradio as gr from PIL import Image import torch import torchvision.transforms as transforms import torch.nn.functional as F from archs import DarkIR device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') #define some auxiliary functions pil_to_tensor = transforms.ToTensor() tensor_to_pil = transforms.ToPILImage() network = 'DarkIR' PATH_MODEL = './darkir_v2real+lsrw.pt' model = DarkIR(img_channel=3, width=32, middle_blk_num_enc=2, middle_blk_num_dec=2, enc_blk_nums=[1, 2, 3], dec_blk_nums=[3, 1, 1], dilations=[1, 4, 9], extra_depth_wise=True) checkpoints = torch.load(PATH_MODEL, map_location=device) model.load_state_dict(checkpoints['params']) model = model.to(device) def pad_tensor(tensor, multiple = 8): '''pad the tensor to be multiple of some number''' multiple = multiple _, _, H, W = tensor.shape pad_h = (multiple - H % multiple) % multiple pad_w = (multiple - W % multiple) % multiple tensor = F.pad(tensor, (0, pad_w, 0, pad_h), value = 0) return tensor def process_img(image): tensor = pil_to_tensor(image).unsqueeze(0).to(device) _, _, H, W = tensor.shape tensor = pad_tensor(tensor) with torch.no_grad(): output = model(tensor, side_loss=False) output = torch.clamp(output, 0., 1.) output = output[:,:, :H, :W].squeeze(0) return tensor_to_pil(output) title = "DarkIR ✏️🖼️ 🤗" description = ''' ## [ DarkIR: Robust Low-Light Image Restoration](https://github.com/cidautai/DarkIR) [Daniel Feijoo](https://github.com/danifei) Fundación Cidaut > **Disclaimer:** please remember this is not a product, thus, you will notice some limitations. **This demo expects an image with some Low-Light degradations.**
''' examples = [['examples/0010.png'], ['examples/r13073518t_low.png'], ['examples/low00733_low.png'], ["examples/0087.png"]] css = """ .image-frame img, .image-container img { width: auto; height: auto; max-width: none; } """ demo = gr.Interface( fn = process_img, inputs = [ gr.Image(type = 'pil', label = 'input') ], outputs = [gr.Image(type='pil', label = 'output')], title = title, description = description, examples = examples, css = css ) if __name__ == '__main__': demo.launch()