Spaces:
Runtime error
Runtime error
# ************************************************************************* | |
# Copyright (2023) Bytedance Inc. | |
# | |
# Copyright (2023) DragDiffusion Authors | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ************************************************************************* | |
import os | |
import gradio as gr | |
from utils.ui_utils import get_points, undo_points | |
from utils.ui_utils import clear_all, store_img, train_lora_interface, run_drag | |
from utils.ui_utils import clear_all_gen, store_img_gen, gen_img, run_drag_gen | |
LENGTH=480 # length of the square area displaying/editing images | |
with gr.Blocks() as demo: | |
# layout definition | |
with gr.Row(): | |
gr.Markdown(""" | |
# Official Implementation of [DragDiffusion](https://arxiv.org/abs/2306.14435) | |
""") | |
# UI components for editing real images | |
with gr.Tab(label="Editing Real Image"): | |
mask = gr.State(value=None) # store mask | |
selected_points = gr.State([]) # store points | |
original_image = gr.State(value=None) # store original input image | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("""<p style="text-align: center; font-size: 20px">Draw Mask</p>""") | |
canvas = gr.Image(type="numpy", tool="sketch", label="Draw Mask", | |
show_label=True, height=LENGTH, width=LENGTH) # for mask painting | |
train_lora_button = gr.Button("Train LoRA") | |
with gr.Column(): | |
gr.Markdown("""<p style="text-align: center; font-size: 20px">Click Points</p>""") | |
input_image = gr.Image(type="numpy", label="Click Points", | |
show_label=True, height=LENGTH, width=LENGTH) # for points clicking | |
undo_button = gr.Button("Undo point") | |
with gr.Column(): | |
gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing Results</p>""") | |
output_image = gr.Image(type="numpy", label="Editing Results", | |
show_label=True, height=LENGTH, width=LENGTH) | |
with gr.Row(): | |
run_button = gr.Button("Run") | |
clear_all_button = gr.Button("Clear All") | |
# general parameters | |
with gr.Row(): | |
prompt = gr.Textbox(label="Prompt") | |
lora_path = gr.Textbox(value="./lora_tmp", label="LoRA path") | |
lora_status_bar = gr.Textbox(label="display LoRA training status") | |
# algorithm specific parameters | |
with gr.Tab("Drag Config"): | |
with gr.Row(): | |
n_pix_step = gr.Number( | |
value=40, | |
label="number of pixel steps", | |
info="Number of gradient descent (motion supervision) steps on latent.", | |
precision=0) | |
lam = gr.Number(value=0.1, label="lam", info="regularization strength on unmasked areas") | |
# n_actual_inference_step = gr.Number(value=40, label="optimize latent step", precision=0) | |
inversion_strength = gr.Slider(0, 1.0, | |
value=0.75, | |
label="inversion strength", | |
info="The latent at [inversion-strength * total-sampling-steps] is optimized for dragging.") | |
latent_lr = gr.Number(value=0.01, label="latent lr") | |
start_step = gr.Number(value=0, label="start_step", precision=0, visible=False) | |
start_layer = gr.Number(value=10, label="start_layer", precision=0, visible=False) | |
with gr.Tab("Base Model Config"): | |
with gr.Row(): | |
local_models_dir = 'local_pretrained_models' | |
local_models_choice = \ | |
[os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))] | |
model_path = gr.Dropdown(value="runwayml/stable-diffusion-v1-5", | |
label="Diffusion Model Path", | |
choices=[ | |
"runwayml/stable-diffusion-v1-5", | |
] + local_models_choice | |
) | |
vae_path = gr.Dropdown(value="default", | |
label="VAE choice", | |
choices=["default", | |
"stabilityai/sd-vae-ft-mse"] + local_models_choice | |
) | |
with gr.Tab("LoRA Parameters"): | |
with gr.Row(): | |
lora_step = gr.Number(value=200, label="LoRA training steps", precision=0) | |
lora_lr = gr.Number(value=0.0002, label="LoRA learning rate") | |
lora_rank = gr.Number(value=16, label="LoRA rank", precision=0) | |
# UI components for editing generated images | |
with gr.Tab(label="Editing Generated Image"): | |
mask_gen = gr.State(value=None) # store mask | |
selected_points_gen = gr.State([]) # store points | |
original_image_gen = gr.State(value=None) # store the diffusion-generated image | |
intermediate_latents_gen = gr.State(value=None) # store the intermediate diffusion latent during generation | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("""<p style="text-align: center; font-size: 20px">Draw Mask</p>""") | |
canvas_gen = gr.Image(type="numpy", tool="sketch", label="Draw Mask", | |
show_label=True, height=LENGTH, width=LENGTH) # for mask painting | |
gen_img_button = gr.Button("Generate Image") | |
with gr.Column(): | |
gr.Markdown("""<p style="text-align: center; font-size: 20px">Click Points</p>""") | |
input_image_gen = gr.Image(type="numpy", label="Click Points", | |
show_label=True, height=LENGTH, width=LENGTH) # for points clicking | |
undo_button_gen = gr.Button("Undo point") | |
with gr.Column(): | |
gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing Results</p>""") | |
output_image_gen = gr.Image(type="numpy", label="Editing Results", | |
show_label=True, height=LENGTH, width=LENGTH) | |
with gr.Row(): | |
run_button_gen = gr.Button("Run") | |
clear_all_button_gen = gr.Button("Clear All") | |
# general parameters | |
with gr.Row(): | |
pos_prompt_gen = gr.Textbox(label="Positive Prompt") | |
neg_prompt_gen = gr.Textbox(label="Negative Prompt") | |
with gr.Tab("Generation Config"): | |
with gr.Row(): | |
local_models_dir = 'local_pretrained_models' | |
local_models_choice = \ | |
[os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))] | |
model_path_gen = gr.Dropdown(value="runwayml/stable-diffusion-v1-5", | |
label="Diffusion Model Path", | |
choices=[ | |
"runwayml/stable-diffusion-v1-5", | |
"gsdf/Counterfeit-V2.5", | |
"emilianJR/majicMIX_realistic", | |
"SG161222/Realistic_Vision_V2.0", | |
"stablediffusionapi/landscapesupermix", | |
"huangzhe0803/ArchitectureRealMix", | |
"stablediffusionapi/interiordesignsuperm" | |
] + local_models_choice | |
) | |
vae_path_gen = gr.Dropdown(value="default", | |
label="VAE choice", | |
choices=["default", | |
"stabilityai/sd-vae-ft-mse"] + local_models_choice | |
) | |
lora_path_gen = gr.Textbox(value="", label="LoRA path") | |
gen_seed = gr.Number(value=65536, label="Generation Seed", precision=0) | |
height = gr.Number(value=512, label="Height", precision=0) | |
width = gr.Number(value=512, label="Width", precision=0) | |
guidance_scale = gr.Number(value=7.5, label="CFG Scale") | |
scheduler_name_gen = gr.Dropdown( | |
value="DDIM", | |
label="Scheduler", | |
choices=[ | |
"DDIM", | |
"DPM++2M", | |
"DPM++2M_karras" | |
] | |
) | |
n_inference_step_gen = gr.Number(value=50, label="Total Sampling Steps", precision=0) | |
with gr.Tab(label="Drag Config"): | |
with gr.Row(): | |
n_pix_step_gen = gr.Number( | |
value=40, | |
label="Number of Pixel Steps", | |
info="Number of gradient descent (motion supervision) steps on latent.", | |
precision=0) | |
lam_gen = gr.Number(value=0.1, label="lam", info="regularization strength on unmasked areas") | |
# n_actual_inference_step_gen = gr.Number(value=40, label="optimize latent step", precision=0) | |
inversion_strength_gen = gr.Slider(0, 1.0, | |
value=0.75, | |
label="Inversion Strength", | |
info="The latent at [inversion-strength * total-sampling-steps] is optimized for dragging.") | |
latent_lr_gen = gr.Number(value=0.01, label="latent lr") | |
start_step_gen = gr.Number(value=0, label="start_step", precision=0, visible=False) | |
start_layer_gen = gr.Number(value=10, label="start_layer", precision=0, visible=False) | |
# Add a checkbox for users to select if they want a GIF of the process | |
with gr.Row(): | |
create_gif_checkbox = gr.Checkbox(label="create_GIF", value=False) | |
create_tracking_point_checkbox = gr.Checkbox(label="create_tracking_point", value=False) | |
gif_interval = gr.Number(value=10, label="interval_GIF", precision=0, info="The interval of the GIF, i.e. the number of steps between each frame of the GIF.") | |
gif_fps = gr.Number(value=1, label="fps_GIF", precision=0, info="The fps of the GIF, i.e. the number of frames per second of the GIF.") | |
# event definition | |
# event for dragging user-input real image | |
canvas.edit( | |
store_img, | |
[canvas], | |
[original_image, selected_points, input_image, mask] | |
) | |
input_image.select( | |
get_points, | |
[input_image, selected_points], | |
[input_image], | |
) | |
undo_button.click( | |
undo_points, | |
[original_image, mask], | |
[input_image, selected_points] | |
) | |
train_lora_button.click( | |
train_lora_interface, | |
[original_image, | |
prompt, | |
model_path, | |
vae_path, | |
lora_path, | |
lora_step, | |
lora_lr, | |
lora_rank], | |
[lora_status_bar] | |
) | |
run_button.click( | |
run_drag, | |
[original_image, | |
input_image, | |
mask, | |
prompt, | |
selected_points, | |
inversion_strength, | |
lam, | |
latent_lr, | |
n_pix_step, | |
model_path, | |
vae_path, | |
lora_path, | |
start_step, | |
start_layer, | |
create_gif_checkbox, | |
gif_interval, | |
], | |
[output_image] | |
) | |
clear_all_button.click( | |
clear_all, | |
[gr.Number(value=LENGTH, visible=False, precision=0)], | |
[canvas, | |
input_image, | |
output_image, | |
selected_points, | |
original_image, | |
mask] | |
) | |
# event for dragging generated image | |
canvas_gen.edit( | |
store_img_gen, | |
[canvas_gen], | |
[original_image_gen, selected_points_gen, input_image_gen, mask_gen] | |
) | |
input_image_gen.select( | |
get_points, | |
[input_image_gen, selected_points_gen], | |
[input_image_gen], | |
) | |
gen_img_button.click( | |
gen_img, | |
[ | |
gr.Number(value=LENGTH, visible=False, precision=0), | |
height, | |
width, | |
n_inference_step_gen, | |
scheduler_name_gen, | |
gen_seed, | |
guidance_scale, | |
pos_prompt_gen, | |
neg_prompt_gen, | |
model_path_gen, | |
vae_path_gen, | |
lora_path_gen, | |
], | |
[canvas_gen, input_image_gen, output_image_gen, mask_gen, intermediate_latents_gen] | |
) | |
undo_button_gen.click( | |
undo_points, | |
[original_image_gen, mask_gen], | |
[input_image_gen, selected_points_gen] | |
) | |
run_button_gen.click( | |
run_drag_gen, | |
[ | |
n_inference_step_gen, | |
scheduler_name_gen, | |
original_image_gen, # the original image generated by the diffusion model | |
input_image_gen, # image with clicking, masking, etc. | |
intermediate_latents_gen, | |
guidance_scale, | |
mask_gen, | |
pos_prompt_gen, | |
neg_prompt_gen, | |
selected_points_gen, | |
inversion_strength_gen, | |
lam_gen, | |
latent_lr_gen, | |
n_pix_step_gen, | |
model_path_gen, | |
vae_path_gen, | |
lora_path_gen, | |
start_step_gen, | |
start_layer_gen, | |
create_gif_checkbox, | |
create_tracking_point_checkbox, | |
gif_interval, | |
gif_fps | |
], | |
[output_image_gen] | |
) | |
clear_all_button_gen.click( | |
clear_all_gen, | |
[gr.Number(value=LENGTH, visible=False, precision=0)], | |
[canvas_gen, | |
input_image_gen, | |
output_image_gen, | |
selected_points_gen, | |
original_image_gen, | |
mask_gen, | |
intermediate_latents_gen, | |
] | |
) | |
demo.queue().launch(share=True, debug=True) | |