anchorxia commited on
Commit
0133d2e
·
1 Parent(s): 7f7d86b

add mmcm, musev

Browse files
Files changed (6) hide show
  1. .gitmodules +6 -0
  2. MMCM +1 -0
  3. MuseV +1 -0
  4. app_gradio_space.py +32 -2
  5. gradio_text2video.py +0 -949
  6. gradio_video2video.py +0 -1039
.gitmodules ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [submodule "MMCM"]
2
+ path = MMCM
3
+ url = https://github.com/TMElyralab/MMCM.git
4
+ [submodule "MuseV"]
5
+ path = MuseV
6
+ url = https://github.com/TMElyralab/MuseV.git
MMCM ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 3a6ef08762f8bd9b50c09eb20402a13b74e839e3
MuseV ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit ddba6a41725db2f75f7b6dd91185cacb0fed556f
app_gradio_space.py CHANGED
@@ -14,10 +14,40 @@ import subprocess
14
 
15
  ProjectDir = os.path.abspath(os.path.dirname(__file__))
16
  CheckpointsDir = os.path.join(ProjectDir, "checkpoints")
17
-
18
- ignore_video2video = True
19
  max_image_edge = 960
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def download_model():
23
  if not os.path.exists(CheckpointsDir):
 
14
 
15
  ProjectDir = os.path.abspath(os.path.dirname(__file__))
16
  CheckpointsDir = os.path.join(ProjectDir, "checkpoints")
17
+ ignore_video2video = False
 
18
  max_image_edge = 960
19
 
20
+ sys.path.insert(0, ProjectDir)
21
+ sys.path.insert(0, f"{ProjectDir}/MMCM")
22
+ sys.path.insert(0, f"{ProjectDir}/diffusers/src")
23
+ sys.path.insert(0, f"{ProjectDir}/controlnet_aux/src")
24
+ sys.path.insert(0, f"{ProjectDir}/scripts/gradio")
25
+
26
+ result = subprocess.run(
27
+ ["pip", "install", "--no-cache-dir", "-U", "openmim"],
28
+ capture_output=True,
29
+ text=True,
30
+ )
31
+ print(result)
32
+
33
+ result = subprocess.run(["mim", "install", "mmengine"], capture_output=True, text=True)
34
+ print(result)
35
+
36
+ result = subprocess.run(
37
+ ["mim", "install", "mmcv>=2.0.1"], capture_output=True, text=True
38
+ )
39
+ print(result)
40
+
41
+ result = subprocess.run(
42
+ ["mim", "install", "mmdet>=3.1.0"], capture_output=True, text=True
43
+ )
44
+ print(result)
45
+
46
+ result = subprocess.run(
47
+ ["mim", "install", "mmpose>=1.1.0"], capture_output=True, text=True
48
+ )
49
+ print(result)
50
+
51
 
52
  def download_model():
53
  if not os.path.exists(CheckpointsDir):
