Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import spaces | |
import random | |
import gradio as gr | |
import torch | |
from PIL import Image, ImageOps | |
from diffusers import StableDiffusionInstructPix2PixPipeline | |
help_text = """ | |
Considerations while editing: | |
1. The Base-Model, trained on the PIPE dataset, is great for some tasks, while the Finetuned-MB-Model, fine-tuned on the MagicBrush dataset, can be better for others. Please try both until you are satisfied. | |
2. Image CFG controls how much to deviate from the original image. Higher values keep the image more consistent with the original. | |
3. Text CFG does the opposite. Higher values lead to more changes in the image. | |
4. Using different seed values will produce varied outputs. | |
5. Increasing the number of steps can enhance the results. | |
6. The Stable Diffusion autoencoder struggles with small faces in images. | |
""" | |
article = """ | |
<p style='text-align: center'> | |
<a href='https://arxiv.org/abs/2404.18212' target='_blank'> | |
Paint by Inpaint: Learning to Add Image Objects by Removing Them First</a> | |
</p> | |
""" | |
description = """ | |
<p style="text-align: center;"> | |
Gradio demo for <strong>Paint by Inpaint: Learning to Add Image Objects by Removing Them First</strong>, visit our <a href='https://rotsteinnoam.github.io/Paint-by-Inpaint/' target='_blank'>project page</a>. <br> | |
The demo involves two models: one trained for image object addition using the <a href='https://huggingface.co/datasets/paint-by-inpaint/PIPE' target='_blank'>PIPE dataset</a>, and another model further fine-tuned on the MagicBrush dataset. | |
</p> | |
""" | |
# Base models | |
object_addition_base_model_id = "paint-by-inpaint/add-base" | |
# general_editing_base_model_id = "paint-by-inpaint/general-base" | |
# MagicBrush finetuned models | |
object_addition_finetuned_model_id = "paint-by-inpaint/add-finetuned-mb" | |
# general_editing_finetuned_model_id = "paint-by-inpaint/general-finetuned-mb" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.float16 #if "cuda" in device else torch.float32 | |
def load_model(model_id): | |
return StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=dtype).to(device) | |
pipe_object_addition_base = load_model(object_addition_base_model_id) | |
pipe_object_addition_finetuned = load_model(object_addition_finetuned_model_id) | |
# pipe_general_editing_base = load_model(general_editing_base_model_id) | |
# pipe_general_editing_finetuned = load_model(general_editing_finetuned_model_id) | |
def generate( | |
input_image: Image.Image, | |
instruction: str, | |
model_choice: int, | |
steps: int, | |
randomize_seed: bool, | |
seed: int, | |
text_cfg_scale: float, | |
image_cfg_scale: float, | |
task_type: str, | |
): | |
seed = random.randint(0, 100000) if randomize_seed else seed | |
if task_type == "object_addition": | |
pipe = pipe_object_addition_base if model_choice == 0 else pipe_object_addition_finetuned | |
else: | |
pipe = pipe_general_editing_base if model_choice == 0 else pipe_general_editing_finetuned | |
width, height = input_image.size | |
factor = 512 / max(width, height) | |
factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height) | |
width = int((width * factor) // 64) * 64 | |
height = int((height * factor) // 64) * 64 | |
input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS) | |
if instruction == "": | |
return [input_image, seed] | |
generator = torch.manual_seed(seed) | |
edited_image = pipe( | |
instruction, image=input_image, | |
guidance_scale=text_cfg_scale, image_guidance_scale=image_cfg_scale, | |
num_inference_steps=steps, generator=generator, | |
).images[0] | |
return [seed, text_cfg_scale, image_cfg_scale, edited_image] | |
def reset(): | |
return [0, "Randomize Seed", 2024, "Fix CFG", 7.5, 1.5, None] | |
with gr.Blocks(css=".compact-box .gr-row { margin-bottom: 5px; } .compact-box .gr-number input, .compact-box .gr-radio label { padding: 5px 10px; }") as demo: | |
gr.HTML(""" | |
<div style="text-align: center;"> | |
<h1 style="font-weight: 900; margin-bottom: 7px;">Paint by Inpaint</h1> | |
{description} | |
</div> | |
""".format(description=description)) | |
# with gr.Tabs(): | |
# with gr.Tab("Object Addition"): | |
if 1: | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Input Image", type="pil", interactive=True) | |
instruction = gr.Textbox(lines=1, label="Addition Instruction", interactive=True, max_lines=1, placeholder="Enter addition instruction here") | |
model_choice = gr.Radio( | |
["Base-Model", "Finetuned-MB-Model"], | |
value="Base-Model", | |
type="index", | |
label="Choose Model", | |
interactive=True, | |
) | |
with gr.Group(elem_id="compact-box"): | |
with gr.Row(): | |
with gr.Column(): | |
steps = gr.Number(value=50, precision=0, label="Steps", interactive=True) | |
with gr.Column(): | |
with gr.Row(): | |
seed = gr.Number(value=2024, precision=0, label="Seed", interactive=True) | |
randomize_seed = gr.Radio( | |
["Fix Seed", "Randomize Seed"], | |
value="Randomize Seed", | |
type="index", | |
show_label=False, | |
interactive=True, | |
) | |
with gr.Row(): | |
text_cfg_scale = gr.Number(value=7.5, label="Text CFG", interactive=True) | |
image_cfg_scale = gr.Number(value=1.5, label="Image CFG", interactive=True) | |
with gr.Row(): | |
generate_button = gr.Button("Generate") | |
reset_button = gr.Button("Reset") | |
with gr.Column(): | |
edited_image = gr.Image(label="Edited Image", type="pil", interactive=False) | |
generate_button.click( | |
fn=lambda *args: generate(*args, task_type="object_addition"), | |
inputs=[ | |
input_image, | |
instruction, | |
model_choice, | |
steps, | |
randomize_seed, | |
seed, | |
text_cfg_scale, | |
image_cfg_scale, | |
], | |
outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image], | |
) | |
reset_button.click( | |
fn=reset, | |
inputs=[], | |
outputs=[steps, randomize_seed, seed, text_cfg_scale, image_cfg_scale, edited_image], | |
) | |
# with gr.Tab("General Editing"): | |
# with gr.Row(): | |
# with gr.Column(): | |
# input_image_editing = gr.Image(label="Input Image", type="pil", interactive=True) | |
# instruction_editing = gr.Textbox(lines=1, label="Editing Instruction", interactive=True, max_lines=1, placeholder="Enter editing instruction here") | |
# model_choice_editing = gr.Radio( | |
# ["Base-Model", "Finetuned-MB-Model"], | |
# value="Base-Model", | |
# type="index", | |
# label="Choose Model", | |
# interactive=True, | |
# ) | |
# with gr.Group(elem_id="compact-box"): | |
# with gr.Row(): | |
# steps_editing = gr.Number(value=50, precision=0, label="Steps", interactive=True) | |
# with gr.Column(): | |
# with gr.Row(): | |
# seed_editing = gr.Number(value=2024, precision=0, label="Seed", interactive=True) | |
# randomize_seed_editing = gr.Radio( | |
# ["Fix Seed", "Randomize Seed"], | |
# value="Randomize Seed", | |
# type="index", | |
# show_label=False, | |
# interactive=True, | |
# ) | |
# with gr.Row(): | |
# text_cfg_scale_editing = gr.Number(value=7.5, label="Text CFG", interactive=True) | |
# image_cfg_scale_editing = gr.Number(value=1.5, label="Image CFG", interactive=True) | |
# with gr.Row(): | |
# generate_button_editing = gr.Button("Generate") | |
# reset_button_editing = gr.Button("Reset") | |
# with gr.Column(): | |
# edited_image_editing = gr.Image(label="Edited Image", type="pil", interactive=False) | |
# generate_button_editing.click( | |
# fn=lambda *args: generate(*args, task_type="general_editing"), | |
# inputs=[ | |
# input_image_editing, | |
# instruction_editing, | |
# model_choice_editing, | |
# steps_editing, | |
# randomize_seed_editing, | |
# seed_editing, | |
# text_cfg_scale_editing, | |
# image_cfg_scale_editing, | |
# ], | |
# outputs=[seed_editing, text_cfg_scale_editing, image_cfg_scale_editing, edited_image_editing], | |
# ) | |
# reset_button_editing.click( | |
# fn=reset, | |
# inputs=[], | |
# outputs=[steps_editing, randomize_seed_editing, seed_editing, text_cfg_scale_editing, image_cfg_scale_editing, edited_image_editing], | |
# ) | |
gr.Markdown(help_text) | |
examples = [ | |
["examples/messi.jpeg", "Add a royal silver crown"], | |
["examples/coffee.jpg", "Add steamed milk"], | |
] | |
gr.Examples( | |
examples=examples, | |
inputs=[input_image, instruction], | |
outputs=[edited_image], | |
) | |
gr.HTML(article) | |
demo.queue() | |
demo.launch(share=False, max_threads=1) | |