from argparse import Namespace import gradio as gr import torch import torchvision.transforms as transforms from huggingface_hub import hf_hub_download from PIL import Image from models.psp import pSp device = 'cuda' if torch.cuda.is_available() else 'cpu' transfroms = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor()] ) def tensor2im(var): var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy() var = ((var + 1) / 2) var[var < 0] = 0 var[var > 1] = 1 var = var * 255 return Image.fromarray(var.astype('uint8')) def sketch_recognition(img): from_im = transfroms(Image.fromarray(img)) with torch.no_grad(): res = net(from_im.unsqueeze(0).to(device)) return tensor2im(res[0]) path = hf_hub_download('huggan/TediGAN_sketch', 'psp_celebs_sketch_to_face.pt') ckpt = torch.load(path, map_location=device) opts = ckpt['opts'] opts.update({"checkpoint_path": path}) opts = Namespace(**opts) net = pSp(opts) net.eval() net.to(device) iface = gr.Interface( fn=sketch_recognition, inputs=gr.inputs.Image( shape=(256, 256), image_mode="L", invert_colors=False, source="canvas", tool="editor", type="numpy", label=None, optional=False ), outputs="image" ).launch() iface.launch()