InteractiveVideo / demo /main_gradio.py
Yiyuan's picture
Upload 98 files
96a9519 verified
import argparse
import time, os, sys
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
os.system('python scripts/download_models.py')
import gradio as gr
from PIL import Image
import numpy as np
import torch
from typing import List, Literal, Dict, Optional
from draw_utils import draw_points_on_image, draw_mask_on_image
import cv2
from models.streamdiffusion.wrapper import StreamDiffusionWrapper
from models.animatediff.pipelines import I2VPipeline
from omegaconf import OmegaConf
from models.draggan.viz.renderer import Renderer
from models.draggan.gan_inv.lpips.util import PerceptualLoss
import models.draggan.dnnlib as dnnlib
from models.draggan.gan_inv.inversion import PTI
import imageio
import torchvision
from einops import rearrange
# =========================== Model Implementation Start ===================================
def save_videos_grid_255(videos: torch.Tensor, path: str, n_rows=6, fps=8):
videos = rearrange(videos, "b c t h w -> t b c h w")
outputs = []
for x in videos:
x = torchvision.utils.make_grid(x, nrow=n_rows)
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
x = x.numpy().astype(np.uint8)
outputs.append(x)
os.makedirs(os.path.dirname(path), exist_ok=True)
imageio.mimsave(path, outputs, fps=fps)
def reverse_point_pairs(points):
new_points = []
for p in points:
new_points.append([p[1], p[0]])
return new_points
def render_view_image(img, drag_markers, show_mask=False):
img = draw_points_on_image(img, drag_markers['points'])
if show_mask:
img = draw_mask_on_image(img, drag_markers['mask'])
img = np.array(img).astype(np.uint8)
img = np.concatenate([
img,
255 * np.ones((img.shape[0], img.shape[1], 1), dtype=img.dtype)
], axis=2)
return Image.fromarray(img)
def update_state_image(state):
state['generated_image_show'] = render_view_image(
state['generated_image'],
state['drag_markers'][0],
state['is_show_mask'],
)
return state['generated_image_show']
class GeneratePipeline:
def __init__(
self,
i2i_body_ckpt: str = "checkpoints/diffusion_body/kohaku-v2.1",
# i2i_body_ckpt: str = "checkpoints/diffusion_body/stable-diffusion-v1-5",
i2i_lora_dict: Optional[Dict[str, float]] = {'checkpoints/i2i/lora/lcm-lora-sdv1-5.safetensors': 1.0},
prompt: str = "",
negative_prompt: str = "low quality, bad quality, blurry, low resolution",
frame_buffer_size: int = 1,
width: int = 512,
height: int = 512,
acceleration: Literal["none", "xformers", "tensorrt"] = "xformers",
use_denoising_batch: bool = True,
seed: int = 2,
cfg_type: Literal["none", "full", "self", "initialize"] = "self",
guidance_scale: float = 1.4,
delta: float = 0.5,
do_add_noise: bool = False,
enable_similar_image_filter: bool = True,
similar_image_filter_threshold: float = 0.99,
similar_image_filter_max_skip_frame: float = 10,
):
super(GeneratePipeline, self).__init__()
if not torch.cuda.is_available():
acceleration = None
self.img2img_model = None
self.img2video_model = None
self.img2video_generator = None
self.sim_ranges = None
# set parameters
self.i2i_body_ckpt = i2i_body_ckpt
self.i2i_lora_dict = i2i_lora_dict
self.prompt = prompt
self.negative_prompt = negative_prompt
self.frame_buffer_size = frame_buffer_size
self.width = width
self.height = height
self.acceleration = acceleration
self.use_denoising_batch = use_denoising_batch
self.seed = seed
self.cfg_type = cfg_type
self.guidance_scale = guidance_scale
self.delta = delta
self.do_add_noise = do_add_noise
self.enable_similar_image_filter = enable_similar_image_filter
self.similar_image_filter_threshold = similar_image_filter_threshold
self.similar_image_filter_max_skip_frame = similar_image_filter_max_skip_frame
self.i2v_config = OmegaConf.load('demo/configs/i2v_config.yaml')
self.i2v_body_ckpt = self.i2v_config.pretrained_model_path
self.i2v_unet_path = self.i2v_config.generate.model_path
self.i2v_dreambooth_ckpt = self.i2v_config.generate.db_path
self.lora_alpha = 0
assert self.frame_buffer_size == 1
def init_model(self):
# StreamDiffusion
self.img2img_model = StreamDiffusionWrapper(
model_id_or_path=self.i2i_body_ckpt,
lora_dict=self.i2i_lora_dict,
t_index_list=[32, 45],
frame_buffer_size=self.frame_buffer_size,
width=self.width,
height=self.height,
warmup=10,
acceleration=self.acceleration,
do_add_noise=self.do_add_noise,
enable_similar_image_filter=self.enable_similar_image_filter,
similar_image_filter_threshold=self.similar_image_filter_threshold,
similar_image_filter_max_skip_frame=self.similar_image_filter_max_skip_frame,
mode="img2img",
use_denoising_batch=self.use_denoising_batch,
cfg_type=self.cfg_type,
seed=self.seed,
use_lcm_lora=False,
)
self.img2img_model.prepare(
prompt=self.prompt,
negative_prompt=self.negative_prompt,
num_inference_steps=50,
guidance_scale=self.guidance_scale,
delta=self.delta,
)
# PIA
self.img2video_model = I2VPipeline.build_pipeline(
self.i2v_config,
self.i2v_body_ckpt,
self.i2v_unet_path,
self.i2v_dreambooth_ckpt,
None, # lora path
self.lora_alpha,
)
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
self.img2video_generator = torch.Generator(device=device)
self.img2video_generator.manual_seed(self.i2v_config.generate.global_seed)
self.sim_ranges = self.i2v_config.validation_data.mask_sim_range
# Drag GAN
self.drag_model = Renderer(disable_timing=True)
def generate_image(self, image, text, start_time=None):
if text is not None:
pos_prompt, neg_prompt = text
self.img2img_model.prepare(
prompt=pos_prompt,
negative_prompt=neg_prompt,
num_inference_steps=50,
guidance_scale=self.guidance_scale,
delta=self.delta,
)
sampled_inputs = [image]
input_batch = torch.cat(sampled_inputs)
output_images = self.img2img_model.stream(
input_batch.to(device=self.img2img_model.device, dtype=self.img2img_model.dtype)
)
# if start_time is not None:
# print('Generate Done: {}'.format(time.perf_counter() - start_time))
output_images = output_images.cpu()
# if start_time is not None:
# print('Move Done: {}'.format(time.perf_counter() - start_time))
return output_images
def generate_video(self, image, text, height=None, width=None):
pos_prompt, neg_prompt = text
sim_range = self.sim_ranges[0]
print(f"using sim_range : {sim_range}")
self.i2v_config.validation_data.mask_sim_range = sim_range
sample = self.img2video_model(
image = image,
prompt = pos_prompt,
generator = self.img2video_generator,
video_length = self.i2v_config.generate.video_length,
height = height if height is not None else self.i2v_config.generate.sample_height,
width = width if width is not None else self.i2v_config.generate.sample_width,
negative_prompt = neg_prompt,
mask_sim_template_idx = self.i2v_config.validation_data.mask_sim_range,
**self.i2v_config.validation_data,
).videos
return sample
def prepare_drag_model(
self,
custom_image: Image,
latent_space = 'w+',
trunc_psi = 0.7,
trunc_cutoff = None,
seed = 0,
lr = 0.001,
generator_params = dnnlib.EasyDict(),
pretrained_weight = 'stylegan2_lions_512_pytorch',
):
self.drag_model.init_network(
generator_params, # res
pretrained_weight, # pkl
seed, # w0_seed,
None, # w_load
latent_space == 'w+', # w_plus
'const',
trunc_psi, # trunc_psi,
trunc_cutoff, # trunc_cutoff,
None, # input_transform
lr # lr,
)
if torch.cuda.is_available():
percept = PerceptualLoss(model="net-lin", net="vgg", use_gpu=True)
else:
percept = PerceptualLoss(model="net-lin", net="vgg", use_gpu=False)
pti = PTI(self.drag_model.G, percept, max_pti_step=400)
inversed_img, w_pivot = pti.train(custom_image, latent_space == 'w+')
inversed_img = (inversed_img[0] * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0)
inversed_img = inversed_img.cpu().numpy()
inversed_img = Image.fromarray(inversed_img)
mask = np.ones((inversed_img.height, inversed_img.width),
dtype=np.uint8)
generator_params.image = inversed_img
generator_params.w = w_pivot.detach().cpu().numpy()
self.drag_model.set_latent(w_pivot, trunc_psi, trunc_cutoff)
del percept
del pti
print('inverse end')
return generator_params, mask
def drag_image(
self,
points,
mask,
motion_lambda = 20,
r1_in_pixels = 3,
r2_in_pixels = 12,
trunc_psi = 0.7,
draw_interval = 1,
generator_params = dnnlib.EasyDict(),
):
p_in_pixels = []
t_in_pixels = []
valid_points = []
# Transform the points into torch tensors
for key_point, point in points.items():
try:
p_start = point.get("start_temp", point["start"])
p_end = point["target"]
if p_start is None or p_end is None:
continue
except KeyError:
continue
p_in_pixels.append(p_start)
t_in_pixels.append(p_end)
valid_points.append(key_point)
mask = torch.tensor(mask).float()
drag_mask = 1 - mask
# reverse points order
p_to_opt = reverse_point_pairs(p_in_pixels)
t_to_opt = reverse_point_pairs(t_in_pixels)
step_idx = 0
self.drag_model._render_drag_impl(
generator_params,
p_to_opt, # point
t_to_opt, # target
drag_mask, # mask,
motion_lambda, # lambda_mask
reg = 0,
feature_idx = 5, # NOTE: do not support change for now
r1 = r1_in_pixels, # r1
r2 = r2_in_pixels, # r2
# random_seed = 0,
# noise_mode = 'const',
trunc_psi = trunc_psi,
# force_fp32 = False,
# layer_name = None,
# sel_channels = 3,
# base_channel = 0,
# img_scale_db = 0,
# img_normalize = False,
# untransform = False,
is_drag=True,
to_pil=True
)
points_upd = points
if step_idx % draw_interval == 0:
for key_point, p_i, t_i in zip(valid_points, p_to_opt,
t_to_opt):
points_upd[key_point]["start_temp"] = [
p_i[1],
p_i[0],
]
points_upd[key_point]["target"] = [
t_i[1],
t_i[0],
]
start_temp = points_upd[key_point][
"start_temp"]
image_result = generator_params['image']
return image_result
# ============================= Model Implementation ENd ===================================
parser = argparse.ArgumentParser()
parser.add_argument('--share', action='store_true',default='True')
parser.add_argument('--cache-dir', type=str, default='./checkpoints')
parser.add_argument(
"--listen",
action="store_true",
help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests",
)
args = parser.parse_args()
class CustomImageMask(gr.Image):
is_template = True
def __init__(
self,
source='upload',
tool='sketch',
elem_id="image_upload",
label='Generated Image',
type="pil",
mask_opacity=0.5,
brush_color='#FFFFFF',
height=400,
interactive=True,
**kwargs
):
super(CustomImageMask, self).__init__(
source=source,
tool=tool,
elem_id=elem_id,
label=label,
type=type,
mask_opacity=mask_opacity,
brush_color=brush_color,
height=height,
interactive=interactive,
**kwargs
)
def preprocess(self, x):
if x is None:
return x
if self.tool == 'sketch' and self.source in ['upload', 'webcam'] and type(x) != dict:
decode_image = gr.processing_utils.decode_base64_to_image(x)
width, height = decode_image.size
mask = np.ones((height, width, 4), dtype=np.uint8)
mask[..., -1] = 255
mask = self.postprocess(mask)
x = {'image': x, 'mask': mask}
return super().preprocess(x)
draggan_ckpts = os.listdir('checkpoints/drag')
draggan_ckpts.sort()
generate_pipeline = GeneratePipeline()
generate_pipeline.init_model()
with gr.Blocks() as demo:
global_state = gr.State(
{
'is_image_generation': True,
'is_image_text_prompt_up-to-date': True,
'is_show_mask': False,
'is_dragging': False,
'generated_image': None,
'generated_image_show': None,
'drag_markers': [
{
'points': {},
'mask': None
}
],
'generator_params': dnnlib.EasyDict(),
'default_image_text_prompts': ('', 'low quality, bad quality, blurry, low resolution'),
'default_video_text_prompts': ('', 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg'),
'image_text_prompts': ('', 'low quality, bad quality, blurry, low resolution'),
'video_text_prompts': ('', 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg'),
'params': {
'seed': 0,
'motion_lambda': 20,
'r1_in_pixels': 3,
'r2_in_pixels': 12,
'magnitude_direction_in_pixels': 1.0,
'latent_space': 'w+',
'trunc_psi': 0.7,
'trunc_cutoff': None,
'lr': 0.001,
},
'device': None, # device,
'draw_interval': 1,
'points': {},
'curr_point': None,
'curr_type_point': 'start',
'editing_state': 'add_points',
'pretrained_weight': draggan_ckpts[0],
'video_preview_resolution': '512 x 512',
'viewer_height': 300,
'viewer_width': 300
}
)
with gr.Column():
with gr.Row():
with gr.Column(scale=8, min_width=10):
with gr.Tab('Image Text Prompts'):
image_pos_text_prompt_editor = gr.Textbox(placeholder='Positive Prompts', label='Positive', min_width=10)
image_neg_text_prompt_editor = gr.Textbox(placeholder='Negative Prompts', label='Negative', min_width=10)
with gr.Tab('Video Text Prompts'):
video_pos_text_prompt_editor = gr.Textbox(placeholder='Positive Prompts', label='Positive', min_width=10)
video_neg_text_prompt_editor = gr.Textbox(placeholder='Negative Prompts', label='Negative', min_width=10)
with gr.Tab('Drag Image'):
with gr.Row():
with gr.Column(scale=1, min_width=10):
drag_mode_on_button = gr.Button('Drag Mode On', size='sm', min_width=10)
drag_mode_off_button = gr.Button('Drag Mode Off', size='sm', min_width=10)
drag_checkpoint_dropdown = gr.Dropdown(choices=draggan_ckpts, value=draggan_ckpts[0], label='checkpoint', min_width=10)
with gr.Column(scale=1, min_width=10):
with gr.Row():
drag_start_button = gr.Button('start', size='sm', min_width=10)
drag_stop_button = gr.Button('stop', size='sm', min_width=10)
with gr.Row():
add_point_button = gr.Button('add point', size='sm', min_width=10)
reset_point_button = gr.Button('reset point', size='sm', min_width=10)
with gr.Row():
steps_number = gr.Number(0, label='steps', interactive=False)
with gr.Column(scale=1, min_width=10):
with gr.Row():
draw_mask_button = gr.Button('draw mask', size='sm', min_width=10)
reset_mask_button = gr.Button('reset mask', size='sm', min_width=10)
with gr.Row():
show_mask_checkbox = gr.Checkbox(value=False, label='show mask', min_width=10, interactive=True)
with gr.Row():
motion_lambda_number = gr.Number(20, label='Motion Lambda', minimum=1, maximum=100, step=1, interactive=True)
with gr.Tab('More'):
with gr.Row():
with gr.Column(scale=2, min_width=10):
video_preview_resolution_dropdown = gr.Dropdown(choices=['256 x 256', '512 x 512'], value='512 x 512', label='Video Preview Resolution', min_width=10)
sample_image_dropdown = gr.Dropdown(choices=['samples/canvas.jpg'] + ['samples/sample{:>02d}.jpg'.format(i) for i in range(1, 8)], value=None, label='Choose A Sample Image', min_width=10)
with gr.Column(scale=1, min_width=10):
confirm_text_button = gr.Button('Confirm Text', size='sm', min_width=10)
generate_video_button = gr.Button('Generate Video', size='sm', min_width=10)
clear_video_button = gr.Button('Clear Video', size='sm', min_width=10)
with gr.Row():
captured_image_viewer = gr.Image(source='upload', tool='color-sketch', type='pil', label='Image Drawer', height=global_state.value['viewer_height'], width=global_state.value['viewer_width'], interactive=True, shape=(global_state.value['viewer_width'], global_state.value['viewer_height'])) #
generated_image_viewer = CustomImageMask(source='upload', tool='sketch', elem_id="image_upload", label='Generated Image', type="pil", mask_opacity=0.5, brush_color='#FFFFFF', height=global_state.value['viewer_height'], width=global_state.value['viewer_width'], interactive=True)
generated_video_viewer = gr.Video(source='upload', label='Generated Video', height=global_state.value['viewer_height'], width=global_state.value['viewer_width'], interactive=False)
gr.Markdown(
"""
## Quick Start
1. Select one sample image in `More` tab.
2. Draw to edit the sample image in the left most image viewer.
3. Click `Generate Video` and enjoy it!
## Note
Due to the limitation of gradio implementation, the image-to-image generation might have a large latency after the model generation is done.
We command you to enjoy a better experience with our local demo at [github](https://github.com/invictus717/InteractiveVideo).
## Advance Usage
1. **Try different text prompts.** Enter positive or negative prompts for image / video generation, and
click `Confirm Text` to enable your prompts.
2. **Drag images.** Go to `Drag Image` tab, choose a suitable checkpoint and click `Drag Mode On`.
It might take a minute to prepare. Properly add points and use masks, then click `start` to
start dragging. Once you think it's ok, click `stop` button.
3. **Adjust video resolution** in the `More` tab.
4. **Draw from scratch** by choosing `canvas.jpg` in `More` tab and enjoy yourself!
"""
)
# ========================= Main Function Start =============================
def on_captured_image_viewer_update(state, image):
if image is None:
return state, gr.Image.update(None)
if state['is_image_text_prompt_up-to-date']:
text_prompts = None
else:
text_prompts = state['image_text_prompts']
state['is_image_text_prompt_up-to-date'] = True
# start_time = time.perf_counter()
input_image = np.array(image).astype(np.float32)
input_image = (input_image / 255 - 0.5) * 2
input_image = torch.tensor(input_image).permute([2, 0, 1])
noisy_image = torch.randn_like(input_image)
# print('preprocess done: {}'.format(time.perf_counter() - start_time))
output_image = generate_pipeline.generate_image(
input_image,
text_prompts,
# start_time,
)[0]
output_image = generate_pipeline.generate_image(
noisy_image,
None,
# start_time,
)[0] # TODO: is there more elegant way?
output_image = output_image.permute([1, 2, 0])
output_image = (output_image / 2 + 0.5).clamp(0, 1) * 255
output_image = output_image.to(torch.uint8).cpu().numpy()
output_image = Image.fromarray(output_image)
# print('postprocess done: {}'.format(time.perf_counter() - start_time))
# output_image = image
state['generated_image'] = output_image
output_image = update_state_image(state)
# print('draw done: {}'.format(time.perf_counter() - start_time))
return state, gr.Image.update(output_image, interactive=False)
captured_image_viewer.change(
fn=on_captured_image_viewer_update,
inputs=[global_state, captured_image_viewer],
outputs=[global_state, generated_image_viewer]
)
def on_generated_image_viewer_edit(state, data_dict):
mask = data_dict['mask']
state['drag_markers'][0]['mask'] = np.array(mask)[:, :, 0] // 255
image = update_state_image(state)
return state, image
generated_image_viewer.edit(
fn=on_generated_image_viewer_edit,
inputs=[global_state, generated_image_viewer],
outputs=[global_state, generated_image_viewer]
)
def on_generate_video_click(state):
input_image = np.array(state['generated_image'])
text_prompts = state['video_text_prompts']
video_preview_resolution = state['video_preview_resolution'].split('x')
height = int(video_preview_resolution[0].strip(' '))
width = int(video_preview_resolution[1].strip(' '))
output_video = generate_pipeline.generate_video(
input_image,
text_prompts,
height = height,
width = width
)[0]
output_video = output_video.clamp(0, 1) * 255
output_video = output_video.to(torch.uint8)
# 3 T H W
print('[video generation done]')
fps = 5 # frames per second
video_size = (height, width)
fourcc = cv2.VideoWriter.fourcc(*'mp4v')
if not os.access('results', os.F_OK):
os.makedirs('results')
video_writer = cv2.VideoWriter('results/gradio_temp.mp4', fourcc, fps, video_size) # Create VideoWriter object
for i in range(output_video.shape[1]):
frame = output_video[:, i, :, :].permute([1, 2, 0]).cpu().numpy()
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
video_writer.write(frame)
video_writer.release()
return state, gr.Video.update('results/gradio_temp.mp4')
generate_video_button.click(
fn=on_generate_video_click,
inputs=[global_state],
outputs=[global_state, generated_video_viewer]
)
def on_clear_video_click(state):
return state, gr.Video.update(None)
clear_video_button.click(
fn=on_clear_video_click,
inputs=[global_state],
outputs=[global_state, generated_video_viewer]
)
def on_drag_mode_on_click(state):
# prepare DragGAN for custom image
custom_image = state['generated_image']
current_ckpt_name = state['pretrained_weight']
generate_pipeline.prepare_drag_model(
custom_image,
generator_params = state['generator_params'],
pretrained_weight = os.path.join('checkpoints/drag/', current_ckpt_name),
)
state['generated_image'] = state['generator_params'].image
view_image = update_state_image(state)
return state, gr.Image.update(view_image, interactive=True)
drag_mode_on_button.click(
fn=on_drag_mode_on_click,
inputs=[global_state],
outputs=[global_state, generated_image_viewer]
)
def on_drag_mode_off_click(state, image):
return on_captured_image_viewer_update(state, image)
drag_mode_off_button.click(
fn=on_drag_mode_off_click,
inputs=[global_state, captured_image_viewer],
outputs=[global_state, generated_image_viewer]
)
def on_drag_start_click(state):
state['is_dragging'] = True
points = state['drag_markers'][0]['points']
if state['drag_markers'][0]['mask'] is None:
mask = np.ones((state['generator_params'].image.height, state['generator_params'].image.width), dtype=np.uint8)
else:
mask = state['drag_markers'][0]['mask']
cur_step = 0
while True:
if not state['is_dragging']:
break
generated_image = generate_pipeline.drag_image(
points,
mask,
motion_lambda = state['params']['motion_lambda'],
generator_params = state['generator_params']
)
state['drag_markers'] = [{'points': points, 'mask': mask}]
state['generated_image'] = generated_image
cur_step += 1
view_image = update_state_image(state)
if cur_step % 50 == 0:
print('[{} / {}]'.format(cur_step, 'inf'))
yield (
state,
gr.Image.update(view_image, interactive=False), # generated image viewer
gr.Number.update(cur_step), # step
)
view_image = update_state_image(state)
return (
state,
gr.Image.update(view_image, interactive=True),
gr.Number.update(cur_step),
)
drag_start_button.click(
fn=on_drag_start_click,
inputs=[global_state],
outputs=[global_state, generated_image_viewer, steps_number]
)
def on_drag_stop_click(state):
state['is_dragging'] = False
return state
drag_stop_button.click(
fn=on_drag_stop_click,
inputs=[global_state],
outputs=[global_state]
)
# ========================= Main Function End =============================
# ====================== Update Text Prompts Start ====================
def on_image_pos_text_prompt_editor_submit(state, text):
if len(text) == 0:
temp = state['image_text_prompts']
state['image_text_prompts'] = (state['default_image_text_prompts'][0], temp[1])
else:
temp = state['image_text_prompts']
state['image_text_prompts'] = (text, temp[1])
state['is_image_text_prompt_up-to-date'] = False
return state
image_pos_text_prompt_editor.submit(
fn=on_image_pos_text_prompt_editor_submit,
inputs=[global_state, image_pos_text_prompt_editor],
outputs=None
)
def on_image_neg_text_prompt_editor_submit(state, text):
if len(text) == 0:
temp = state['image_text_prompts']
state['image_text_prompts'] = (temp[0], state['default_image_text_prompts'][1])
else:
temp = state['image_text_prompts']
state['image_text_prompts'] = (temp[0], text)
state['is_image_text_prompt_up-to-date'] = False
return state
image_neg_text_prompt_editor.submit(
fn=on_image_neg_text_prompt_editor_submit,
inputs=[global_state, image_neg_text_prompt_editor],
outputs=None
)
def on_video_pos_text_prompt_editor_submit(state, text):
if len(text) == 0:
temp = state['video_text_prompts']
state['video_text_prompts'] = (state['default_video_text_prompts'][0], temp[1])
else:
temp = state['video_text_prompts']
state['video_text_prompts'] = (text, temp[1])
return state
video_pos_text_prompt_editor.submit(
fn=on_video_pos_text_prompt_editor_submit,
inputs=[global_state, video_pos_text_prompt_editor],
outputs=None
)
def on_video_neg_text_prompt_editor_submit(state, text):
if len(text) == 0:
temp = state['video_text_prompts']
state['video_text_prompts'] = (temp[0], state['default_video_text_prompts'][1])
else:
temp = state['video_text_prompts']
state['video_text_prompts'] = (temp[0], text)
return state
video_neg_text_prompt_editor.submit(
fn=on_video_neg_text_prompt_editor_submit,
inputs=[global_state, video_neg_text_prompt_editor],
outputs=None
)
def on_confirm_text_click(state, image, img_pos_t, img_neg_t, vid_pos_t, vid_neg_t):
state = on_image_pos_text_prompt_editor_submit(state, img_pos_t)
state = on_image_neg_text_prompt_editor_submit(state, img_neg_t)
state = on_video_pos_text_prompt_editor_submit(state, vid_pos_t)
state = on_video_neg_text_prompt_editor_submit(state, vid_neg_t)
return on_captured_image_viewer_update(state, image)
confirm_text_button.click(
fn=on_confirm_text_click,
inputs=[global_state, captured_image_viewer, image_pos_text_prompt_editor, image_neg_text_prompt_editor,
video_pos_text_prompt_editor, video_neg_text_prompt_editor],
outputs=[global_state, generated_image_viewer]
)
# ====================== Update Text Prompts End ====================
# ======================= Drag Point Edit Start =========================
def on_image_clicked(state, evt: gr.SelectData):
"""
This function only support click for point selection
"""
pos_x, pos_y = evt.index
drag_markers = state['drag_markers']
key_points = list(drag_markers[0]['points'].keys())
key_points.sort(reverse=False)
if len(key_points) == 0: # no point pairs, add a new point pair
drag_markers[0]['points'][0] = {
'start_temp': [pos_x, pos_y],
'start': [pos_x, pos_y],
'target': None,
}
else:
largest_id = key_points[-1]
if drag_markers[0]['points'][largest_id]['target'] is None: # target is not set
drag_markers[0]['points'][largest_id]['target'] = [pos_x, pos_y]
else: # target is set, add a new point pair
drag_markers[0]['points'][largest_id + 1] = {
'start_temp': [pos_x, pos_y],
'start': [pos_x, pos_y],
'target': None,
}
state['drag_markers'] = drag_markers
image = update_state_image(state)
return state, gr.Image.update(image, interactive=False)
generated_image_viewer.select(
fn=on_image_clicked,
inputs=[global_state],
outputs=[global_state, generated_image_viewer],
)
def on_add_point_click(state):
return gr.Image.update(state['generated_image_show'], interactive=False)
add_point_button.click(
fn=on_add_point_click,
inputs=[global_state],
outputs=[generated_image_viewer]
)
def on_reset_point_click(state):
drag_markers = state['drag_markers']
drag_markers[0]['points'] = {}
state['drag_markers'] = drag_markers
image = update_state_image(state)
return state, gr.Image.update(image)
reset_point_button.click(
fn=on_reset_point_click,
inputs=[global_state],
outputs=[global_state, generated_image_viewer]
)
# ======================= Drag Point Edit End =========================
# ======================= Drag Mask Edit Start =========================
def on_draw_mask_click(state):
return gr.Image.update(state['generated_image_show'], interactive=True)
draw_mask_button.click(
fn=on_draw_mask_click,
inputs=[global_state],
outputs=[generated_image_viewer]
)
def on_reset_mask_click(state):
drag_markers = state['drag_markers']
drag_markers[0]['mask'] = np.ones_like(drag_markers[0]['mask'])
state['drag_markers'] = drag_markers
image = update_state_image(state)
return state, gr.Image.update(image)
reset_mask_button.click(
fn=on_reset_mask_click,
inputs=[global_state],
outputs=[global_state, generated_image_viewer]
)
def on_show_mask_click(state, evt: gr.SelectData):
state['is_show_mask'] = evt.selected
image = update_state_image(state)
return state, image
show_mask_checkbox.select(
fn=on_show_mask_click,
inputs=[global_state],
outputs=[global_state, generated_image_viewer]
)
# ======================= Drag Mask Edit End =========================
# ======================= Drag Setting Start =========================
def on_motion_lambda_change(state, number):
state['params']['number'] = number
return state
motion_lambda_number.input(
fn=on_motion_lambda_change,
inputs=[global_state, motion_lambda_number],
outputs=[global_state]
)
def on_drag_checkpoint_change(state, checkpoint):
state['pretrained_weight'] = checkpoint
print(type(checkpoint), checkpoint)
return state
drag_checkpoint_dropdown.change(
fn=on_drag_checkpoint_change,
inputs=[global_state, drag_checkpoint_dropdown],
outputs=[global_state]
)
# ======================= Drag Setting End =========================
# ======================= General Setting Start =========================
def on_video_preview_resolution_change(state, resolution):
state['video_preview_resolution'] = resolution
return state
video_preview_resolution_dropdown.change(
fn=on_video_preview_resolution_change,
inputs=[global_state, video_preview_resolution_dropdown],
outputs=[global_state]
)
def on_sample_image_change(state, image):
return state, gr.Image.update(image)
sample_image_dropdown.change(
fn=on_sample_image_change,
inputs=[global_state, sample_image_dropdown],
outputs=[global_state, captured_image_viewer]
)
# ======================= General Setting End =========================
demo.queue(concurrency_count=3, max_size=20)
# demo.launch(share=False, server_name="0.0.0.0" if args.listen else "127.0.0.1")
demo.launch()