Spaces:
Sleeping
Sleeping
import argparse | |
import time, os, sys | |
sys.path.append(os.path.dirname(os.path.dirname(__file__))) | |
os.system('python scripts/download_models.py') | |
import gradio as gr | |
from PIL import Image | |
import numpy as np | |
import torch | |
from typing import List, Literal, Dict, Optional | |
from draw_utils import draw_points_on_image, draw_mask_on_image | |
import cv2 | |
from models.streamdiffusion.wrapper import StreamDiffusionWrapper | |
from models.animatediff.pipelines import I2VPipeline | |
from omegaconf import OmegaConf | |
from models.draggan.viz.renderer import Renderer | |
from models.draggan.gan_inv.lpips.util import PerceptualLoss | |
import models.draggan.dnnlib as dnnlib | |
from models.draggan.gan_inv.inversion import PTI | |
import imageio | |
import torchvision | |
from einops import rearrange | |
# =========================== Model Implementation Start =================================== | |
def save_videos_grid_255(videos: torch.Tensor, path: str, n_rows=6, fps=8): | |
videos = rearrange(videos, "b c t h w -> t b c h w") | |
outputs = [] | |
for x in videos: | |
x = torchvision.utils.make_grid(x, nrow=n_rows) | |
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) | |
x = x.numpy().astype(np.uint8) | |
outputs.append(x) | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
imageio.mimsave(path, outputs, fps=fps) | |
def reverse_point_pairs(points): | |
new_points = [] | |
for p in points: | |
new_points.append([p[1], p[0]]) | |
return new_points | |
def render_view_image(img, drag_markers, show_mask=False): | |
img = draw_points_on_image(img, drag_markers['points']) | |
if show_mask: | |
img = draw_mask_on_image(img, drag_markers['mask']) | |
img = np.array(img).astype(np.uint8) | |
img = np.concatenate([ | |
img, | |
255 * np.ones((img.shape[0], img.shape[1], 1), dtype=img.dtype) | |
], axis=2) | |
return Image.fromarray(img) | |
def update_state_image(state): | |
state['generated_image_show'] = render_view_image( | |
state['generated_image'], | |
state['drag_markers'][0], | |
state['is_show_mask'], | |
) | |
return state['generated_image_show'] | |
class GeneratePipeline: | |
def __init__( | |
self, | |
i2i_body_ckpt: str = "checkpoints/diffusion_body/kohaku-v2.1", | |
# i2i_body_ckpt: str = "checkpoints/diffusion_body/stable-diffusion-v1-5", | |
i2i_lora_dict: Optional[Dict[str, float]] = {'checkpoints/i2i/lora/lcm-lora-sdv1-5.safetensors': 1.0}, | |
prompt: str = "", | |
negative_prompt: str = "low quality, bad quality, blurry, low resolution", | |
frame_buffer_size: int = 1, | |
width: int = 512, | |
height: int = 512, | |
acceleration: Literal["none", "xformers", "tensorrt"] = "xformers", | |
use_denoising_batch: bool = True, | |
seed: int = 2, | |
cfg_type: Literal["none", "full", "self", "initialize"] = "self", | |
guidance_scale: float = 1.4, | |
delta: float = 0.5, | |
do_add_noise: bool = False, | |
enable_similar_image_filter: bool = True, | |
similar_image_filter_threshold: float = 0.99, | |
similar_image_filter_max_skip_frame: float = 10, | |
): | |
super(GeneratePipeline, self).__init__() | |
if not torch.cuda.is_available(): | |
acceleration = None | |
self.img2img_model = None | |
self.img2video_model = None | |
self.img2video_generator = None | |
self.sim_ranges = None | |
# set parameters | |
self.i2i_body_ckpt = i2i_body_ckpt | |
self.i2i_lora_dict = i2i_lora_dict | |
self.prompt = prompt | |
self.negative_prompt = negative_prompt | |
self.frame_buffer_size = frame_buffer_size | |
self.width = width | |
self.height = height | |
self.acceleration = acceleration | |
self.use_denoising_batch = use_denoising_batch | |
self.seed = seed | |
self.cfg_type = cfg_type | |
self.guidance_scale = guidance_scale | |
self.delta = delta | |
self.do_add_noise = do_add_noise | |
self.enable_similar_image_filter = enable_similar_image_filter | |
self.similar_image_filter_threshold = similar_image_filter_threshold | |
self.similar_image_filter_max_skip_frame = similar_image_filter_max_skip_frame | |
self.i2v_config = OmegaConf.load('demo/configs/i2v_config.yaml') | |
self.i2v_body_ckpt = self.i2v_config.pretrained_model_path | |
self.i2v_unet_path = self.i2v_config.generate.model_path | |
self.i2v_dreambooth_ckpt = self.i2v_config.generate.db_path | |
self.lora_alpha = 0 | |
assert self.frame_buffer_size == 1 | |
def init_model(self): | |
# StreamDiffusion | |
self.img2img_model = StreamDiffusionWrapper( | |
model_id_or_path=self.i2i_body_ckpt, | |
lora_dict=self.i2i_lora_dict, | |
t_index_list=[32, 45], | |
frame_buffer_size=self.frame_buffer_size, | |
width=self.width, | |
height=self.height, | |
warmup=10, | |
acceleration=self.acceleration, | |
do_add_noise=self.do_add_noise, | |
enable_similar_image_filter=self.enable_similar_image_filter, | |
similar_image_filter_threshold=self.similar_image_filter_threshold, | |
similar_image_filter_max_skip_frame=self.similar_image_filter_max_skip_frame, | |
mode="img2img", | |
use_denoising_batch=self.use_denoising_batch, | |
cfg_type=self.cfg_type, | |
seed=self.seed, | |
use_lcm_lora=False, | |
) | |
self.img2img_model.prepare( | |
prompt=self.prompt, | |
negative_prompt=self.negative_prompt, | |
num_inference_steps=50, | |
guidance_scale=self.guidance_scale, | |
delta=self.delta, | |
) | |
# PIA | |
self.img2video_model = I2VPipeline.build_pipeline( | |
self.i2v_config, | |
self.i2v_body_ckpt, | |
self.i2v_unet_path, | |
self.i2v_dreambooth_ckpt, | |
None, # lora path | |
self.lora_alpha, | |
) | |
if torch.cuda.is_available(): | |
device = 'cuda' | |
else: | |
device = 'cpu' | |
self.img2video_generator = torch.Generator(device=device) | |
self.img2video_generator.manual_seed(self.i2v_config.generate.global_seed) | |
self.sim_ranges = self.i2v_config.validation_data.mask_sim_range | |
# Drag GAN | |
self.drag_model = Renderer(disable_timing=True) | |
def generate_image(self, image, text, start_time=None): | |
if text is not None: | |
pos_prompt, neg_prompt = text | |
self.img2img_model.prepare( | |
prompt=pos_prompt, | |
negative_prompt=neg_prompt, | |
num_inference_steps=50, | |
guidance_scale=self.guidance_scale, | |
delta=self.delta, | |
) | |
sampled_inputs = [image] | |
input_batch = torch.cat(sampled_inputs) | |
output_images = self.img2img_model.stream( | |
input_batch.to(device=self.img2img_model.device, dtype=self.img2img_model.dtype) | |
) | |
# if start_time is not None: | |
# print('Generate Done: {}'.format(time.perf_counter() - start_time)) | |
output_images = output_images.cpu() | |
# if start_time is not None: | |
# print('Move Done: {}'.format(time.perf_counter() - start_time)) | |
return output_images | |
def generate_video(self, image, text, height=None, width=None): | |
pos_prompt, neg_prompt = text | |
sim_range = self.sim_ranges[0] | |
print(f"using sim_range : {sim_range}") | |
self.i2v_config.validation_data.mask_sim_range = sim_range | |
sample = self.img2video_model( | |
image = image, | |
prompt = pos_prompt, | |
generator = self.img2video_generator, | |
video_length = self.i2v_config.generate.video_length, | |
height = height if height is not None else self.i2v_config.generate.sample_height, | |
width = width if width is not None else self.i2v_config.generate.sample_width, | |
negative_prompt = neg_prompt, | |
mask_sim_template_idx = self.i2v_config.validation_data.mask_sim_range, | |
**self.i2v_config.validation_data, | |
).videos | |
return sample | |
def prepare_drag_model( | |
self, | |
custom_image: Image, | |
latent_space = 'w+', | |
trunc_psi = 0.7, | |
trunc_cutoff = None, | |
seed = 0, | |
lr = 0.001, | |
generator_params = dnnlib.EasyDict(), | |
pretrained_weight = 'stylegan2_lions_512_pytorch', | |
): | |
self.drag_model.init_network( | |
generator_params, # res | |
pretrained_weight, # pkl | |
seed, # w0_seed, | |
None, # w_load | |
latent_space == 'w+', # w_plus | |
'const', | |
trunc_psi, # trunc_psi, | |
trunc_cutoff, # trunc_cutoff, | |
None, # input_transform | |
lr # lr, | |
) | |
if torch.cuda.is_available(): | |
percept = PerceptualLoss(model="net-lin", net="vgg", use_gpu=True) | |
else: | |
percept = PerceptualLoss(model="net-lin", net="vgg", use_gpu=False) | |
pti = PTI(self.drag_model.G, percept, max_pti_step=400) | |
inversed_img, w_pivot = pti.train(custom_image, latent_space == 'w+') | |
inversed_img = (inversed_img[0] * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0) | |
inversed_img = inversed_img.cpu().numpy() | |
inversed_img = Image.fromarray(inversed_img) | |
mask = np.ones((inversed_img.height, inversed_img.width), | |
dtype=np.uint8) | |
generator_params.image = inversed_img | |
generator_params.w = w_pivot.detach().cpu().numpy() | |
self.drag_model.set_latent(w_pivot, trunc_psi, trunc_cutoff) | |
del percept | |
del pti | |
print('inverse end') | |
return generator_params, mask | |
def drag_image( | |
self, | |
points, | |
mask, | |
motion_lambda = 20, | |
r1_in_pixels = 3, | |
r2_in_pixels = 12, | |
trunc_psi = 0.7, | |
draw_interval = 1, | |
generator_params = dnnlib.EasyDict(), | |
): | |
p_in_pixels = [] | |
t_in_pixels = [] | |
valid_points = [] | |
# Transform the points into torch tensors | |
for key_point, point in points.items(): | |
try: | |
p_start = point.get("start_temp", point["start"]) | |
p_end = point["target"] | |
if p_start is None or p_end is None: | |
continue | |
except KeyError: | |
continue | |
p_in_pixels.append(p_start) | |
t_in_pixels.append(p_end) | |
valid_points.append(key_point) | |
mask = torch.tensor(mask).float() | |
drag_mask = 1 - mask | |
# reverse points order | |
p_to_opt = reverse_point_pairs(p_in_pixels) | |
t_to_opt = reverse_point_pairs(t_in_pixels) | |
step_idx = 0 | |
self.drag_model._render_drag_impl( | |
generator_params, | |
p_to_opt, # point | |
t_to_opt, # target | |
drag_mask, # mask, | |
motion_lambda, # lambda_mask | |
reg = 0, | |
feature_idx = 5, # NOTE: do not support change for now | |
r1 = r1_in_pixels, # r1 | |
r2 = r2_in_pixels, # r2 | |
# random_seed = 0, | |
# noise_mode = 'const', | |
trunc_psi = trunc_psi, | |
# force_fp32 = False, | |
# layer_name = None, | |
# sel_channels = 3, | |
# base_channel = 0, | |
# img_scale_db = 0, | |
# img_normalize = False, | |
# untransform = False, | |
is_drag=True, | |
to_pil=True | |
) | |
points_upd = points | |
if step_idx % draw_interval == 0: | |
for key_point, p_i, t_i in zip(valid_points, p_to_opt, | |
t_to_opt): | |
points_upd[key_point]["start_temp"] = [ | |
p_i[1], | |
p_i[0], | |
] | |
points_upd[key_point]["target"] = [ | |
t_i[1], | |
t_i[0], | |
] | |
start_temp = points_upd[key_point][ | |
"start_temp"] | |
image_result = generator_params['image'] | |
return image_result | |
# ============================= Model Implementation ENd =================================== | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--share', action='store_true',default='True') | |
parser.add_argument('--cache-dir', type=str, default='./checkpoints') | |
parser.add_argument( | |
"--listen", | |
action="store_true", | |
help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests", | |
) | |
args = parser.parse_args() | |
class CustomImageMask(gr.Image): | |
is_template = True | |
def __init__( | |
self, | |
source='upload', | |
tool='sketch', | |
elem_id="image_upload", | |
label='Generated Image', | |
type="pil", | |
mask_opacity=0.5, | |
brush_color='#FFFFFF', | |
height=400, | |
interactive=True, | |
**kwargs | |
): | |
super(CustomImageMask, self).__init__( | |
source=source, | |
tool=tool, | |
elem_id=elem_id, | |
label=label, | |
type=type, | |
mask_opacity=mask_opacity, | |
brush_color=brush_color, | |
height=height, | |
interactive=interactive, | |
**kwargs | |
) | |
def preprocess(self, x): | |
if x is None: | |
return x | |
if self.tool == 'sketch' and self.source in ['upload', 'webcam'] and type(x) != dict: | |
decode_image = gr.processing_utils.decode_base64_to_image(x) | |
width, height = decode_image.size | |
mask = np.ones((height, width, 4), dtype=np.uint8) | |
mask[..., -1] = 255 | |
mask = self.postprocess(mask) | |
x = {'image': x, 'mask': mask} | |
return super().preprocess(x) | |
draggan_ckpts = os.listdir('checkpoints/drag') | |
draggan_ckpts.sort() | |
generate_pipeline = GeneratePipeline() | |
generate_pipeline.init_model() | |
with gr.Blocks() as demo: | |
global_state = gr.State( | |
{ | |
'is_image_generation': True, | |
'is_image_text_prompt_up-to-date': True, | |
'is_show_mask': False, | |
'is_dragging': False, | |
'generated_image': None, | |
'generated_image_show': None, | |
'drag_markers': [ | |
{ | |
'points': {}, | |
'mask': None | |
} | |
], | |
'generator_params': dnnlib.EasyDict(), | |
'default_image_text_prompts': ('', 'low quality, bad quality, blurry, low resolution'), | |
'default_video_text_prompts': ('', 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg'), | |
'image_text_prompts': ('', 'low quality, bad quality, blurry, low resolution'), | |
'video_text_prompts': ('', 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg'), | |
'params': { | |
'seed': 0, | |
'motion_lambda': 20, | |
'r1_in_pixels': 3, | |
'r2_in_pixels': 12, | |
'magnitude_direction_in_pixels': 1.0, | |
'latent_space': 'w+', | |
'trunc_psi': 0.7, | |
'trunc_cutoff': None, | |
'lr': 0.001, | |
}, | |
'device': None, # device, | |
'draw_interval': 1, | |
'points': {}, | |
'curr_point': None, | |
'curr_type_point': 'start', | |
'editing_state': 'add_points', | |
'pretrained_weight': draggan_ckpts[0], | |
'video_preview_resolution': '512 x 512', | |
'viewer_height': 300, | |
'viewer_width': 300 | |
} | |
) | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(scale=8, min_width=10): | |
with gr.Tab('Image Text Prompts'): | |
image_pos_text_prompt_editor = gr.Textbox(placeholder='Positive Prompts', label='Positive', min_width=10) | |
image_neg_text_prompt_editor = gr.Textbox(placeholder='Negative Prompts', label='Negative', min_width=10) | |
with gr.Tab('Video Text Prompts'): | |
video_pos_text_prompt_editor = gr.Textbox(placeholder='Positive Prompts', label='Positive', min_width=10) | |
video_neg_text_prompt_editor = gr.Textbox(placeholder='Negative Prompts', label='Negative', min_width=10) | |
with gr.Tab('Drag Image'): | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=10): | |
drag_mode_on_button = gr.Button('Drag Mode On', size='sm', min_width=10) | |
drag_mode_off_button = gr.Button('Drag Mode Off', size='sm', min_width=10) | |
drag_checkpoint_dropdown = gr.Dropdown(choices=draggan_ckpts, value=draggan_ckpts[0], label='checkpoint', min_width=10) | |
with gr.Column(scale=1, min_width=10): | |
with gr.Row(): | |
drag_start_button = gr.Button('start', size='sm', min_width=10) | |
drag_stop_button = gr.Button('stop', size='sm', min_width=10) | |
with gr.Row(): | |
add_point_button = gr.Button('add point', size='sm', min_width=10) | |
reset_point_button = gr.Button('reset point', size='sm', min_width=10) | |
with gr.Row(): | |
steps_number = gr.Number(0, label='steps', interactive=False) | |
with gr.Column(scale=1, min_width=10): | |
with gr.Row(): | |
draw_mask_button = gr.Button('draw mask', size='sm', min_width=10) | |
reset_mask_button = gr.Button('reset mask', size='sm', min_width=10) | |
with gr.Row(): | |
show_mask_checkbox = gr.Checkbox(value=False, label='show mask', min_width=10, interactive=True) | |
with gr.Row(): | |
motion_lambda_number = gr.Number(20, label='Motion Lambda', minimum=1, maximum=100, step=1, interactive=True) | |
with gr.Tab('More'): | |
with gr.Row(): | |
with gr.Column(scale=2, min_width=10): | |
video_preview_resolution_dropdown = gr.Dropdown(choices=['256 x 256', '512 x 512'], value='512 x 512', label='Video Preview Resolution', min_width=10) | |
sample_image_dropdown = gr.Dropdown(choices=['samples/canvas.jpg'] + ['samples/sample{:>02d}.jpg'.format(i) for i in range(1, 8)], value=None, label='Choose A Sample Image', min_width=10) | |
with gr.Column(scale=1, min_width=10): | |
confirm_text_button = gr.Button('Confirm Text', size='sm', min_width=10) | |
generate_video_button = gr.Button('Generate Video', size='sm', min_width=10) | |
clear_video_button = gr.Button('Clear Video', size='sm', min_width=10) | |
with gr.Row(): | |
captured_image_viewer = gr.Image(source='upload', tool='color-sketch', type='pil', label='Image Drawer', height=global_state.value['viewer_height'], width=global_state.value['viewer_width'], interactive=True, shape=(global_state.value['viewer_width'], global_state.value['viewer_height'])) # | |
generated_image_viewer = CustomImageMask(source='upload', tool='sketch', elem_id="image_upload", label='Generated Image', type="pil", mask_opacity=0.5, brush_color='#FFFFFF', height=global_state.value['viewer_height'], width=global_state.value['viewer_width'], interactive=True) | |
generated_video_viewer = gr.Video(source='upload', label='Generated Video', height=global_state.value['viewer_height'], width=global_state.value['viewer_width'], interactive=False) | |
gr.Markdown( | |
""" | |
## Quick Start | |
1. Select one sample image in `More` tab. | |
2. Draw to edit the sample image in the left most image viewer. | |
3. Click `Generate Video` and enjoy it! | |
## Note | |
Due to the limitation of gradio implementation, the image-to-image generation might have a large latency after the model generation is done. | |
We command you to enjoy a better experience with our local demo at [github](https://github.com/invictus717/InteractiveVideo). | |
## Advance Usage | |
1. **Try different text prompts.** Enter positive or negative prompts for image / video generation, and | |
click `Confirm Text` to enable your prompts. | |
2. **Drag images.** Go to `Drag Image` tab, choose a suitable checkpoint and click `Drag Mode On`. | |
It might take a minute to prepare. Properly add points and use masks, then click `start` to | |
start dragging. Once you think it's ok, click `stop` button. | |
3. **Adjust video resolution** in the `More` tab. | |
4. **Draw from scratch** by choosing `canvas.jpg` in `More` tab and enjoy yourself! | |
""" | |
) | |
# ========================= Main Function Start ============================= | |
def on_captured_image_viewer_update(state, image): | |
if image is None: | |
return state, gr.Image.update(None) | |
if state['is_image_text_prompt_up-to-date']: | |
text_prompts = None | |
else: | |
text_prompts = state['image_text_prompts'] | |
state['is_image_text_prompt_up-to-date'] = True | |
# start_time = time.perf_counter() | |
input_image = np.array(image).astype(np.float32) | |
input_image = (input_image / 255 - 0.5) * 2 | |
input_image = torch.tensor(input_image).permute([2, 0, 1]) | |
noisy_image = torch.randn_like(input_image) | |
# print('preprocess done: {}'.format(time.perf_counter() - start_time)) | |
output_image = generate_pipeline.generate_image( | |
input_image, | |
text_prompts, | |
# start_time, | |
)[0] | |
output_image = generate_pipeline.generate_image( | |
noisy_image, | |
None, | |
# start_time, | |
)[0] # TODO: is there more elegant way? | |
output_image = output_image.permute([1, 2, 0]) | |
output_image = (output_image / 2 + 0.5).clamp(0, 1) * 255 | |
output_image = output_image.to(torch.uint8).cpu().numpy() | |
output_image = Image.fromarray(output_image) | |
# print('postprocess done: {}'.format(time.perf_counter() - start_time)) | |
# output_image = image | |
state['generated_image'] = output_image | |
output_image = update_state_image(state) | |
# print('draw done: {}'.format(time.perf_counter() - start_time)) | |
return state, gr.Image.update(output_image, interactive=False) | |
captured_image_viewer.change( | |
fn=on_captured_image_viewer_update, | |
inputs=[global_state, captured_image_viewer], | |
outputs=[global_state, generated_image_viewer] | |
) | |
def on_generated_image_viewer_edit(state, data_dict): | |
mask = data_dict['mask'] | |
state['drag_markers'][0]['mask'] = np.array(mask)[:, :, 0] // 255 | |
image = update_state_image(state) | |
return state, image | |
generated_image_viewer.edit( | |
fn=on_generated_image_viewer_edit, | |
inputs=[global_state, generated_image_viewer], | |
outputs=[global_state, generated_image_viewer] | |
) | |
def on_generate_video_click(state): | |
input_image = np.array(state['generated_image']) | |
text_prompts = state['video_text_prompts'] | |
video_preview_resolution = state['video_preview_resolution'].split('x') | |
height = int(video_preview_resolution[0].strip(' ')) | |
width = int(video_preview_resolution[1].strip(' ')) | |
output_video = generate_pipeline.generate_video( | |
input_image, | |
text_prompts, | |
height = height, | |
width = width | |
)[0] | |
output_video = output_video.clamp(0, 1) * 255 | |
output_video = output_video.to(torch.uint8) | |
# 3 T H W | |
print('[video generation done]') | |
fps = 5 # frames per second | |
video_size = (height, width) | |
fourcc = cv2.VideoWriter.fourcc(*'mp4v') | |
if not os.access('results', os.F_OK): | |
os.makedirs('results') | |
video_writer = cv2.VideoWriter('results/gradio_temp.mp4', fourcc, fps, video_size) # Create VideoWriter object | |
for i in range(output_video.shape[1]): | |
frame = output_video[:, i, :, :].permute([1, 2, 0]).cpu().numpy() | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
video_writer.write(frame) | |
video_writer.release() | |
return state, gr.Video.update('results/gradio_temp.mp4') | |
generate_video_button.click( | |
fn=on_generate_video_click, | |
inputs=[global_state], | |
outputs=[global_state, generated_video_viewer] | |
) | |
def on_clear_video_click(state): | |
return state, gr.Video.update(None) | |
clear_video_button.click( | |
fn=on_clear_video_click, | |
inputs=[global_state], | |
outputs=[global_state, generated_video_viewer] | |
) | |
def on_drag_mode_on_click(state): | |
# prepare DragGAN for custom image | |
custom_image = state['generated_image'] | |
current_ckpt_name = state['pretrained_weight'] | |
generate_pipeline.prepare_drag_model( | |
custom_image, | |
generator_params = state['generator_params'], | |
pretrained_weight = os.path.join('checkpoints/drag/', current_ckpt_name), | |
) | |
state['generated_image'] = state['generator_params'].image | |
view_image = update_state_image(state) | |
return state, gr.Image.update(view_image, interactive=True) | |
drag_mode_on_button.click( | |
fn=on_drag_mode_on_click, | |
inputs=[global_state], | |
outputs=[global_state, generated_image_viewer] | |
) | |
def on_drag_mode_off_click(state, image): | |
return on_captured_image_viewer_update(state, image) | |
drag_mode_off_button.click( | |
fn=on_drag_mode_off_click, | |
inputs=[global_state, captured_image_viewer], | |
outputs=[global_state, generated_image_viewer] | |
) | |
def on_drag_start_click(state): | |
state['is_dragging'] = True | |
points = state['drag_markers'][0]['points'] | |
if state['drag_markers'][0]['mask'] is None: | |
mask = np.ones((state['generator_params'].image.height, state['generator_params'].image.width), dtype=np.uint8) | |
else: | |
mask = state['drag_markers'][0]['mask'] | |
cur_step = 0 | |
while True: | |
if not state['is_dragging']: | |
break | |
generated_image = generate_pipeline.drag_image( | |
points, | |
mask, | |
motion_lambda = state['params']['motion_lambda'], | |
generator_params = state['generator_params'] | |
) | |
state['drag_markers'] = [{'points': points, 'mask': mask}] | |
state['generated_image'] = generated_image | |
cur_step += 1 | |
view_image = update_state_image(state) | |
if cur_step % 50 == 0: | |
print('[{} / {}]'.format(cur_step, 'inf')) | |
yield ( | |
state, | |
gr.Image.update(view_image, interactive=False), # generated image viewer | |
gr.Number.update(cur_step), # step | |
) | |
view_image = update_state_image(state) | |
return ( | |
state, | |
gr.Image.update(view_image, interactive=True), | |
gr.Number.update(cur_step), | |
) | |
drag_start_button.click( | |
fn=on_drag_start_click, | |
inputs=[global_state], | |
outputs=[global_state, generated_image_viewer, steps_number] | |
) | |
def on_drag_stop_click(state): | |
state['is_dragging'] = False | |
return state | |
drag_stop_button.click( | |
fn=on_drag_stop_click, | |
inputs=[global_state], | |
outputs=[global_state] | |
) | |
# ========================= Main Function End ============================= | |
# ====================== Update Text Prompts Start ==================== | |
def on_image_pos_text_prompt_editor_submit(state, text): | |
if len(text) == 0: | |
temp = state['image_text_prompts'] | |
state['image_text_prompts'] = (state['default_image_text_prompts'][0], temp[1]) | |
else: | |
temp = state['image_text_prompts'] | |
state['image_text_prompts'] = (text, temp[1]) | |
state['is_image_text_prompt_up-to-date'] = False | |
return state | |
image_pos_text_prompt_editor.submit( | |
fn=on_image_pos_text_prompt_editor_submit, | |
inputs=[global_state, image_pos_text_prompt_editor], | |
outputs=None | |
) | |
def on_image_neg_text_prompt_editor_submit(state, text): | |
if len(text) == 0: | |
temp = state['image_text_prompts'] | |
state['image_text_prompts'] = (temp[0], state['default_image_text_prompts'][1]) | |
else: | |
temp = state['image_text_prompts'] | |
state['image_text_prompts'] = (temp[0], text) | |
state['is_image_text_prompt_up-to-date'] = False | |
return state | |
image_neg_text_prompt_editor.submit( | |
fn=on_image_neg_text_prompt_editor_submit, | |
inputs=[global_state, image_neg_text_prompt_editor], | |
outputs=None | |
) | |
def on_video_pos_text_prompt_editor_submit(state, text): | |
if len(text) == 0: | |
temp = state['video_text_prompts'] | |
state['video_text_prompts'] = (state['default_video_text_prompts'][0], temp[1]) | |
else: | |
temp = state['video_text_prompts'] | |
state['video_text_prompts'] = (text, temp[1]) | |
return state | |
video_pos_text_prompt_editor.submit( | |
fn=on_video_pos_text_prompt_editor_submit, | |
inputs=[global_state, video_pos_text_prompt_editor], | |
outputs=None | |
) | |
def on_video_neg_text_prompt_editor_submit(state, text): | |
if len(text) == 0: | |
temp = state['video_text_prompts'] | |
state['video_text_prompts'] = (temp[0], state['default_video_text_prompts'][1]) | |
else: | |
temp = state['video_text_prompts'] | |
state['video_text_prompts'] = (temp[0], text) | |
return state | |
video_neg_text_prompt_editor.submit( | |
fn=on_video_neg_text_prompt_editor_submit, | |
inputs=[global_state, video_neg_text_prompt_editor], | |
outputs=None | |
) | |
def on_confirm_text_click(state, image, img_pos_t, img_neg_t, vid_pos_t, vid_neg_t): | |
state = on_image_pos_text_prompt_editor_submit(state, img_pos_t) | |
state = on_image_neg_text_prompt_editor_submit(state, img_neg_t) | |
state = on_video_pos_text_prompt_editor_submit(state, vid_pos_t) | |
state = on_video_neg_text_prompt_editor_submit(state, vid_neg_t) | |
return on_captured_image_viewer_update(state, image) | |
confirm_text_button.click( | |
fn=on_confirm_text_click, | |
inputs=[global_state, captured_image_viewer, image_pos_text_prompt_editor, image_neg_text_prompt_editor, | |
video_pos_text_prompt_editor, video_neg_text_prompt_editor], | |
outputs=[global_state, generated_image_viewer] | |
) | |
# ====================== Update Text Prompts End ==================== | |
# ======================= Drag Point Edit Start ========================= | |
def on_image_clicked(state, evt: gr.SelectData): | |
""" | |
This function only support click for point selection | |
""" | |
pos_x, pos_y = evt.index | |
drag_markers = state['drag_markers'] | |
key_points = list(drag_markers[0]['points'].keys()) | |
key_points.sort(reverse=False) | |
if len(key_points) == 0: # no point pairs, add a new point pair | |
drag_markers[0]['points'][0] = { | |
'start_temp': [pos_x, pos_y], | |
'start': [pos_x, pos_y], | |
'target': None, | |
} | |
else: | |
largest_id = key_points[-1] | |
if drag_markers[0]['points'][largest_id]['target'] is None: # target is not set | |
drag_markers[0]['points'][largest_id]['target'] = [pos_x, pos_y] | |
else: # target is set, add a new point pair | |
drag_markers[0]['points'][largest_id + 1] = { | |
'start_temp': [pos_x, pos_y], | |
'start': [pos_x, pos_y], | |
'target': None, | |
} | |
state['drag_markers'] = drag_markers | |
image = update_state_image(state) | |
return state, gr.Image.update(image, interactive=False) | |
generated_image_viewer.select( | |
fn=on_image_clicked, | |
inputs=[global_state], | |
outputs=[global_state, generated_image_viewer], | |
) | |
def on_add_point_click(state): | |
return gr.Image.update(state['generated_image_show'], interactive=False) | |
add_point_button.click( | |
fn=on_add_point_click, | |
inputs=[global_state], | |
outputs=[generated_image_viewer] | |
) | |
def on_reset_point_click(state): | |
drag_markers = state['drag_markers'] | |
drag_markers[0]['points'] = {} | |
state['drag_markers'] = drag_markers | |
image = update_state_image(state) | |
return state, gr.Image.update(image) | |
reset_point_button.click( | |
fn=on_reset_point_click, | |
inputs=[global_state], | |
outputs=[global_state, generated_image_viewer] | |
) | |
# ======================= Drag Point Edit End ========================= | |
# ======================= Drag Mask Edit Start ========================= | |
def on_draw_mask_click(state): | |
return gr.Image.update(state['generated_image_show'], interactive=True) | |
draw_mask_button.click( | |
fn=on_draw_mask_click, | |
inputs=[global_state], | |
outputs=[generated_image_viewer] | |
) | |
def on_reset_mask_click(state): | |
drag_markers = state['drag_markers'] | |
drag_markers[0]['mask'] = np.ones_like(drag_markers[0]['mask']) | |
state['drag_markers'] = drag_markers | |
image = update_state_image(state) | |
return state, gr.Image.update(image) | |
reset_mask_button.click( | |
fn=on_reset_mask_click, | |
inputs=[global_state], | |
outputs=[global_state, generated_image_viewer] | |
) | |
def on_show_mask_click(state, evt: gr.SelectData): | |
state['is_show_mask'] = evt.selected | |
image = update_state_image(state) | |
return state, image | |
show_mask_checkbox.select( | |
fn=on_show_mask_click, | |
inputs=[global_state], | |
outputs=[global_state, generated_image_viewer] | |
) | |
# ======================= Drag Mask Edit End ========================= | |
# ======================= Drag Setting Start ========================= | |
def on_motion_lambda_change(state, number): | |
state['params']['number'] = number | |
return state | |
motion_lambda_number.input( | |
fn=on_motion_lambda_change, | |
inputs=[global_state, motion_lambda_number], | |
outputs=[global_state] | |
) | |
def on_drag_checkpoint_change(state, checkpoint): | |
state['pretrained_weight'] = checkpoint | |
print(type(checkpoint), checkpoint) | |
return state | |
drag_checkpoint_dropdown.change( | |
fn=on_drag_checkpoint_change, | |
inputs=[global_state, drag_checkpoint_dropdown], | |
outputs=[global_state] | |
) | |
# ======================= Drag Setting End ========================= | |
# ======================= General Setting Start ========================= | |
def on_video_preview_resolution_change(state, resolution): | |
state['video_preview_resolution'] = resolution | |
return state | |
video_preview_resolution_dropdown.change( | |
fn=on_video_preview_resolution_change, | |
inputs=[global_state, video_preview_resolution_dropdown], | |
outputs=[global_state] | |
) | |
def on_sample_image_change(state, image): | |
return state, gr.Image.update(image) | |
sample_image_dropdown.change( | |
fn=on_sample_image_change, | |
inputs=[global_state, sample_image_dropdown], | |
outputs=[global_state, captured_image_viewer] | |
) | |
# ======================= General Setting End ========================= | |
demo.queue(concurrency_count=3, max_size=20) | |
# demo.launch(share=False, server_name="0.0.0.0" if args.listen else "127.0.0.1") | |
demo.launch() |