import gradio as gr from PIL import Image import numpy as np from torchvision import transforms from load_model import sample import torch import random device = "cuda" if torch.cuda.is_available() else "cpu" device = "mps" if torch.backends.mps.is_available() else device image_size = 128 upscale = False clicked = False transform = transforms.Compose( [ transforms.Resize((image_size, image_size)), transforms.ToTensor(), transforms.Lambda(lambda t: (t * 2) - 1), ] ) def make_scribbles(sketch, scribbles): # get the value that occurs most often in the scribbles sketch = transforms.Resize((image_size, image_size))(sketch) scribbles = transforms.Resize((image_size, image_size))(scribbles) grey_tensor = torch.tensor(0.49803922, device=device) grey_tensor = grey_tensor.expand(3, image_size, image_size) sketch = transforms.ToTensor()(sketch).to(device) scribbles = transforms.ToTensor()(scribbles).to(device) scribble_where_grey_mask = torch.eq(scribbles, grey_tensor) merged = torch.where(scribble_where_grey_mask, sketch, scribbles) return transforms.Lambda(lambda t: (t * 2) - 1)(sketch), transforms.Lambda( lambda t: (t * 2) - 1 )(merged) def process_images(sketch, scribbles, sampling_steps, is_scribbles, seed_nr, upscale): global clicked clicked = True w, h = sketch.size if is_scribbles: sketch, scribbles = make_scribbles(sketch, scribbles) else: sketch = transform(sketch.convert("RGB")) scribbles = transform(scribbles.convert("RGB")) if upscale: output = transforms.Resize((h, w))( sample(sketch, scribbles, sampling_steps, seed_nr) ) clicked = False return output else: output = sample(sketch, scribbles, sampling_steps, seed_nr) clicked = False return output theme = gr.themes.Monochrome() with gr.Blocks(theme=theme) as demo: with gr.Row(): gr.Markdown( "
" "By default the scribbles are assumed to be merged with the sketch, if they appear on a grey background check the box below. " "
" ) is_scribbles = gr.Checkbox(label="Is Scribbles", value=False) with gr.Column(): output = gr.Image(type="pil", label="Output") upscale_info = gr.Markdown( "" f"If you want to stretch the downloadable output, check the box below, the default output of neural networks is {image_size}x{image_size} " "
" ) upscale_button = gr.Checkbox(label="Stretch", value=False) with gr.Row(): with gr.Column(): seed_slider = gr.Number( label="Random Seed 🎲", value=random.randint( 1, 1000, ), ) with gr.Column(): sampling_slider = gr.Slider( minimum=1, maximum=250, step=1, label="DDPM Sampling Steps 🔄", value=50, ) with gr.Row(): generate_button = gr.Button(value="Generate", interactive=not clicked) generate_button.click( process_images, inputs=[ sketch_input, scribbles_input, sampling_slider, is_scribbles, seed_slider, upscale_button, ], outputs=output, show_progress=True, ) demo.launch(server_port=3000, max_threads=1)