Spaces:
Running
Running
import gradio as gr | |
from PIL import Image | |
import numpy as np | |
from torchvision import transforms | |
from load_model import sample | |
import torch | |
import random | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
device = "mps" if torch.backends.mps.is_available() else device | |
image_size = 128 | |
upscale = False | |
clicked = False | |
transform = transforms.Compose( | |
[ | |
transforms.Resize((image_size, image_size)), | |
transforms.ToTensor(), | |
transforms.Lambda(lambda t: (t * 2) - 1), | |
] | |
) | |
def make_scribbles(sketch, scribbles): | |
# get the value that occurs most often in the scribbles | |
sketch = transforms.Resize((image_size, image_size))(sketch) | |
scribbles = transforms.Resize((image_size, image_size))(scribbles) | |
grey_tensor = torch.tensor(0.49803922, device=device) | |
grey_tensor = grey_tensor.expand(3, image_size, image_size) | |
sketch = transforms.ToTensor()(sketch).to(device) | |
scribbles = transforms.ToTensor()(scribbles).to(device) | |
scribble_where_grey_mask = torch.eq(scribbles, grey_tensor) | |
merged = torch.where(scribble_where_grey_mask, sketch, scribbles) | |
return transforms.Lambda(lambda t: (t * 2) - 1)(sketch), transforms.Lambda( | |
lambda t: (t * 2) - 1 | |
)(merged) | |
def process_images(sketch, scribbles, sampling_steps, is_scribbles, seed_nr, upscale): | |
global clicked | |
clicked = True | |
w, h = sketch.size | |
if is_scribbles: | |
sketch, scribbles = make_scribbles(sketch, scribbles) | |
else: | |
sketch = transform(sketch.convert("RGB")) | |
scribbles = transform(scribbles.convert("RGB")) | |
if upscale: | |
output = transforms.Resize((h, w))( | |
sample(sketch, scribbles, sampling_steps, seed_nr) | |
) | |
clicked = False | |
return output | |
else: | |
output = sample(sketch, scribbles, sampling_steps, seed_nr) | |
clicked = False | |
return output | |
theme = gr.themes.Monochrome() | |
with gr.Blocks(theme=theme) as demo: | |
with gr.Row(): | |
gr.Markdown( | |
"<h1 style='text-align: center; font-size: 30px;'>Image Inpainting with Conditional Diffusion by MedicAI</h1>" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
sketch_input = gr.Image(type="pil", label="Sketch", height=500) | |
with gr.Column(): | |
scribbles_input = gr.Image(type="pil", label="Scribbles", height=500) | |
info = gr.Markdown( | |
"<p style='text-align: center; font-size: 12px;'>" | |
"By default the scribbles are assumed to be merged with the sketch, if they appear on a grey background check the box below. " | |
"</p>" | |
) | |
is_scribbles = gr.Checkbox(label="Is Scribbles", value=False) | |
with gr.Column(): | |
output = gr.Image(type="pil", label="Output") | |
upscale_info = gr.Markdown( | |
"<p style='text-align: center; font-size: 12px;'>" | |
f"If you want to stretch the downloadable output, check the box below, the default output of neural networks is {image_size}x{image_size} " | |
"</p>" | |
) | |
upscale_button = gr.Checkbox(label="Stretch", value=False) | |
with gr.Row(): | |
with gr.Column(): | |
seed_slider = gr.Number( | |
label="Random Seed π²", | |
value=random.randint( | |
1, | |
1000, | |
), | |
) | |
with gr.Column(): | |
sampling_slider = gr.Slider( | |
minimum=1, | |
maximum=250, | |
step=1, | |
label="DDPM Sampling Steps π", | |
value=50, | |
) | |
with gr.Row(): | |
generate_button = gr.Button(value="Generate", interactive=not clicked) | |
generate_button.click( | |
process_images, | |
inputs=[ | |
sketch_input, | |
scribbles_input, | |
sampling_slider, | |
is_scribbles, | |
seed_slider, | |
upscale_button, | |
], | |
outputs=output, | |
show_progress=True, | |
) | |
demo.launch(server_port=3000, max_threads=1) | |