anime_diffusion / app.py
pawlo2013's picture
init commit
0b2b0ab
raw
history blame
4.11 kB
import gradio as gr
from PIL import Image
import numpy as np
from torchvision import transforms
from load_model import sample
import torch
import random
device = "cuda" if torch.cuda.is_available() else "cpu"
device = "mps" if torch.backends.mps.is_available() else device
image_size = 128
upscale = False
clicked = False
transform = transforms.Compose(
[
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Lambda(lambda t: (t * 2) - 1),
]
)
def make_scribbles(sketch, scribbles):
# get the value that occurs most often in the scribbles
sketch = transforms.Resize((image_size, image_size))(sketch)
scribbles = transforms.Resize((image_size, image_size))(scribbles)
grey_tensor = torch.tensor(0.49803922, device=device)
grey_tensor = grey_tensor.expand(3, image_size, image_size)
sketch = transforms.ToTensor()(sketch).to(device)
scribbles = transforms.ToTensor()(scribbles).to(device)
scribble_where_grey_mask = torch.eq(scribbles, grey_tensor)
merged = torch.where(scribble_where_grey_mask, sketch, scribbles)
return transforms.Lambda(lambda t: (t * 2) - 1)(sketch), transforms.Lambda(
lambda t: (t * 2) - 1
)(merged)
def process_images(sketch, scribbles, sampling_steps, is_scribbles, seed_nr, upscale):
global clicked
clicked = True
w, h = sketch.size
if is_scribbles:
sketch, scribbles = make_scribbles(sketch, scribbles)
else:
sketch = transform(sketch.convert("RGB"))
scribbles = transform(scribbles.convert("RGB"))
if upscale:
output = transforms.Resize((h, w))(
sample(sketch, scribbles, sampling_steps, seed_nr)
)
clicked = False
return output
else:
output = sample(sketch, scribbles, sampling_steps, seed_nr)
clicked = False
return output
theme = gr.themes.Monochrome()
with gr.Blocks(theme=theme) as demo:
with gr.Row():
gr.Markdown(
"<h1 style='text-align: center; font-size: 30px;'>Image Inpainting with Conditional Diffusion by MedicAI</h1>"
)
with gr.Row():
with gr.Column():
sketch_input = gr.Image(type="pil", label="Sketch", height=500)
with gr.Column():
scribbles_input = gr.Image(type="pil", label="Scribbles", height=500)
info = gr.Markdown(
"<p style='text-align: center; font-size: 12px;'>"
"By default the scribbles are assumed to be merged with the sketch, if they appear on a grey background check the box below. "
"</p>"
)
is_scribbles = gr.Checkbox(label="Is Scribbles", value=False)
with gr.Column():
output = gr.Image(type="pil", label="Output")
upscale_info = gr.Markdown(
"<p style='text-align: center; font-size: 12px;'>"
f"If you want to stretch the downloadable output, check the box below, the default output of neural networks is {image_size}x{image_size} "
"</p>"
)
upscale_button = gr.Checkbox(label="Stretch", value=False)
with gr.Row():
with gr.Column():
seed_slider = gr.Number(
label="Random Seed 🎲",
value=random.randint(
1,
1000,
),
)
with gr.Column():
sampling_slider = gr.Slider(
minimum=1,
maximum=250,
step=1,
label="DDPM Sampling Steps πŸ”„",
value=50,
)
with gr.Row():
generate_button = gr.Button(value="Generate", interactive=not clicked)
generate_button.click(
process_images,
inputs=[
sketch_input,
scribbles_input,
sampling_slider,
is_scribbles,
seed_slider,
upscale_button,
],
outputs=output,
show_progress=True,
)
demo.launch(server_port=3000, max_threads=1)