gradio_text2video.py DELETED
@@ -1,949 +0,0 @@
1
- import argparse
2
- import copy
3
- import os
4
- from pathlib import Path
5
- import logging
6
- from collections import OrderedDict
7
- from pprint import pprint
8
- import random
9
- import gradio as gr
10
- from argparse import Namespace
11
-
12
- import numpy as np
13
- from omegaconf import OmegaConf, SCMode
14
- import torch
15
- from einops import rearrange, repeat
16
- import cv2
17
- from PIL import Image
18
- from diffusers.models.autoencoder_kl import AutoencoderKL
19
-
20
- from mmcm.utils.load_util import load_pyhon_obj
21
- from mmcm.utils.seed_util import set_all_seed
22
- from mmcm.utils.signature import get_signature_of_string
23
- from mmcm.utils.task_util import fiss_tasks, generate_tasks as generate_tasks_from_table
24
- from mmcm.vision.utils.data_type_util import is_video, is_image, read_image_as_5d
25
- from mmcm.utils.str_util import clean_str_for_save
26
- from mmcm.vision.data.video_dataset import DecordVideoDataset
27
- from musev.auto_prompt.util import generate_prompts
28
-
29
-
30
- from musev.models.facein_loader import load_facein_extractor_and_proj_by_name
31
- from musev.models.referencenet_loader import load_referencenet_by_name
32
- from musev.models.ip_adapter_loader import (
33
- load_ip_adapter_vision_clip_encoder_by_name,
34
- load_vision_clip_encoder_by_name,
35
- load_ip_adapter_image_proj_by_name,
36
- )
37
- from musev.models.ip_adapter_face_loader import (
38
- load_ip_adapter_face_extractor_and_proj_by_name,
39
- )
40
- from musev.pipelines.pipeline_controlnet_predictor import (
41
- DiffusersPipelinePredictor,
42
- )
43
- from musev.models.referencenet import ReferenceNet2D
44
- from musev.models.unet_loader import load_unet_by_name
45
- from musev.utils.util import save_videos_grid_with_opencv
46
- from musev import logger
47
-
48
- use_v2v_predictor = False
49
- if use_v2v_predictor:
50
- from gradio_video2video import sd_predictor as video_sd_predictor
51
-
52
- logger.setLevel("INFO")
53
-
54
- file_dir = os.path.dirname(__file__)
55
- PROJECT_DIR = os.path.join(os.path.dirname(__file__), "./")
56
- DATA_DIR = os.path.join(PROJECT_DIR, "data")
57
- CACHE_PATH = "./t2v_input_image"
58
-
59
-
60
- # TODO:use group to group arguments
61
-
62
-
63
- args_dict = {
64
- "add_static_video_prompt": False,
65
- "context_batch_size": 1,
66
- "context_frames": 12,
67
- "context_overlap": 4,
68
- "context_schedule": "uniform_v2",
69
- "context_stride": 1,
70
- "cross_attention_dim": 768,
71
- "face_image_path": None,
72
- "facein_model_cfg_path": os.path.join(PROJECT_DIR, "./configs/model/facein.py"),
73
- "facein_model_name": None,
74
- "facein_scale": 1.0,
75
- "fix_condition_images": False,
76
- "fixed_ip_adapter_image": True,
77
- "fixed_refer_face_image": True,
78
- "fixed_refer_image": True,
79
- "fps": 4,
80
- "guidance_scale": 7.5,
81
- "height": None,
82
- "img_length_ratio": 1.0,
83
- "img_weight": 0.001,
84
- "interpolation_factor": 1,
85
- "ip_adapter_face_model_cfg_path": os.path.join(
86
- PROJECT_DIR, "./configs/model/ip_adapter.py"
87
- ),
88
- "ip_adapter_face_model_name": None,
89
- "ip_adapter_face_scale": 1.0,
90
- "ip_adapter_model_cfg_path": os.path.join(
91
- PROJECT_DIR, "./configs/model/ip_adapter.py"
92
- ),
93
- "ip_adapter_model_name": "musev_referencenet",
94
- "ip_adapter_scale": 1.0,
95
- "ipadapter_image_path": None,
96
- "lcm_model_cfg_path": os.path.join(PROJECT_DIR, "./configs/model/lcm_model.py"),
97
- "lcm_model_name": None,
98
- "log_level": "INFO",
99
- "motion_speed": 8.0,
100
- "n_batch": 1,
101
- "n_cols": 3,
102
- "n_repeat": 1,
103
- "n_vision_condition": 1,
104
- "need_hist_match": False,
105
- "need_img_based_video_noise": True,
106
- "need_redraw": False,
107
- "negative_prompt": "V2",
108
- "negprompt_cfg_path": os.path.join(
109
- PROJECT_DIR, "./configs/model/negative_prompt.py"
110
- ),
111
- "noise_type": "video_fusion",
112
- "num_inference_steps": 30,
113
- "output_dir": "./results/",
114
- "overwrite": False,
115
- "prompt_only_use_image_prompt": False,
116
- "record_mid_video_latents": False,
117
- "record_mid_video_noises": False,
118
- "redraw_condition_image": False,
119
- "redraw_condition_image_with_facein": True,
120
- "redraw_condition_image_with_ip_adapter_face": True,
121
- "redraw_condition_image_with_ipdapter": True,
122
- "redraw_condition_image_with_referencenet": True,
123
- "referencenet_image_path": None,
124
- "referencenet_model_cfg_path": os.path.join(
125
- PROJECT_DIR, "./configs/model/referencenet.py"
126
- ),
127
- "referencenet_model_name": "musev_referencenet",
128
- "save_filetype": "mp4",
129
- "save_images": False,
130
- "sd_model_cfg_path": os.path.join(PROJECT_DIR, "./configs/model/T2I_all_model.py"),
131
- "sd_model_name": "majicmixRealv6Fp16",
132
- "seed": None,
133
- "strength": 0.8,
134
- "target_datas": "boy_dance2",
135
- "test_data_path": os.path.join(
136
- PROJECT_DIR, "./configs/infer/testcase_video_famous.yaml"
137
- ),
138
- "time_size": 24,
139
- "unet_model_cfg_path": os.path.join(PROJECT_DIR, "./configs/model/motion_model.py"),
140
- "unet_model_name": "musev_referencenet",
141
- "use_condition_image": True,
142
- "use_video_redraw": True,
143
- "vae_model_path": os.path.join(PROJECT_DIR, "./checkpoints/vae/sd-vae-ft-mse"),
144
- "video_guidance_scale": 3.5,
145
- "video_guidance_scale_end": None,
146
- "video_guidance_scale_method": "linear",
147
- "video_negative_prompt": "V2",
148
- "video_num_inference_steps": 10,
149
- "video_overlap": 1,
150
- "vision_clip_extractor_class_name": "ImageClipVisionFeatureExtractor",
151
- "vision_clip_model_path": os.path.join(
152
- PROJECT_DIR, "./checkpoints/IP-Adapter/models/image_encoder"
153
- ),
154
- "w_ind_noise": 0.5,
155
- "width": None,
156
- "write_info": False,
157
- }
158
- args = Namespace(**args_dict)
159
- print("args")
160
- pprint(args)
161
- print("\n")
162
-
163
- logger.setLevel(args.log_level)
164
- overwrite = args.overwrite
165
- cross_attention_dim = args.cross_attention_dim
166
- time_size = args.time_size # 一次视频生成的帧数
167
- n_batch = args.n_batch # 按照time_size的尺寸 生成n_batch次,总帧数 = time_size * n_batch
168
- fps = args.fps
169
- # need_redraw = args.need_redraw # 视频重绘视频使用视频网络
170
- # use_video_redraw = args.use_video_redraw # 视频重绘视频使用视频网络
171
- fix_condition_images = args.fix_condition_images
172
- use_condition_image = args.use_condition_image # 当 test_data 中有图像时,作为初始图像
173
- redraw_condition_image = args.redraw_condition_image # 用于视频生成的首帧是否使用重绘后的
174
- need_img_based_video_noise = (
175
- args.need_img_based_video_noise
176
- ) # 视频加噪过程中是否使用首帧 condition_images
177
- img_weight = args.img_weight
178
- height = args.height # 如果测试数据中没有单独指定宽高,则默认这里
179
- width = args.width # 如果测试数据中没有单独指定宽高,则默认这里
180
- img_length_ratio = args.img_length_ratio # 如果测试数据中没有单独指定图像宽高比resize比例,则默认这里
181
- n_cols = args.n_cols
182
- noise_type = args.noise_type
183
- strength = args.strength # 首帧重绘程度参数
184
- video_guidance_scale = args.video_guidance_scale # 视频 condition与 uncond的权重参数
185
- guidance_scale = args.guidance_scale # 时序条件帧 condition与uncond的权重参数
186
- video_num_inference_steps = args.video_num_inference_steps # 视频迭代次数
187
- num_inference_steps = args.num_inference_steps # 时序条件帧 重绘参数
188
- seed = args.seed
189
- save_filetype = args.save_filetype
190
- save_images = args.save_images
191
- sd_model_cfg_path = args.sd_model_cfg_path
192
- sd_model_name = (
193
- args.sd_model_name
194
- if args.sd_model_name in ["all", "None"]
195
- else args.sd_model_name.split(",")
196
- )
197
- unet_model_cfg_path = args.unet_model_cfg_path
198
- unet_model_name = args.unet_model_name
199
- test_data_path = args.test_data_path
200
- target_datas = (
201
- args.target_datas if args.target_datas == "all" else args.target_datas.split(",")
202
- )
203
- device = "cuda" if torch.cuda.is_available() else "cpu"
204
- torch_dtype = torch.float16
205
- negprompt_cfg_path = args.negprompt_cfg_path
206
- video_negative_prompt = args.video_negative_prompt
207
- negative_prompt = args.negative_prompt
208
- motion_speed = args.motion_speed
209
- need_hist_match = args.need_hist_match
210
- video_guidance_scale_end = args.video_guidance_scale_end
211
- video_guidance_scale_method = args.video_guidance_scale_method
212
- add_static_video_prompt = args.add_static_video_prompt
213
- n_vision_condition = args.n_vision_condition
214
- lcm_model_cfg_path = args.lcm_model_cfg_path
215
- lcm_model_name = args.lcm_model_name
216
- referencenet_model_cfg_path = args.referencenet_model_cfg_path
217
- referencenet_model_name = args.referencenet_model_name
218
- ip_adapter_model_cfg_path = args.ip_adapter_model_cfg_path
219
- ip_adapter_model_name = args.ip_adapter_model_name
220
- vision_clip_model_path = args.vision_clip_model_path
221
- vision_clip_extractor_class_name = args.vision_clip_extractor_class_name
222
- facein_model_cfg_path = args.facein_model_cfg_path
223
- facein_model_name = args.facein_model_name
224
- ip_adapter_face_model_cfg_path = args.ip_adapter_face_model_cfg_path
225
- ip_adapter_face_model_name = args.ip_adapter_face_model_name
226
-
227
- fixed_refer_image = args.fixed_refer_image
228
- fixed_ip_adapter_image = args.fixed_ip_adapter_image
229
- fixed_refer_face_image = args.fixed_refer_face_image
230
- redraw_condition_image_with_referencenet = args.redraw_condition_image_with_referencenet
231
- redraw_condition_image_with_ipdapter = args.redraw_condition_image_with_ipdapter
232
- redraw_condition_image_with_facein = args.redraw_condition_image_with_facein
233
- redraw_condition_image_with_ip_adapter_face = (
234
- args.redraw_condition_image_with_ip_adapter_face
235
- )
236
- w_ind_noise = args.w_ind_noise
237
- ip_adapter_scale = args.ip_adapter_scale
238
- facein_scale = args.facein_scale
239
- ip_adapter_face_scale = args.ip_adapter_face_scale
240
- face_image_path = args.face_image_path
241
- ipadapter_image_path = args.ipadapter_image_path
242
- referencenet_image_path = args.referencenet_image_path
243
- vae_model_path = args.vae_model_path
244
- prompt_only_use_image_prompt = args.prompt_only_use_image_prompt
245
- # serial_denoise parameter start
246
- record_mid_video_noises = args.record_mid_video_noises
247
- record_mid_video_latents = args.record_mid_video_latents
248
- video_overlap = args.video_overlap
249
- # serial_denoise parameter end
250
- # parallel_denoise parameter start
251
- context_schedule = args.context_schedule
252
- context_frames = args.context_frames
253
- context_stride = args.context_stride
254
- context_overlap = args.context_overlap
255
- context_batch_size = args.context_batch_size
256
- interpolation_factor = args.interpolation_factor
257
- n_repeat = args.n_repeat
258
-
259
- # parallel_denoise parameter end
260
-
261
- b = 1
262
- negative_embedding = [
263
- [os.path.join(PROJECT_DIR, "./checkpoints/embedding/badhandv4.pt"), "badhandv4"],
264
- [
265
- os.path.join(PROJECT_DIR, "./checkpoints/embedding/ng_deepnegative_v1_75t.pt"),
266
- "ng_deepnegative_v1_75t",
267
- ],
268
- [
269
- os.path.join(PROJECT_DIR, "./checkpoints/embedding/EasyNegativeV2.safetensors"),
270
- "EasyNegativeV2",
271
- ],
272
- [
273
- os.path.join(PROJECT_DIR, "./checkpoints/embedding/bad_prompt_version2-neg.pt"),
274
- "bad_prompt_version2-neg",
275
- ],
276
- ]
277
- prefix_prompt = ""
278
- suffix_prompt = ", beautiful, masterpiece, best quality"
279
- suffix_prompt = ""
280
-
281
-
282
- # sd model parameters
283
-
284
- if sd_model_name != "None":
285
- # 使用 cfg_path 里的sd_model_path
286
- sd_model_params_dict_src = load_pyhon_obj(sd_model_cfg_path, "MODEL_CFG")
287
- sd_model_params_dict = {
288
- k: v
289
- for k, v in sd_model_params_dict_src.items()
290
- if sd_model_name == "all" or k in sd_model_name
291
- }
292
- else:
293
- # 使用命令行给的sd_model_path, 需要单独设置 sd_model_name 为None,
294
- sd_model_name = os.path.basename(sd_model_cfg_path).split(".")[0]
295
- sd_model_params_dict = {sd_model_name: {"sd": sd_model_cfg_path}}
296
- sd_model_params_dict_src = sd_model_params_dict
297
- if len(sd_model_params_dict) == 0:
298
- raise ValueError(
299
- "has not target model, please set one of {}".format(
300
- " ".join(list(sd_model_params_dict_src.keys()))
301
- )
302
- )
303
- print("running model, T2I SD")
304
- pprint(sd_model_params_dict)
305
-
306
- # lcm
307
- if lcm_model_name is not None:
308
- lcm_model_params_dict_src = load_pyhon_obj(lcm_model_cfg_path, "MODEL_CFG")
309
- print("lcm_model_params_dict_src")
310
- lcm_lora_dct = lcm_model_params_dict_src[lcm_model_name]
311
- else:
312
- lcm_lora_dct = None
313
- print("lcm: ", lcm_model_name, lcm_lora_dct)
314
-
315
-
316
- # motion net parameters
317
- if os.path.isdir(unet_model_cfg_path):
318
- unet_model_path = unet_model_cfg_path
319
- elif os.path.isfile(unet_model_cfg_path):
320
- unet_model_params_dict_src = load_pyhon_obj(unet_model_cfg_path, "MODEL_CFG")
321
- print("unet_model_params_dict_src", unet_model_params_dict_src.keys())
322
- unet_model_path = unet_model_params_dict_src[unet_model_name]["unet"]
323
- else:
324
- raise ValueError(f"expect dir or file, but given {unet_model_cfg_path}")
325
- print("unet: ", unet_model_name, unet_model_path)
326
-
327
-
328
- # referencenet
329
- if referencenet_model_name is not None:
330
- if os.path.isdir(referencenet_model_cfg_path):
331
- referencenet_model_path = referencenet_model_cfg_path
332
- elif os.path.isfile(referencenet_model_cfg_path):
333
- referencenet_model_params_dict_src = load_pyhon_obj(
334
- referencenet_model_cfg_path, "MODEL_CFG"
335
- )
336
- print(
337
- "referencenet_model_params_dict_src",
338
- referencenet_model_params_dict_src.keys(),
339
- )
340
- referencenet_model_path = referencenet_model_params_dict_src[
341
- referencenet_model_name
342
- ]["net"]
343
- else:
344
- raise ValueError(f"expect dir or file, but given {referencenet_model_cfg_path}")
345
- else:
346
- referencenet_model_path = None
347
- print("referencenet: ", referencenet_model_name, referencenet_model_path)
348
-
349
-
350
- # ip_adapter
351
- if ip_adapter_model_name is not None:
352
- ip_adapter_model_params_dict_src = load_pyhon_obj(
353
- ip_adapter_model_cfg_path, "MODEL_CFG"
354
- )
355
- print("ip_adapter_model_params_dict_src", ip_adapter_model_params_dict_src.keys())
356
- ip_adapter_model_params_dict = ip_adapter_model_params_dict_src[
357
- ip_adapter_model_name
358
- ]
359
- else:
360
- ip_adapter_model_params_dict = None
361
- print("ip_adapter: ", ip_adapter_model_name, ip_adapter_model_params_dict)
362
-
363
-
364
- # facein
365
- if facein_model_name is not None:
366
- facein_model_params_dict_src = load_pyhon_obj(facein_model_cfg_path, "MODEL_CFG")
367
- print("facein_model_params_dict_src", facein_model_params_dict_src.keys())
368
- facein_model_params_dict = facein_model_params_dict_src[facein_model_name]
369
- else:
370
- facein_model_params_dict = None
371
- print("facein: ", facein_model_name, facein_model_params_dict)
372
-
373
- # ip_adapter_face
374
- if ip_adapter_face_model_name is not None:
375
- ip_adapter_face_model_params_dict_src = load_pyhon_obj(
376
- ip_adapter_face_model_cfg_path, "MODEL_CFG"
377
- )
378
- print(
379
- "ip_adapter_face_model_params_dict_src",
380
- ip_adapter_face_model_params_dict_src.keys(),
381
- )
382
- ip_adapter_face_model_params_dict = ip_adapter_face_model_params_dict_src[
383
- ip_adapter_face_model_name
384
- ]
385
- else:
386
- ip_adapter_face_model_params_dict = None
387
- print(
388
- "ip_adapter_face: ", ip_adapter_face_model_name, ip_adapter_face_model_params_dict
389
- )
390
-
391
-
392
- # negative_prompt
393
- def get_negative_prompt(negative_prompt, cfg_path=None, n: int = 10):
394
- name = negative_prompt[:n]
395
- if cfg_path is not None and cfg_path not in ["None", "none"]:
396
- dct = load_pyhon_obj(cfg_path, "Negative_Prompt_CFG")
397
- negative_prompt = dct[negative_prompt]["prompt"]
398
-
399
- return name, negative_prompt
400
-
401
-
402
- negtive_prompt_length = 10
403
- video_negative_prompt_name, video_negative_prompt = get_negative_prompt(
404
- video_negative_prompt,
405
- cfg_path=negprompt_cfg_path,
406
- n=negtive_prompt_length,
407
- )
408
- negative_prompt_name, negative_prompt = get_negative_prompt(
409
- negative_prompt,
410
- cfg_path=negprompt_cfg_path,
411
- n=negtive_prompt_length,
412
- )
413
-
414
- print("video_negprompt", video_negative_prompt_name, video_negative_prompt)
415
- print("negprompt", negative_prompt_name, negative_prompt)
416
-
417
- output_dir = args.output_dir
418
- os.makedirs(output_dir, exist_ok=True)
419
-
420
-
421
- # test_data_parameters
422
- def load_yaml(path):
423
- tasks = OmegaConf.to_container(
424
- OmegaConf.load(path), structured_config_mode=SCMode.INSTANTIATE, resolve=True
425
- )
426
- return tasks
427
-
428
-
429
- # if test_data_path.endswith(".yaml"):
430
- # test_datas_src = load_yaml(test_data_path)
431
- # elif test_data_path.endswith(".csv"):
432
- # test_datas_src = generate_tasks_from_table(test_data_path)
433
- # else:
434
- # raise ValueError("expect yaml or csv, but given {}".format(test_data_path))
435
-
436
- # test_datas = [
437
- # test_data
438
- # for test_data in test_datas_src
439
- # if target_datas == "all" or test_data.get("name", None) in target_datas
440
- # ]
441
-
442
- # test_datas = fiss_tasks(test_datas)
443
- # test_datas = generate_prompts(test_datas)
444
-
445
- # n_test_datas = len(test_datas)
446
- # if n_test_datas == 0:
447
- # raise ValueError(
448
- # "n_test_datas == 0, set target_datas=None or set atleast one of {}".format(
449
- # " ".join(list(d.get("name", "None") for d in test_datas_src))
450
- # )
451
- # )
452
- # print("n_test_datas", n_test_datas)
453
- # # pprint(test_datas)
454
-
455
-
456
- def read_image(path):
457
- name = os.path.basename(path).split(".")[0]
458
- image = read_image_as_5d(path)
459
- return image, name
460
-
461
-
462
- def read_image_lst(path):
463
- images_names = [read_image(x) for x in path]
464
- images, names = zip(*images_names)
465
- images = np.concatenate(images, axis=2)
466
- name = "_".join(names)
467
- return images, name
468
-
469
-
470
- def read_image_and_name(path):
471
- if isinstance(path, str):
472
- path = [path]
473
- images, name = read_image_lst(path)
474
- return images, name
475
-
476
-
477
- if referencenet_model_name is not None and not use_v2v_predictor:
478
- referencenet = load_referencenet_by_name(
479
- model_name=referencenet_model_name,
480
- # sd_model=sd_model_path,
481
- # sd_model=os.path.join(PROJECT_DIR, "./checkpoints//Moore-AnimateAnyone/AnimateAnyone/reference_unet.pth",
482
- sd_referencenet_model=referencenet_model_path,
483
- cross_attention_dim=cross_attention_dim,
484
- )
485
- else:
486
- referencenet = None
487
- referencenet_model_name = "no"
488
-
489
- if vision_clip_extractor_class_name is not None and not use_v2v_predictor:
490
- vision_clip_extractor = load_vision_clip_encoder_by_name(
491
- ip_image_encoder=vision_clip_model_path,
492
- vision_clip_extractor_class_name=vision_clip_extractor_class_name,
493
- )
494
- logger.info(
495
- f"vision_clip_extractor, name={vision_clip_extractor_class_name}, path={vision_clip_model_path}"
496
- )
497
- else:
498
- vision_clip_extractor = None
499
- logger.info(f"vision_clip_extractor, None")
500
-
501
- if ip_adapter_model_name is not None and not use_v2v_predictor:
502
- ip_adapter_image_proj = load_ip_adapter_image_proj_by_name(
503
- model_name=ip_adapter_model_name,
504
- ip_image_encoder=ip_adapter_model_params_dict.get(
505
- "ip_image_encoder", vision_clip_model_path
506
- ),
507
- ip_ckpt=ip_adapter_model_params_dict["ip_ckpt"],
508
- cross_attention_dim=cross_attention_dim,
509
- clip_embeddings_dim=ip_adapter_model_params_dict["clip_embeddings_dim"],
510
- clip_extra_context_tokens=ip_adapter_model_params_dict[
511
- "clip_extra_context_tokens"
512
- ],
513
- ip_scale=ip_adapter_model_params_dict["ip_scale"],
514
- device=device,
515
- )
516
- else:
517
- ip_adapter_image_proj = None
518
- ip_adapter_model_name = "no"
519
-
520
- for model_name, sd_model_params in sd_model_params_dict.items():
521
- lora_dict = sd_model_params.get("lora", None)
522
- model_sex = sd_model_params.get("sex", None)
523
- model_style = sd_model_params.get("style", None)
524
- sd_model_path = sd_model_params["sd"]
525
- test_model_vae_model_path = sd_model_params.get("vae", vae_model_path)
526
-
527
- unet = (
528
- load_unet_by_name(
529
- model_name=unet_model_name,
530
- sd_unet_model=unet_model_path,
531
- sd_model=sd_model_path,
532
- # sd_model=os.path.join(PROJECT_DIR, "./checkpoints//Moore-AnimateAnyone/AnimateAnyone/denoising_unet.pth",
533
- cross_attention_dim=cross_attention_dim,
534
- need_t2i_facein=facein_model_name is not None,
535
- # facein 目前没参与训练,但在unet中定义了,载入相关参数会报错,所以用strict控制
536
- strict=not (facein_model_name is not None),
537
- need_t2i_ip_adapter_face=ip_adapter_face_model_name is not None,
538
- )
539
- if not use_v2v_predictor
540
- else None
541
- )
542
-
543
- if facein_model_name is not None and not use_v2v_predictor:
544
- (
545
- face_emb_extractor,
546
- facein_image_proj,
547
- ) = load_facein_extractor_and_proj_by_name(
548
- model_name=facein_model_name,
549
- ip_image_encoder=facein_model_params_dict["ip_image_encoder"],
550
- ip_ckpt=facein_model_params_dict["ip_ckpt"],
551
- cross_attention_dim=cross_attention_dim,
552
- clip_embeddings_dim=facein_model_params_dict["clip_embeddings_dim"],
553
- clip_extra_context_tokens=facein_model_params_dict[
554
- "clip_extra_context_tokens"
555
- ],
556
- ip_scale=facein_model_params_dict["ip_scale"],
557
- device=device,
558
- # facein目前没有参与unet中的训练,需要单独载入参数
559
- unet=unet,
560
- )
561
- else:
562
- face_emb_extractor = None
563
- facein_image_proj = None
564
-
565
- if ip_adapter_face_model_name is not None and not use_v2v_predictor:
566
- (
567
- ip_adapter_face_emb_extractor,
568
- ip_adapter_face_image_proj,
569
- ) = load_ip_adapter_face_extractor_and_proj_by_name(
570
- model_name=ip_adapter_face_model_name,
571
- ip_image_encoder=ip_adapter_face_model_params_dict["ip_image_encoder"],
572
- ip_ckpt=ip_adapter_face_model_params_dict["ip_ckpt"],
573
- cross_attention_dim=cross_attention_dim,
574
- clip_embeddings_dim=ip_adapter_face_model_params_dict[
575
- "clip_embeddings_dim"
576
- ],
577
- clip_extra_context_tokens=ip_adapter_face_model_params_dict[
578
- "clip_extra_context_tokens"
579
- ],
580
- ip_scale=ip_adapter_face_model_params_dict["ip_scale"],
581
- device=device,
582
- unet=unet, # ip_adapter_face 目前没有参与unet中的训练,需要单独载入参数
583
- )
584
- else:
585
- ip_adapter_face_emb_extractor = None
586
- ip_adapter_face_image_proj = None
587
-
588
- print("test_model_vae_model_path", test_model_vae_model_path)
589
-
590
- sd_predictor = (
591
- DiffusersPipelinePredictor(
592
- sd_model_path=sd_model_path,
593
- unet=unet,
594
- lora_dict=lora_dict,
595
- lcm_lora_dct=lcm_lora_dct,
596
- device=device,
597
- dtype=torch_dtype,
598
- negative_embedding=negative_embedding,
599
- referencenet=referencenet,
600
- ip_adapter_image_proj=ip_adapter_image_proj,
601
- vision_clip_extractor=vision_clip_extractor,
602
- facein_image_proj=facein_image_proj,
603
- face_emb_extractor=face_emb_extractor,
604
- vae_model=test_model_vae_model_path,
605
- ip_adapter_face_emb_extractor=ip_adapter_face_emb_extractor,
606
- ip_adapter_face_image_proj=ip_adapter_face_image_proj,
607
- )
608
- if not use_v2v_predictor
609
- else video_sd_predictor
610
- )
611
- if use_v2v_predictor:
612
- print(
613
- "text2video use video_sd_predictor, sd_predictor type is ",
614
- type(sd_predictor),
615
- )
616
- logger.debug(f"load sd_predictor"),
617
-
618
- # TODO:这里修改为gradio
619
- import cuid
620
-
621
-
622
- def generate_cuid():
623
- return cuid.cuid()
624
-
625
-
626
- def online_t2v_inference(
627
- prompt,
628
- image_np,
629
- seed,
630
- fps,
631
- w,
632
- h,
633
- video_len,
634
- img_edge_ratio: float = 1.0,
635
- progress=gr.Progress(track_tqdm=True),
636
- ):
637
- progress(0, desc="Starting...")
638
- # Save the uploaded image to a specified path
639
- if not os.path.exists(CACHE_PATH):
640
- os.makedirs(CACHE_PATH)
641
- image_cuid = generate_cuid()
642
-
643
- image_path = os.path.join(CACHE_PATH, f"{image_cuid}.jpg")
644
- image = Image.fromarray(image_np)
645
- image.save(image_path)
646
-
647
- time_size = int(video_len)
648
- test_data = {
649
- "name": image_cuid,
650
- "prompt": prompt,
651
- # 'video_path': None,
652
- "condition_images": image_path,
653
- "refer_image": image_path,
654
- "ipadapter_image": image_path,
655
- "height": h,
656
- "width": w,
657
- "img_length_ratio": img_edge_ratio,
658
- # 'style': 'anime',
659
- # 'sex': 'female'
660
- }
661
- batch = []
662
- texts = []
663
- print("\n test_data", test_data, model_name)
664
- test_data_name = test_data.get("name", test_data)
665
- prompt = test_data["prompt"]
666
- prompt = prefix_prompt + prompt + suffix_prompt
667
- prompt_hash = get_signature_of_string(prompt, length=5)
668
- test_data["prompt_hash"] = prompt_hash
669
- test_data_height = test_data.get("height", height)
670
- test_data_width = test_data.get("width", width)
671
- test_data_condition_images_path = test_data.get("condition_images", None)
672
- test_data_condition_images_index = test_data.get("condition_images_index", None)
673
- test_data_redraw_condition_image = test_data.get(
674
- "redraw_condition_image", redraw_condition_image
675
- )
676
- # read condition_image
677
- if (
678
- test_data_condition_images_path is not None
679
- and use_condition_image
680
- and (
681
- isinstance(test_data_condition_images_path, list)
682
- or (
683
- isinstance(test_data_condition_images_path, str)
684
- and is_image(test_data_condition_images_path)
685
- )
686
- )
687
- ):
688
- (
689
- test_data_condition_images,
690
- test_data_condition_images_name,
691
- ) = read_image_and_name(test_data_condition_images_path)
692
- condition_image_height = test_data_condition_images.shape[3]
693
- condition_image_width = test_data_condition_images.shape[4]
694
- logger.debug(
695
- f"test_data_condition_images use {test_data_condition_images_path}"
696
- )
697
- else:
698
- test_data_condition_images = None
699
- test_data_condition_images_name = "no"
700
- condition_image_height = None
701
- condition_image_width = None
702
- logger.debug(f"test_data_condition_images is None")
703
-
704
- # 当没有指定生成视频的宽高时,使用输入条件的宽高,优先使用 condition_image,低优使用 video
705
- if test_data_height in [None, -1]:
706
- test_data_height = condition_image_height
707
-
708
- if test_data_width in [None, -1]:
709
- test_data_width = condition_image_width
710
-
711
- test_data_img_length_ratio = float(
712
- test_data.get("img_length_ratio", img_length_ratio)
713
- )
714
- # 为了和video2video保持对齐,使用64而不是8作为宽、高最小粒度
715
- # test_data_height = int(test_data_height * test_data_img_length_ratio // 8 * 8)
716
- # test_data_width = int(test_data_width * test_data_img_length_ratio // 8 * 8)
717
- test_data_height = int(test_data_height * test_data_img_length_ratio // 64 * 64)
718
- test_data_width = int(test_data_width * test_data_img_length_ratio // 64 * 64)
719
- pprint(test_data)
720
- print(f"test_data_height={test_data_height}")
721
- print(f"test_data_width={test_data_width}")
722
- # continue
723
- test_data_style = test_data.get("style", None)
724
- test_data_sex = test_data.get("sex", None)
725
- # 如果使用|进行多参数任务设置时对应的字段是字符串类型,需要显式转换浮点数。
726
- test_data_motion_speed = float(test_data.get("motion_speed", motion_speed))
727
- test_data_w_ind_noise = float(test_data.get("w_ind_noise", w_ind_noise))
728
- test_data_img_weight = float(test_data.get("img_weight", img_weight))
729
- logger.debug(f"test_data_condition_images_path {test_data_condition_images_path}")
730
- logger.debug(f"test_data_condition_images_index {test_data_condition_images_index}")
731
- test_data_refer_image_path = test_data.get("refer_image", referencenet_image_path)
732
- test_data_ipadapter_image_path = test_data.get(
733
- "ipadapter_image", ipadapter_image_path
734
- )
735
- test_data_refer_face_image_path = test_data.get("face_image", face_image_path)
736
-
737
- if negprompt_cfg_path is not None:
738
- if "video_negative_prompt" in test_data:
739
- (
740
- test_data_video_negative_prompt_name,
741
- test_data_video_negative_prompt,
742
- ) = get_negative_prompt(
743
- test_data.get(
744
- "video_negative_prompt",
745
- ),
746
- cfg_path=negprompt_cfg_path,
747
- n=negtive_prompt_length,
748
- )
749
- else:
750
- test_data_video_negative_prompt_name = video_negative_prompt_name
751
- test_data_video_negative_prompt = video_negative_prompt
752
- if "negative_prompt" in test_data:
753
- (
754
- test_data_negative_prompt_name,
755
- test_data_negative_prompt,
756
- ) = get_negative_prompt(
757
- test_data.get(
758
- "negative_prompt",
759
- ),
760
- cfg_path=negprompt_cfg_path,
761
- n=negtive_prompt_length,
762
- )
763
- else:
764
- test_data_negative_prompt_name = negative_prompt_name
765
- test_data_negative_prompt = negative_prompt
766
- else:
767
- test_data_video_negative_prompt = test_data.get(
768
- "video_negative_prompt", video_negative_prompt
769
- )
770
- test_data_video_negative_prompt_name = test_data_video_negative_prompt[
771
- :negtive_prompt_length
772
- ]
773
- test_data_negative_prompt = test_data.get("negative_prompt", negative_prompt)
774
- test_data_negative_prompt_name = test_data_negative_prompt[
775
- :negtive_prompt_length
776
- ]
777
-
778
- # 准备 test_data_refer_image
779
- if referencenet is not None:
780
- if test_data_refer_image_path is None:
781
- test_data_refer_image = test_data_condition_images
782
- test_data_refer_image_name = test_data_condition_images_name
783
- logger.debug(f"test_data_refer_image use test_data_condition_images")
784
- else:
785
- test_data_refer_image, test_data_refer_image_name = read_image_and_name(
786
- test_data_refer_image_path
787
- )
788
- logger.debug(f"test_data_refer_image use {test_data_refer_image_path}")
789
- else:
790
- test_data_refer_image = None
791
- test_data_refer_image_name = "no"
792
- logger.debug(f"test_data_refer_image is None")
793
-
794
- # 准备 test_data_ipadapter_image
795
- if vision_clip_extractor is not None:
796
- if test_data_ipadapter_image_path is None:
797
- test_data_ipadapter_image = test_data_condition_images
798
- test_data_ipadapter_image_name = test_data_condition_images_name
799
-
800
- logger.debug(f"test_data_ipadapter_image use test_data_condition_images")
801
- else:
802
- (
803
- test_data_ipadapter_image,
804
- test_data_ipadapter_image_name,
805
- ) = read_image_and_name(test_data_ipadapter_image_path)
806
- logger.debug(
807
- f"test_data_ipadapter_image use f{test_data_ipadapter_image_path}"
808
- )
809
- else:
810
- test_data_ipadapter_image = None
811
- test_data_ipadapter_image_name = "no"
812
- logger.debug(f"test_data_ipadapter_image is None")
813
-
814
- # 准备 test_data_refer_face_image
815
- if facein_image_proj is not None or ip_adapter_face_image_proj is not None:
816
- if test_data_refer_face_image_path is None:
817
- test_data_refer_face_image = test_data_condition_images
818
- test_data_refer_face_image_name = test_data_condition_images_name
819
-
820
- logger.debug(f"test_data_refer_face_image use test_data_condition_images")
821
- else:
822
- (
823
- test_data_refer_face_image,
824
- test_data_refer_face_image_name,
825
- ) = read_image_and_name(test_data_refer_face_image_path)
826
- logger.debug(
827
- f"test_data_refer_face_image use f{test_data_refer_face_image_path}"
828
- )
829
- else:
830
- test_data_refer_face_image = None
831
- test_data_refer_face_image_name = "no"
832
- logger.debug(f"test_data_refer_face_image is None")
833
-
834
- # # 当模型的sex、style与test_data同时存在且不相等时,就跳过这个测试用例
835
- # if (
836
- # model_sex is not None
837
- # and test_data_sex is not None
838
- # and model_sex != test_data_sex
839
- # ) or (
840
- # model_style is not None
841
- # and test_data_style is not None
842
- # and model_style != test_data_style
843
- # ):
844
- # print("model doesnt match test_data")
845
- # print("model name: ", model_name)
846
- # print("test_data: ", test_data)
847
- # continue
848
- if add_static_video_prompt:
849
- test_data_video_negative_prompt = "static video, {}".format(
850
- test_data_video_negative_prompt
851
- )
852
- for i_num in range(n_repeat):
853
- test_data_seed = random.randint(0, 1e8) if seed in [None, -1] else seed
854
- cpu_generator, gpu_generator = set_all_seed(int(test_data_seed))
855
- save_file_name = (
856
- f"m={model_name}_rm={referencenet_model_name}_case={test_data_name}"
857
- f"_w={test_data_width}_h={test_data_height}_t={time_size}_nb={n_batch}"
858
- f"_s={test_data_seed}_p={prompt_hash}"
859
- f"_w={test_data_img_weight}"
860
- f"_ms={test_data_motion_speed}"
861
- f"_s={strength}_g={video_guidance_scale}"
862
- f"_c-i={test_data_condition_images_name[:5]}_r-c={test_data_redraw_condition_image}"
863
- f"_w={test_data_w_ind_noise}_{test_data_video_negative_prompt_name}"
864
- f"_r={test_data_refer_image_name[:3]}_ip={test_data_refer_image_name[:3]}_f={test_data_refer_face_image_name[:3]}"
865
- )
866
-
867
- save_file_name = clean_str_for_save(save_file_name)
868
- output_path = os.path.join(
869
- output_dir,
870
- f"{save_file_name}.{save_filetype}",
871
- )
872
- if os.path.exists(output_path) and not overwrite:
873
- print("existed", output_path)
874
- continue
875
-
876
- print("output_path", output_path)
877
- out_videos = sd_predictor.run_pipe_text2video(
878
- video_length=time_size,
879
- prompt=prompt,
880
- width=test_data_width,
881
- height=test_data_height,
882
- generator=gpu_generator,
883
- noise_type=noise_type,
884
- negative_prompt=test_data_negative_prompt,
885
- video_negative_prompt=test_data_video_negative_prompt,
886
- max_batch_num=n_batch,
887
- strength=strength,
888
- need_img_based_video_noise=need_img_based_video_noise,
889
- video_num_inference_steps=video_num_inference_steps,
890
- condition_images=test_data_condition_images,
891
- fix_condition_images=fix_condition_images,
892
- video_guidance_scale=video_guidance_scale,
893
- guidance_scale=guidance_scale,
894
- num_inference_steps=num_inference_steps,
895
- redraw_condition_image=test_data_redraw_condition_image,
896
- img_weight=test_data_img_weight,
897
- w_ind_noise=test_data_w_ind_noise,
898
- n_vision_condition=n_vision_condition,
899
- motion_speed=test_data_motion_speed,
900
- need_hist_match=need_hist_match,
901
- video_guidance_scale_end=video_guidance_scale_end,
902
- video_guidance_scale_method=video_guidance_scale_method,
903
- vision_condition_latent_index=test_data_condition_images_index,
904
- refer_image=test_data_refer_image,
905
- fixed_refer_image=fixed_refer_image,
906
- redraw_condition_image_with_referencenet=redraw_condition_image_with_referencenet,
907
- ip_adapter_image=test_data_ipadapter_image,
908
- refer_face_image=test_data_refer_face_image,
909
- fixed_refer_face_image=fixed_refer_face_image,
910
- facein_scale=facein_scale,
911
- redraw_condition_image_with_facein=redraw_condition_image_with_facein,
912
- ip_adapter_face_scale=ip_adapter_face_scale,
913
- redraw_condition_image_with_ip_adapter_face=redraw_condition_image_with_ip_adapter_face,
914
- fixed_ip_adapter_image=fixed_ip_adapter_image,
915
- ip_adapter_scale=ip_adapter_scale,
916
- redraw_condition_image_with_ipdapter=redraw_condition_image_with_ipdapter,
917
- prompt_only_use_image_prompt=prompt_only_use_image_prompt,
918
- # need_redraw=need_redraw,
919
- # use_video_redraw=use_video_redraw,
920
- # serial_denoise parameter start
921
- record_mid_video_noises=record_mid_video_noises,
922
- record_mid_video_latents=record_mid_video_latents,
923
- video_overlap=video_overlap,
924
- # serial_denoise parameter end
925
- # parallel_denoise parameter start
926
- context_schedule=context_schedule,
927
- context_frames=context_frames,
928
- context_stride=context_stride,
929
- context_overlap=context_overlap,
930
- context_batch_size=context_batch_size,
931
- interpolation_factor=interpolation_factor,
932
- # parallel_denoise parameter end
933
- )
934
- out = np.concatenate([out_videos], axis=0)
935
- texts = ["out"]
936
- save_videos_grid_with_opencv(
937
- out,
938
- output_path,
939
- texts=texts,
940
- fps=fps,
941
- tensor_order="b c t h w",
942
- n_cols=n_cols,
943
- write_info=args.write_info,
944
- save_filetype=save_filetype,
945
- save_images=save_images,
946
- )
947
- print("Save to", output_path)
948
- print("\n" * 2)
949
- return output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_video2video.py DELETED
@@ -1,1039 +0,0 @@
1
- import argparse
2
- import copy
3
- import os
4
- from pathlib import Path
5
- import logging
6
- from collections import OrderedDict
7
- from pprint import pprint
8
- import random
9
- import gradio as gr
10
-
11
- import numpy as np
12
- from omegaconf import OmegaConf, SCMode
13
- import torch
14
- from einops import rearrange, repeat
15
- import cv2
16
- from PIL import Image
17
- from diffusers.models.autoencoder_kl import AutoencoderKL
18
-
19
- from mmcm.utils.load_util import load_pyhon_obj
20
- from mmcm.utils.seed_util import set_all_seed
21
- from mmcm.utils.signature import get_signature_of_string
22
- from mmcm.utils.task_util import fiss_tasks, generate_tasks as generate_tasks_from_table
23
- from mmcm.vision.utils.data_type_util import is_video, is_image, read_image_as_5d
24
- from mmcm.utils.str_util import clean_str_for_save
25
- from mmcm.vision.data.video_dataset import DecordVideoDataset
26
- from musev.auto_prompt.util import generate_prompts
27
-
28
- from musev.models.controlnet import PoseGuider
29
- from musev.models.facein_loader import load_facein_extractor_and_proj_by_name
30
- from musev.models.referencenet_loader import load_referencenet_by_name
31
- from musev.models.ip_adapter_loader import (
32
- load_ip_adapter_vision_clip_encoder_by_name,
33
- load_vision_clip_encoder_by_name,
34
- load_ip_adapter_image_proj_by_name,
35
- )
36
- from musev.models.ip_adapter_face_loader import (
37
- load_ip_adapter_face_extractor_and_proj_by_name,
38
- )
39
- from musev.pipelines.pipeline_controlnet_predictor import (
40
- DiffusersPipelinePredictor,
41
- )
42
- from musev.models.referencenet import ReferenceNet2D
43
- from musev.models.unet_loader import load_unet_by_name
44
- from musev.utils.util import save_videos_grid_with_opencv
45
- from musev import logger
46
-
47
- logger.setLevel("INFO")
48
-
49
- file_dir = os.path.dirname(__file__)
50
- PROJECT_DIR = os.path.join(os.path.dirname(__file__), "./")
51
- DATA_DIR = os.path.join(PROJECT_DIR, "data")
52
- CACHE_PATH = "./t2v_input_image"
53
-
54
-
55
- # TODO:use group to group arguments
56
- args_dict = {
57
- "add_static_video_prompt": False,
58
- "context_batch_size": 1,
59
- "context_frames": 12,
60
- "context_overlap": 4,
61
- "context_schedule": "uniform_v2",
62
- "context_stride": 1,
63
- "controlnet_conditioning_scale": 1.0,
64
- "controlnet_name": "dwpose_body_hand",
65
- "cross_attention_dim": 768,
66
- "enable_zero_snr": False,
67
- "end_to_end": True,
68
- "face_image_path": None,
69
- "facein_model_cfg_path": os.path.join(PROJECT_DIR, "./configs/model/facein.py"),
70
- "facein_model_name": None,
71
- "facein_scale": 1.0,
72
- "fix_condition_images": False,
73
- "fixed_ip_adapter_image": True,
74
- "fixed_refer_face_image": True,
75
- "fixed_refer_image": True,
76
- "fps": 4,
77
- "guidance_scale": 7.5,
78
- "height": None,
79
- "img_length_ratio": 1.0,
80
- "img_weight": 0.001,
81
- "interpolation_factor": 1,
82
- "ip_adapter_face_model_cfg_path": os.path.join(
83
- PROJECT_DIR, "./configs/model/ip_adapter.py"
84
- ),
85
- "ip_adapter_face_model_name": None,
86
- "ip_adapter_face_scale": 1.0,
87
- "ip_adapter_model_cfg_path": os.path.join(
88
- PROJECT_DIR, "./configs/model/ip_adapter.py"
89
- ),
90
- "ip_adapter_model_name": "musev_referencenet",
91
- "ip_adapter_scale": 1.0,
92
- "ipadapter_image_path": None,
93
- "lcm_model_cfg_path": os.path.join(PROJECT_DIR, "./configs/model/lcm_model.py"),
94
- "lcm_model_name": None,
95
- "log_level": "INFO",
96
- "motion_speed": 8.0,
97
- "n_batch": 1,
98
- "n_cols": 3,
99
- "n_repeat": 1,
100
- "n_vision_condition": 1,
101
- "need_hist_match": False,
102
- "need_img_based_video_noise": True,
103
- "need_return_condition": False,
104
- "need_return_videos": False,
105
- "need_video2video": False,
106
- "negative_prompt": "V2",
107
- "negprompt_cfg_path": os.path.join(
108
- PROJECT_DIR, "./configs/model/negative_prompt.py"
109
- ),
110
- "noise_type": "video_fusion",
111
- "num_inference_steps": 30,
112
- "output_dir": "./results/",
113
- "overwrite": False,
114
- "pose_guider_model_path": None,
115
- "prompt_only_use_image_prompt": False,
116
- "record_mid_video_latents": False,
117
- "record_mid_video_noises": False,
118
- "redraw_condition_image": False,
119
- "redraw_condition_image_with_facein": True,
120
- "redraw_condition_image_with_ip_adapter_face": True,
121
- "redraw_condition_image_with_ipdapter": True,
122
- "redraw_condition_image_with_referencenet": True,
123
- "referencenet_image_path": None,
124
- "referencenet_model_cfg_path": os.path.join(
125
- PROJECT_DIR, "./configs/model/referencenet.py"
126
- ),
127
- "referencenet_model_name": "musev_referencenet",
128
- "sample_rate": 1,
129
- "save_filetype": "mp4",
130
- "save_images": False,
131
- "sd_model_cfg_path": os.path.join(PROJECT_DIR, "./configs/model/T2I_all_model.py"),
132
- "sd_model_name": "majicmixRealv6Fp16",
133
- "seed": None,
134
- "strength": 0.8,
135
- "target_datas": "boy_dance2",
136
- "test_data_path": os.path.join(
137
- PROJECT_DIR, "./configs/infer/testcase_video_famous.yaml"
138
- ),
139
- "time_size": 24,
140
- "unet_model_cfg_path": os.path.join(PROJECT_DIR, "./configs/model/motion_model.py"),
141
- "unet_model_name": "musev_referencenet",
142
- "use_condition_image": True,
143
- "use_video_redraw": True,
144
- "vae_model_path": os.path.join(PROJECT_DIR, "./checkpoints/vae/sd-vae-ft-mse"),
145
- "video_guidance_scale": 3.5,
146
- "video_guidance_scale_end": None,
147
- "video_guidance_scale_method": "linear",
148
- "video_has_condition": True,
149
- "video_is_middle": False,
150
- "video_negative_prompt": "V2",
151
- "video_num_inference_steps": 10,
152
- "video_overlap": 1,
153
- "video_strength": 1.0,
154
- "vision_clip_extractor_class_name": "ImageClipVisionFeatureExtractor",
155
- "vision_clip_model_path": os.path.join(
156
- PROJECT_DIR, "./checkpoints/IP-Adapter/models/image_encoder"
157
- ),
158
- "w_ind_noise": 0.5,
159
- "which2video": "video_middle",
160
- "width": None,
161
- "write_info": False,
162
- }
163
- args = argparse.Namespace(**args_dict)
164
- print("args")
165
- pprint(args.__dict__)
166
- print("\n")
167
-
168
- logger.setLevel(args.log_level)
169
- overwrite = args.overwrite
170
- cross_attention_dim = args.cross_attention_dim
171
- time_size = args.time_size # 一次视频生成的帧数
172
- n_batch = args.n_batch # 按照time_size的尺寸 生成n_batch次,总帧数 = time_size * n_batch
173
- fps = args.fps
174
- fix_condition_images = args.fix_condition_images
175
- use_condition_image = args.use_condition_image # 当 test_data 中有图像时,作为初始图像
176
- redraw_condition_image = args.redraw_condition_image # 用于视频生成的首帧是否使用重绘后的
177
- need_img_based_video_noise = (
178
- args.need_img_based_video_noise
179
- ) # 视频加噪过程中是否使用首帧 condition_images
180
- img_weight = args.img_weight
181
- height = args.height # 如果测试数据中没有单独指定宽高,则默认这里
182
- width = args.width # 如果测试数据中没有单独指定宽高,则默认这里
183
- img_length_ratio = args.img_length_ratio # 如果测试数据中没有单独指定图像宽高比resize比例,则默认这里
184
- n_cols = args.n_cols
185
- noise_type = args.noise_type
186
- strength = args.strength # 首帧重绘程度参数
187
- video_guidance_scale = args.video_guidance_scale # 视频 condition与 uncond的权重参数
188
- guidance_scale = args.guidance_scale # 时序条件帧 condition与uncond的权重参数
189
- video_num_inference_steps = args.video_num_inference_steps # 视频迭代次数
190
- num_inference_steps = args.num_inference_steps # 时序条件帧 重绘参数
191
- seed = args.seed
192
- save_filetype = args.save_filetype
193
- save_images = args.save_images
194
- sd_model_cfg_path = args.sd_model_cfg_path
195
- sd_model_name = (
196
- args.sd_model_name if args.sd_model_name == "all" else args.sd_model_name.split(",")
197
- )
198
- unet_model_cfg_path = args.unet_model_cfg_path
199
- unet_model_name = args.unet_model_name
200
- test_data_path = args.test_data_path
201
- target_datas = (
202
- args.target_datas if args.target_datas == "all" else args.target_datas.split(",")
203
- )
204
- device = "cuda" if torch.cuda.is_available() else "cpu"
205
- torch_dtype = torch.float16
206
- controlnet_name = args.controlnet_name
207
- controlnet_name_str = controlnet_name
208
- if controlnet_name is not None:
209
- controlnet_name = controlnet_name.split(",")
210
- if len(controlnet_name) == 1:
211
- controlnet_name = controlnet_name[0]
212
-
213
- video_strength = args.video_strength # 视频重绘程度参数
214
- sample_rate = args.sample_rate
215
- controlnet_conditioning_scale = args.controlnet_conditioning_scale
216
-
217
- end_to_end = args.end_to_end # 是否首尾相连生成长视频
218
- control_guidance_start = 0.0
219
- control_guidance_end = 0.5
220
- control_guidance_end = 1.0
221
- negprompt_cfg_path = args.negprompt_cfg_path
222
- video_negative_prompt = args.video_negative_prompt
223
- negative_prompt = args.negative_prompt
224
- motion_speed = args.motion_speed
225
- need_hist_match = args.need_hist_match
226
- video_guidance_scale_end = args.video_guidance_scale_end
227
- video_guidance_scale_method = args.video_guidance_scale_method
228
- add_static_video_prompt = args.add_static_video_prompt
229
- n_vision_condition = args.n_vision_condition
230
- lcm_model_cfg_path = args.lcm_model_cfg_path
231
- lcm_model_name = args.lcm_model_name
232
- referencenet_model_cfg_path = args.referencenet_model_cfg_path
233
- referencenet_model_name = args.referencenet_model_name
234
- ip_adapter_model_cfg_path = args.ip_adapter_model_cfg_path
235
- ip_adapter_model_name = args.ip_adapter_model_name
236
- vision_clip_model_path = args.vision_clip_model_path
237
- vision_clip_extractor_class_name = args.vision_clip_extractor_class_name
238
- facein_model_cfg_path = args.facein_model_cfg_path
239
- facein_model_name = args.facein_model_name
240
- ip_adapter_face_model_cfg_path = args.ip_adapter_face_model_cfg_path
241
- ip_adapter_face_model_name = args.ip_adapter_face_model_name
242
-
243
- fixed_refer_image = args.fixed_refer_image
244
- fixed_ip_adapter_image = args.fixed_ip_adapter_image
245
- fixed_refer_face_image = args.fixed_refer_face_image
246
- redraw_condition_image_with_referencenet = args.redraw_condition_image_with_referencenet
247
- redraw_condition_image_with_ipdapter = args.redraw_condition_image_with_ipdapter
248
- redraw_condition_image_with_facein = args.redraw_condition_image_with_facein
249
- redraw_condition_image_with_ip_adapter_face = (
250
- args.redraw_condition_image_with_ip_adapter_face
251
- )
252
- w_ind_noise = args.w_ind_noise
253
- ip_adapter_scale = args.ip_adapter_scale
254
- facein_scale = args.facein_scale
255
- ip_adapter_face_scale = args.ip_adapter_face_scale
256
- face_image_path = args.face_image_path
257
- ipadapter_image_path = args.ipadapter_image_path
258
- referencenet_image_path = args.referencenet_image_path
259
- vae_model_path = args.vae_model_path
260
- prompt_only_use_image_prompt = args.prompt_only_use_image_prompt
261
- pose_guider_model_path = args.pose_guider_model_path
262
- need_video2video = args.need_video2video
263
- # serial_denoise parameter start
264
- record_mid_video_noises = args.record_mid_video_noises
265
- record_mid_video_latents = args.record_mid_video_latents
266
- video_overlap = args.video_overlap
267
- # serial_denoise parameter end
268
- # parallel_denoise parameter start
269
- context_schedule = args.context_schedule
270
- context_frames = args.context_frames
271
- context_stride = args.context_stride
272
- context_overlap = args.context_overlap
273
- context_batch_size = args.context_batch_size
274
- interpolation_factor = args.interpolation_factor
275
- n_repeat = args.n_repeat
276
-
277
- video_is_middle = args.video_is_middle
278
- video_has_condition = args.video_has_condition
279
- need_return_videos = args.need_return_videos
280
- need_return_condition = args.need_return_condition
281
- # parallel_denoise parameter end
282
- need_controlnet = controlnet_name is not None
283
-
284
- which2video = args.which2video
285
- if which2video == "video":
286
- which2video_name = "v2v"
287
- elif which2video == "video_middle":
288
- which2video_name = "vm2v"
289
- else:
290
- raise ValueError(
291
- "which2video only support video, video_middle, but given {which2video}"
292
- )
293
- b = 1
294
- negative_embedding = [
295
- [os.path.join(PROJECT_DIR, "./checkpoints/embedding/badhandv4.pt"), "badhandv4"],
296
- [
297
- os.path.join(PROJECT_DIR, "./checkpoints/embedding/ng_deepnegative_v1_75t.pt"),
298
- "ng_deepnegative_v1_75t",
299
- ],
300
- [
301
- os.path.join(PROJECT_DIR, "./checkpoints/embedding/EasyNegativeV2.safetensors"),
302
- "EasyNegativeV2",
303
- ],
304
- [
305
- os.path.join(PROJECT_DIR, "./checkpoints/embedding/bad_prompt_version2-neg.pt"),
306
- "bad_prompt_version2-neg",
307
- ],
308
- ]
309
- prefix_prompt = ""
310
- suffix_prompt = ", beautiful, masterpiece, best quality"
311
- suffix_prompt = ""
312
-
313
- if sd_model_name != "None":
314
- # 使用 cfg_path 里的sd_model_path
315
- sd_model_params_dict_src = load_pyhon_obj(sd_model_cfg_path, "MODEL_CFG")
316
- sd_model_params_dict = {
317
- k: v
318
- for k, v in sd_model_params_dict_src.items()
319
- if sd_model_name == "all" or k in sd_model_name
320
- }
321
- else:
322
- # 使用命令行给的sd_model_path, 需要单独设置 sd_model_name 为None,
323
- sd_model_name = os.path.basename(sd_model_cfg_path).split(".")[0]
324
- sd_model_params_dict = {sd_model_name: {"sd": sd_model_cfg_path}}
325
- sd_model_params_dict_src = sd_model_params_dict
326
- if len(sd_model_params_dict) == 0:
327
- raise ValueError(
328
- "has not target model, please set one of {}".format(
329
- " ".join(list(sd_model_params_dict_src.keys()))
330
- )
331
- )
332
- print("running model, T2I SD")
333
- pprint(sd_model_params_dict)
334
-
335
- # lcm
336
- if lcm_model_name is not None:
337
- lcm_model_params_dict_src = load_pyhon_obj(lcm_model_cfg_path, "MODEL_CFG")
338
- print("lcm_model_params_dict_src")
339
- lcm_lora_dct = lcm_model_params_dict_src[lcm_model_name]
340
- else:
341
- lcm_lora_dct = None
342
- print("lcm: ", lcm_model_name, lcm_lora_dct)
343
-
344
-
345
- # motion net parameters
346
- if os.path.isdir(unet_model_cfg_path):
347
- unet_model_path = unet_model_cfg_path
348
- elif os.path.isfile(unet_model_cfg_path):
349
- unet_model_params_dict_src = load_pyhon_obj(unet_model_cfg_path, "MODEL_CFG")
350
- print("unet_model_params_dict_src", unet_model_params_dict_src.keys())
351
- unet_model_path = unet_model_params_dict_src[unet_model_name]["unet"]
352
- else:
353
- raise ValueError(f"expect dir or file, but given {unet_model_cfg_path}")
354
- print("unet: ", unet_model_name, unet_model_path)
355
-
356
-
357
- # referencenet
358
- if referencenet_model_name is not None:
359
- if os.path.isdir(referencenet_model_cfg_path):
360
- referencenet_model_path = referencenet_model_cfg_path
361
- elif os.path.isfile(referencenet_model_cfg_path):
362
- referencenet_model_params_dict_src = load_pyhon_obj(
363
- referencenet_model_cfg_path, "MODEL_CFG"
364
- )
365
- print(
366
- "referencenet_model_params_dict_src",
367
- referencenet_model_params_dict_src.keys(),
368
- )
369
- referencenet_model_path = referencenet_model_params_dict_src[
370
- referencenet_model_name
371
- ]["net"]
372
- else:
373
- raise ValueError(f"expect dir or file, but given {referencenet_model_cfg_path}")
374
- else:
375
- referencenet_model_path = None
376
- print("referencenet: ", referencenet_model_name, referencenet_model_path)
377
-
378
-
379
- # ip_adapter
380
- if ip_adapter_model_name is not None:
381
- ip_adapter_model_params_dict_src = load_pyhon_obj(
382
- ip_adapter_model_cfg_path, "MODEL_CFG"
383
- )
384
- print("ip_adapter_model_params_dict_src", ip_adapter_model_params_dict_src.keys())
385
- ip_adapter_model_params_dict = ip_adapter_model_params_dict_src[
386
- ip_adapter_model_name
387
- ]
388
- else:
389
- ip_adapter_model_params_dict = None
390
- print("ip_adapter: ", ip_adapter_model_name, ip_adapter_model_params_dict)
391
-
392
-
393
- # facein
394
- if facein_model_name is not None:
395
- facein_model_params_dict_src = load_pyhon_obj(facein_model_cfg_path, "MODEL_CFG")
396
- print("facein_model_params_dict_src", facein_model_params_dict_src.keys())
397
- facein_model_params_dict = facein_model_params_dict_src[facein_model_name]
398
- else:
399
- facein_model_params_dict = None
400
- print("facein: ", facein_model_name, facein_model_params_dict)
401
-
402
- # ip_adapter_face
403
- if ip_adapter_face_model_name is not None:
404
- ip_adapter_face_model_params_dict_src = load_pyhon_obj(
405
- ip_adapter_face_model_cfg_path, "MODEL_CFG"
406
- )
407
- print(
408
- "ip_adapter_face_model_params_dict_src",
409
- ip_adapter_face_model_params_dict_src.keys(),
410
- )
411
- ip_adapter_face_model_params_dict = ip_adapter_face_model_params_dict_src[
412
- ip_adapter_face_model_name
413
- ]
414
- else:
415
- ip_adapter_face_model_params_dict = None
416
- print(
417
- "ip_adapter_face: ", ip_adapter_face_model_name, ip_adapter_face_model_params_dict
418
- )
419
-
420
-
421
- # negative_prompt
422
- def get_negative_prompt(negative_prompt, cfg_path=None, n: int = 10):
423
- name = negative_prompt[:n]
424
- if cfg_path is not None and cfg_path not in ["None", "none"]:
425
- dct = load_pyhon_obj(cfg_path, "Negative_Prompt_CFG")
426
- negative_prompt = dct[negative_prompt]["prompt"]
427
-
428
- return name, negative_prompt
429
-
430
-
431
- negtive_prompt_length = 10
432
- video_negative_prompt_name, video_negative_prompt = get_negative_prompt(
433
- video_negative_prompt,
434
- cfg_path=negprompt_cfg_path,
435
- n=negtive_prompt_length,
436
- )
437
- negative_prompt_name, negative_prompt = get_negative_prompt(
438
- negative_prompt,
439
- cfg_path=negprompt_cfg_path,
440
- n=negtive_prompt_length,
441
- )
442
-
443
- print("video_negprompt", video_negative_prompt_name, video_negative_prompt)
444
- print("negprompt", negative_prompt_name, negative_prompt)
445
-
446
- output_dir = args.output_dir
447
- os.makedirs(output_dir, exist_ok=True)
448
-
449
-
450
- # test_data_parameters
451
- def load_yaml(path):
452
- tasks = OmegaConf.to_container(
453
- OmegaConf.load(path), structured_config_mode=SCMode.INSTANTIATE, resolve=True
454
- )
455
- return tasks
456
-
457
-
458
- # if test_data_path.endswith(".yaml"):
459
- # test_datas_src = load_yaml(test_data_path)
460
- # elif test_data_path.endswith(".csv"):
461
- # test_datas_src = generate_tasks_from_table(test_data_path)
462
- # else:
463
- # raise ValueError("expect yaml or csv, but given {}".format(test_data_path))
464
-
465
- # test_datas = [
466
- # test_data
467
- # for test_data in test_datas_src
468
- # if target_datas == "all" or test_data.get("name", None) in target_datas
469
- # ]
470
-
471
- # test_datas = fiss_tasks(test_datas)
472
- # test_datas = generate_prompts(test_datas)
473
-
474
- # n_test_datas = len(test_datas)
475
- # if n_test_datas == 0:
476
- # raise ValueError(
477
- # "n_test_datas == 0, set target_datas=None or set atleast one of {}".format(
478
- # " ".join(list(d.get("name", "None") for d in test_datas_src))
479
- # )
480
- # )
481
- # print("n_test_datas", n_test_datas)
482
- # # pprint(test_datas)
483
-
484
-
485
- def read_image(path):
486
- name = os.path.basename(path).split(".")[0]
487
- image = read_image_as_5d(path)
488
- return image, name
489
-
490
-
491
- def read_image_lst(path):
492
- images_names = [read_image(x) for x in path]
493
- images, names = zip(*images_names)
494
- images = np.concatenate(images, axis=2)
495
- name = "_".join(names)
496
- return images, name
497
-
498
-
499
- def read_image_and_name(path):
500
- if isinstance(path, str):
501
- path = [path]
502
- images, name = read_image_lst(path)
503
- return images, name
504
-
505
-
506
- if referencenet_model_name is not None:
507
- referencenet = load_referencenet_by_name(
508
- model_name=referencenet_model_name,
509
- # sd_model=sd_model_path,
510
- # sd_model="../../checkpoints/Moore-AnimateAnyone/AnimateAnyone/reference_unet.pth",
511
- sd_referencenet_model=referencenet_model_path,
512
- cross_attention_dim=cross_attention_dim,
513
- )
514
- else:
515
- referencenet = None
516
- referencenet_model_name = "no"
517
-
518
- if vision_clip_extractor_class_name is not None:
519
- vision_clip_extractor = load_vision_clip_encoder_by_name(
520
- ip_image_encoder=vision_clip_model_path,
521
- vision_clip_extractor_class_name=vision_clip_extractor_class_name,
522
- )
523
- logger.info(
524
- f"vision_clip_extractor, name={vision_clip_extractor_class_name}, path={vision_clip_model_path}"
525
- )
526
- else:
527
- vision_clip_extractor = None
528
- logger.info(f"vision_clip_extractor, None")
529
-
530
- if ip_adapter_model_name is not None:
531
- ip_adapter_image_proj = load_ip_adapter_image_proj_by_name(
532
- model_name=ip_adapter_model_name,
533
- ip_image_encoder=ip_adapter_model_params_dict.get(
534
- "ip_image_encoder", vision_clip_model_path
535
- ),
536
- ip_ckpt=ip_adapter_model_params_dict["ip_ckpt"],
537
- cross_attention_dim=cross_attention_dim,
538
- clip_embeddings_dim=ip_adapter_model_params_dict["clip_embeddings_dim"],
539
- clip_extra_context_tokens=ip_adapter_model_params_dict[
540
- "clip_extra_context_tokens"
541
- ],
542
- ip_scale=ip_adapter_model_params_dict["ip_scale"],
543
- device=device,
544
- )
545
- else:
546
- ip_adapter_image_proj = None
547
- ip_adapter_model_name = "no"
548
-
549
- if pose_guider_model_path is not None:
550
- logger.info(f"PoseGuider ={pose_guider_model_path}")
551
- pose_guider = PoseGuider.from_pretrained(
552
- pose_guider_model_path,
553
- conditioning_embedding_channels=320,
554
- block_out_channels=(16, 32, 96, 256),
555
- )
556
- else:
557
- pose_guider = None
558
-
559
- for model_name, sd_model_params in sd_model_params_dict.items():
560
- lora_dict = sd_model_params.get("lora", None)
561
- model_sex = sd_model_params.get("sex", None)
562
- model_style = sd_model_params.get("style", None)
563
- sd_model_path = sd_model_params["sd"]
564
- test_model_vae_model_path = sd_model_params.get("vae", vae_model_path)
565
-
566
- unet = load_unet_by_name(
567
- model_name=unet_model_name,
568
- sd_unet_model=unet_model_path,
569
- sd_model=sd_model_path,
570
- # sd_model="../../checkpoints/Moore-AnimateAnyone/AnimateAnyone/denoising_unet.pth",
571
- cross_attention_dim=cross_attention_dim,
572
- need_t2i_facein=facein_model_name is not None,
573
- # facein 目前没参与训练,但在unet中定义了,载入相关参数会报错,所以用strict控制
574
- strict=not (facein_model_name is not None),
575
- need_t2i_ip_adapter_face=ip_adapter_face_model_name is not None,
576
- )
577
-
578
- if facein_model_name is not None:
579
- (
580
- face_emb_extractor,
581
- facein_image_proj,
582
- ) = load_facein_extractor_and_proj_by_name(
583
- model_name=facein_model_name,
584
- ip_image_encoder=facein_model_params_dict["ip_image_encoder"],
585
- ip_ckpt=facein_model_params_dict["ip_ckpt"],
586
- cross_attention_dim=cross_attention_dim,
587
- clip_embeddings_dim=facein_model_params_dict["clip_embeddings_dim"],
588
- clip_extra_context_tokens=facein_model_params_dict[
589
- "clip_extra_context_tokens"
590
- ],
591
- ip_scale=facein_model_params_dict["ip_scale"],
592
- device=device,
593
- # facein目前没有参与unet中的训练,需要单独载入参数
594
- unet=unet,
595
- )
596
- else:
597
- face_emb_extractor = None
598
- facein_image_proj = None
599
-
600
- if ip_adapter_face_model_name is not None:
601
- (
602
- ip_adapter_face_emb_extractor,
603
- ip_adapter_face_image_proj,
604
- ) = load_ip_adapter_face_extractor_and_proj_by_name(
605
- model_name=ip_adapter_face_model_name,
606
- ip_image_encoder=ip_adapter_face_model_params_dict["ip_image_encoder"],
607
- ip_ckpt=ip_adapter_face_model_params_dict["ip_ckpt"],
608
- cross_attention_dim=cross_attention_dim,
609
- clip_embeddings_dim=ip_adapter_face_model_params_dict[
610
- "clip_embeddings_dim"
611
- ],
612
- clip_extra_context_tokens=ip_adapter_face_model_params_dict[
613
- "clip_extra_context_tokens"
614
- ],
615
- ip_scale=ip_adapter_face_model_params_dict["ip_scale"],
616
- device=device,
617
- unet=unet, # ip_adapter_face 目前没有参与unet中的训练,需要单独载入参数
618
- )
619
- else:
620
- ip_adapter_face_emb_extractor = None
621
- ip_adapter_face_image_proj = None
622
-
623
- print("test_model_vae_model_path", test_model_vae_model_path)
624
-
625
- sd_predictor = DiffusersPipelinePredictor(
626
- sd_model_path=sd_model_path,
627
- unet=unet,
628
- lora_dict=lora_dict,
629
- lcm_lora_dct=lcm_lora_dct,
630
- device=device,
631
- dtype=torch_dtype,
632
- negative_embedding=negative_embedding,
633
- referencenet=referencenet,
634
- ip_adapter_image_proj=ip_adapter_image_proj,
635
- vision_clip_extractor=vision_clip_extractor,
636
- facein_image_proj=facein_image_proj,
637
- face_emb_extractor=face_emb_extractor,
638
- vae_model=test_model_vae_model_path,
639
- ip_adapter_face_emb_extractor=ip_adapter_face_emb_extractor,
640
- ip_adapter_face_image_proj=ip_adapter_face_image_proj,
641
- pose_guider=pose_guider,
642
- controlnet_name=controlnet_name,
643
- # TODO: 一些过期参数,待去掉
644
- include_body=True,
645
- include_face=False,
646
- include_hand=True,
647
- enable_zero_snr=args.enable_zero_snr,
648
- )
649
- logger.debug(f"load referencenet"),
650
-
651
- # TODO:这里修改为gradio
652
- import cuid
653
-
654
-
655
- def generate_cuid():
656
- return cuid.cuid()
657
-
658
-
659
- def online_v2v_inference(
660
- prompt,
661
- image_np,
662
- video,
663
- processor,
664
- seed,
665
- fps,
666
- w,
667
- h,
668
- video_length,
669
- img_edge_ratio: float = 1.0,
670
- progress=gr.Progress(track_tqdm=True),
671
- ):
672
- progress(0, desc="Starting...")
673
- # Save the uploaded image to a specified path
674
- if not os.path.exists(CACHE_PATH):
675
- os.makedirs(CACHE_PATH)
676
- image_cuid = generate_cuid()
677
- import pdb
678
-
679
- image_path = os.path.join(CACHE_PATH, f"{image_cuid}.jpg")
680
- image = Image.fromarray(image_np)
681
- image.save(image_path)
682
- time_size = int(video_length)
683
- test_data = {
684
- "name": image_cuid,
685
- "prompt": prompt,
686
- "video_path": video,
687
- "condition_images": image_path,
688
- "refer_image": image_path,
689
- "ipadapter_image": image_path,
690
- "height": h,
691
- "width": w,
692
- "img_length_ratio": img_edge_ratio,
693
- # 'style': 'anime',
694
- # 'sex': 'female'
695
- }
696
- batch = []
697
- texts = []
698
- video_path = test_data.get("video_path")
699
- video_reader = DecordVideoDataset(
700
- video_path,
701
- time_size=int(video_length),
702
- step=time_size,
703
- sample_rate=sample_rate,
704
- device="cpu",
705
- data_type="rgb",
706
- channels_order="c t h w",
707
- drop_last=True,
708
- )
709
- video_height = video_reader.height
710
- video_width = video_reader.width
711
-
712
- print("\n i_test_data", test_data, model_name)
713
- test_data_name = test_data.get("name", test_data)
714
- prompt = test_data["prompt"]
715
- prompt = prefix_prompt + prompt + suffix_prompt
716
- prompt_hash = get_signature_of_string(prompt, length=5)
717
- test_data["prompt_hash"] = prompt_hash
718
- test_data_height = test_data.get("height", height)
719
- test_data_width = test_data.get("width", width)
720
- test_data_condition_images_path = test_data.get("condition_images", None)
721
- test_data_condition_images_index = test_data.get("condition_images_index", None)
722
- test_data_redraw_condition_image = test_data.get(
723
- "redraw_condition_image", redraw_condition_image
724
- )
725
- # read condition_image
726
- if (
727
- test_data_condition_images_path is not None
728
- and use_condition_image
729
- and (
730
- isinstance(test_data_condition_images_path, list)
731
- or (
732
- isinstance(test_data_condition_images_path, str)
733
- and is_image(test_data_condition_images_path)
734
- )
735
- )
736
- ):
737
- (
738
- test_data_condition_images,
739
- test_data_condition_images_name,
740
- ) = read_image_and_name(test_data_condition_images_path)
741
- condition_image_height = test_data_condition_images.shape[3]
742
- condition_image_width = test_data_condition_images.shape[4]
743
- logger.debug(
744
- f"test_data_condition_images use {test_data_condition_images_path}"
745
- )
746
- else:
747
- test_data_condition_images = None
748
- test_data_condition_images_name = "no"
749
- condition_image_height = None
750
- condition_image_width = None
751
- logger.debug(f"test_data_condition_images is None")
752
-
753
- # 当没有指定生成视频的宽高时,使用输入条件的宽高,优先使用 condition_image,低优使用 video
754
- if test_data_height in [None, -1]:
755
- test_data_height = condition_image_height
756
-
757
- if test_data_width in [None, -1]:
758
- test_data_width = condition_image_width
759
-
760
- test_data_img_length_ratio = float(
761
- test_data.get("img_length_ratio", img_length_ratio)
762
- )
763
-
764
- test_data_height = int(test_data_height * test_data_img_length_ratio // 64 * 64)
765
- test_data_width = int(test_data_width * test_data_img_length_ratio // 64 * 64)
766
- pprint(test_data)
767
- print(f"test_data_height={test_data_height}")
768
- print(f"test_data_width={test_data_width}")
769
- # continue
770
- test_data_style = test_data.get("style", None)
771
- test_data_sex = test_data.get("sex", None)
772
- # 如果使用|进行多参数任务设置时对应的字段是字符串类型,需要显式转换浮点数。
773
- test_data_motion_speed = float(test_data.get("motion_speed", motion_speed))
774
- test_data_w_ind_noise = float(test_data.get("w_ind_noise", w_ind_noise))
775
- test_data_img_weight = float(test_data.get("img_weight", img_weight))
776
- logger.debug(f"test_data_condition_images_path {test_data_condition_images_path}")
777
- logger.debug(f"test_data_condition_images_index {test_data_condition_images_index}")
778
- test_data_refer_image_path = test_data.get("refer_image", referencenet_image_path)
779
- test_data_ipadapter_image_path = test_data.get(
780
- "ipadapter_image", ipadapter_image_path
781
- )
782
- test_data_refer_face_image_path = test_data.get("face_image", face_image_path)
783
- test_data_video_is_middle = test_data.get("video_is_middle", video_is_middle)
784
- test_data_video_has_condition = test_data.get(
785
- "video_has_condition", video_has_condition
786
- )
787
-
788
- controlnet_processor_params = {
789
- "detect_resolution": min(test_data_height, test_data_width),
790
- "image_resolution": min(test_data_height, test_data_width),
791
- }
792
- if negprompt_cfg_path is not None:
793
- if "video_negative_prompt" in test_data:
794
- (
795
- test_data_video_negative_prompt_name,
796
- test_data_video_negative_prompt,
797
- ) = get_negative_prompt(
798
- test_data.get(
799
- "video_negative_prompt",
800
- ),
801
- cfg_path=negprompt_cfg_path,
802
- n=negtive_prompt_length,
803
- )
804
- else:
805
- test_data_video_negative_prompt_name = video_negative_prompt_name
806
- test_data_video_negative_prompt = video_negative_prompt
807
- if "negative_prompt" in test_data:
808
- (
809
- test_data_negative_prompt_name,
810
- test_data_negative_prompt,
811
- ) = get_negative_prompt(
812
- test_data.get(
813
- "negative_prompt",
814
- ),
815
- cfg_path=negprompt_cfg_path,
816
- n=negtive_prompt_length,
817
- )
818
- else:
819
- test_data_negative_prompt_name = negative_prompt_name
820
- test_data_negative_prompt = negative_prompt
821
- else:
822
- test_data_video_negative_prompt = test_data.get(
823
- "video_negative_prompt", video_negative_prompt
824
- )
825
- test_data_video_negative_prompt_name = test_data_video_negative_prompt[
826
- :negtive_prompt_length
827
- ]
828
- test_data_negative_prompt = test_data.get("negative_prompt", negative_prompt)
829
- test_data_negative_prompt_name = test_data_negative_prompt[
830
- :negtive_prompt_length
831
- ]
832
-
833
- # 准备 test_data_refer_image
834
- if referencenet is not None:
835
- if test_data_refer_image_path is None:
836
- test_data_refer_image = test_data_condition_images
837
- test_data_refer_image_name = test_data_condition_images_name
838
- logger.debug(f"test_data_refer_image use test_data_condition_images")
839
- else:
840
- test_data_refer_image, test_data_refer_image_name = read_image_and_name(
841
- test_data_refer_image_path
842
- )
843
- logger.debug(f"test_data_refer_image use {test_data_refer_image_path}")
844
- else:
845
- test_data_refer_image = None
846
- test_data_refer_image_name = "no"
847
- logger.debug(f"test_data_refer_image is None")
848
-
849
- # 准备 test_data_ipadapter_image
850
- if vision_clip_extractor is not None:
851
- if test_data_ipadapter_image_path is None:
852
- test_data_ipadapter_image = test_data_condition_images
853
- test_data_ipadapter_image_name = test_data_condition_images_name
854
-
855
- logger.debug(f"test_data_ipadapter_image use test_data_condition_images")
856
- else:
857
- (
858
- test_data_ipadapter_image,
859
- test_data_ipadapter_image_name,
860
- ) = read_image_and_name(test_data_ipadapter_image_path)
861
- logger.debug(
862
- f"test_data_ipadapter_image use f{test_data_ipadapter_image_path}"
863
- )
864
- else:
865
- test_data_ipadapter_image = None
866
- test_data_ipadapter_image_name = "no"
867
- logger.debug(f"test_data_ipadapter_image is None")
868
-
869
- # 准备 test_data_refer_face_image
870
- if facein_image_proj is not None or ip_adapter_face_image_proj is not None:
871
- if test_data_refer_face_image_path is None:
872
- test_data_refer_face_image = test_data_condition_images
873
- test_data_refer_face_image_name = test_data_condition_images_name
874
-
875
- logger.debug(f"test_data_refer_face_image use test_data_condition_images")
876
- else:
877
- (
878
- test_data_refer_face_image,
879
- test_data_refer_face_image_name,
880
- ) = read_image_and_name(test_data_refer_face_image_path)
881
- logger.debug(
882
- f"test_data_refer_face_image use f{test_data_refer_face_image_path}"
883
- )
884
- else:
885
- test_data_refer_face_image = None
886
- test_data_refer_face_image_name = "no"
887
- logger.debug(f"test_data_refer_face_image is None")
888
-
889
- # # 当模型的sex、style与test_data同时存在且不相等时,就跳过这个测试用例
890
- # if (
891
- # model_sex is not None
892
- # and test_data_sex is not None
893
- # and model_sex != test_data_sex
894
- # ) or (
895
- # model_style is not None
896
- # and test_data_style is not None
897
- # and model_style != test_data_style
898
- # ):
899
- # print("model doesnt match test_data")
900
- # print("model name: ", model_name)
901
- # print("test_data: ", test_data)
902
- # continue
903
- # video
904
- filename = os.path.basename(video_path).split(".")[0]
905
- for i_num in range(n_repeat):
906
- test_data_seed = random.randint(0, 1e8) if seed in [None, -1] else seed
907
- cpu_generator, gpu_generator = set_all_seed(int(test_data_seed))
908
-
909
- save_file_name = (
910
- f"{which2video_name}_m={model_name}_rm={referencenet_model_name}_c={test_data_name}"
911
- f"_w={test_data_width}_h={test_data_height}_t={time_size}_n={n_batch}"
912
- f"_vn={video_num_inference_steps}"
913
- f"_w={test_data_img_weight}_w={test_data_w_ind_noise}"
914
- f"_s={test_data_seed}_n={controlnet_name_str}"
915
- f"_s={strength}_g={guidance_scale}_vs={video_strength}_vg={video_guidance_scale}"
916
- f"_p={prompt_hash}_{test_data_video_negative_prompt_name[:10]}"
917
- f"_r={test_data_refer_image_name[:3]}_ip={test_data_refer_image_name[:3]}_f={test_data_refer_face_image_name[:3]}"
918
- )
919
- save_file_name = clean_str_for_save(save_file_name)
920
- output_path = os.path.join(
921
- output_dir,
922
- f"{save_file_name}.{save_filetype}",
923
- )
924
- if os.path.exists(output_path) and not overwrite:
925
- print("existed", output_path)
926
- continue
927
-
928
- if which2video in ["video", "video_middle"]:
929
- need_video2video = False
930
- if which2video == "video":
931
- need_video2video = True
932
-
933
- (
934
- out_videos,
935
- out_condition,
936
- videos,
937
- ) = sd_predictor.run_pipe_video2video(
938
- video=video_path,
939
- time_size=time_size,
940
- step=time_size,
941
- sample_rate=sample_rate,
942
- need_return_videos=need_return_videos,
943
- need_return_condition=need_return_condition,
944
- controlnet_conditioning_scale=controlnet_conditioning_scale,
945
- control_guidance_start=control_guidance_start,
946
- control_guidance_end=control_guidance_end,
947
- end_to_end=end_to_end,
948
- need_video2video=need_video2video,
949
- video_strength=video_strength,
950
- prompt=prompt,
951
- width=test_data_width,
952
- height=test_data_height,
953
- generator=gpu_generator,
954
- noise_type=noise_type,
955
- negative_prompt=test_data_negative_prompt,
956
- video_negative_prompt=test_data_video_negative_prompt,
957
- max_batch_num=n_batch,
958
- strength=strength,
959
- need_img_based_video_noise=need_img_based_video_noise,
960
- video_num_inference_steps=video_num_inference_steps,
961
- condition_images=test_data_condition_images,
962
- fix_condition_images=fix_condition_images,
963
- video_guidance_scale=video_guidance_scale,
964
- guidance_scale=guidance_scale,
965
- num_inference_steps=num_inference_steps,
966
- redraw_condition_image=test_data_redraw_condition_image,
967
- img_weight=test_data_img_weight,
968
- w_ind_noise=test_data_w_ind_noise,
969
- n_vision_condition=n_vision_condition,
970
- motion_speed=test_data_motion_speed,
971
- need_hist_match=need_hist_match,
972
- video_guidance_scale_end=video_guidance_scale_end,
973
- video_guidance_scale_method=video_guidance_scale_method,
974
- vision_condition_latent_index=test_data_condition_images_index,
975
- refer_image=test_data_refer_image,
976
- fixed_refer_image=fixed_refer_image,
977
- redraw_condition_image_with_referencenet=redraw_condition_image_with_referencenet,
978
- ip_adapter_image=test_data_ipadapter_image,
979
- refer_face_image=test_data_refer_face_image,
980
- fixed_refer_face_image=fixed_refer_face_image,
981
- facein_scale=facein_scale,
982
- redraw_condition_image_with_facein=redraw_condition_image_with_facein,
983
- ip_adapter_face_scale=ip_adapter_face_scale,
984
- redraw_condition_image_with_ip_adapter_face=redraw_condition_image_with_ip_adapter_face,
985
- fixed_ip_adapter_image=fixed_ip_adapter_image,
986
- ip_adapter_scale=ip_adapter_scale,
987
- redraw_condition_image_with_ipdapter=redraw_condition_image_with_ipdapter,
988
- prompt_only_use_image_prompt=prompt_only_use_image_prompt,
989
- controlnet_processor_params=controlnet_processor_params,
990
- # serial_denoise parameter start
991
- record_mid_video_noises=record_mid_video_noises,
992
- record_mid_video_latents=record_mid_video_latents,
993
- video_overlap=video_overlap,
994
- # serial_denoise parameter end
995
- # parallel_denoise parameter start
996
- context_schedule=context_schedule,
997
- context_frames=context_frames,
998
- context_stride=context_stride,
999
- context_overlap=context_overlap,
1000
- context_batch_size=context_batch_size,
1001
- interpolation_factor=interpolation_factor,
1002
- # parallel_denoise parameter end
1003
- video_is_middle=test_data_video_is_middle,
1004
- video_has_condition=test_data_video_has_condition,
1005
- )
1006
- else:
1007
- raise ValueError(
1008
- f"only support video, videomiddle2video, but given {which2video_name}"
1009
- )
1010
- print("out_videos.shape", out_videos.shape)
1011
- batch = [out_videos]
1012
- texts = ["out"]
1013
- if videos is not None:
1014
- print("videos.shape", videos.shape)
1015
- batch.insert(0, videos / 255.0)
1016
- texts.insert(0, "videos")
1017
- if need_controlnet and out_condition is not None:
1018
- if not isinstance(out_condition, list):
1019
- print("out_condition", out_condition.shape)
1020
- batch.append(out_condition / 255.0)
1021
- texts.append(controlnet_name)
1022
- else:
1023
- batch.extend([x / 255.0 for x in out_condition])
1024
- texts.extend(controlnet_name)
1025
- out = np.concatenate(batch, axis=0)
1026
- save_videos_grid_with_opencv(
1027
- out,
1028
- output_path,
1029
- texts=texts,
1030
- fps=fps,
1031
- tensor_order="b c t h w",
1032
- n_cols=n_cols,
1033
- write_info=args.write_info,
1034
- save_filetype=save_filetype,
1035
- save_images=save_images,
1036
- )
1037
- print("Save to", output_path)
1038
- print("\n" * 2)
1039
- return output_path