Spaces:
Runtime error
Runtime error
from argparse import Namespace | |
import gradio as gr | |
import torch | |
import torchvision.transforms as transforms | |
from huggingface_hub import hf_hub_download | |
from PIL import Image | |
from models.psp import pSp | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
transfroms = transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.ToTensor()] | |
) | |
def tensor2im(var): | |
var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy() | |
var = ((var + 1) / 2) | |
var[var < 0] = 0 | |
var[var > 1] = 1 | |
var = var * 255 | |
return Image.fromarray(var.astype('uint8')) | |
def sketch_recognition(img): | |
from_im = transfroms(Image.fromarray(img)) | |
with torch.no_grad(): | |
res = net(from_im.unsqueeze(0).to(device)) | |
return tensor2im(res[0]) | |
path = hf_hub_download('huggan/TediGAN_sketch', 'psp_celebs_sketch_to_face.pt') | |
ckpt = torch.load(path, map_location=device) | |
opts = ckpt['opts'] | |
opts.update({"checkpoint_path": path}) | |
opts = Namespace(**opts) | |
net = pSp(opts) | |
net.eval() | |
net.to(device) | |
iface = gr.Interface( | |
fn=sketch_recognition, | |
inputs=gr.inputs.Image( | |
shape=(256, 256), | |
image_mode="L", | |
invert_colors=False, | |
source="canvas", | |
tool="editor", | |
type="numpy", | |
label=None, | |
optional=False | |
), | |
outputs="image" | |
).launch() | |
iface.launch() | |