TDN-M commited on
Commit
fb17518
·
verified ·
1 Parent(s): 3879c25

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -42
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  import os
3
  import sys
4
- import argparse
5
  import random
6
  import time
7
  from omegaconf import OmegaConf
@@ -12,6 +11,7 @@ from huggingface_hub import hf_hub_download
12
  from einops import repeat
13
  import torchvision.transforms as transforms
14
  from utils.utils import instantiate_from_config
 
15
  sys.path.insert(0, "scripts/evaluation")
16
  from funcs import (
17
  batch_ddim_sampling,
@@ -23,13 +23,11 @@ from funcs import (
23
  def download_model():
24
  REPO_ID = 'Doubiiu/DynamiCrafter_512'
25
  filename_list = ['model.ckpt']
26
- if not os.path.exists('./checkpoints/dynamicrafter_512_v1/'):
27
- os.makedirs('./checkpoints/dynamicrafter_512_v1/')
28
  for filename in filename_list:
29
  local_file = os.path.join('./checkpoints/dynamicrafter_512_v1/', filename)
30
  if not os.path.exists(local_file):
31
  hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/dynamicrafter_512_v1/', force_download=True)
32
-
33
 
34
  def infer(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123):
35
  resolution = (320, 512)
@@ -38,7 +36,7 @@ def infer(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123):
38
  config_file='configs/inference_512_v1.0.yaml'
39
  config = OmegaConf.load(config_file)
40
  model_config = config.pop("model", OmegaConf.create())
41
- model_config['params']['unet_config']['params']['use_checkpoint']=False
42
  model = instantiate_from_config(model_config)
43
  assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
44
  model = load_model_checkpoint(model, ckpt_path)
@@ -50,14 +48,14 @@ def infer(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123):
50
  transform = transforms.Compose([
51
  transforms.Resize(min(resolution)),
52
  transforms.CenterCrop(resolution),
53
- ])
54
  torch.cuda.empty_cache()
55
- print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
56
  start = time.time()
57
  if steps > 60:
58
  steps = 60
59
 
60
- batch_size=1
61
  channels = model.model.diffusion_model.out_channels
62
  frames = model.temporal_length
63
  h, w = resolution[0] // 8, resolution[1] // 8
@@ -70,14 +68,14 @@ def infer(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123):
70
  img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
71
  img_tensor = (img_tensor / 255. - 0.5) * 2
72
 
73
- image_tensor_resized = transform(img_tensor) #3,256,256
74
- videos = image_tensor_resized.unsqueeze(0) # bchw
75
 
76
- z = get_latent_z(model, videos.unsqueeze(2)) #bc,1,hw
77
 
78
  img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
79
 
80
- cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc
81
  img_emb = model.image_proj_model(cond_images)
82
 
83
  imtext_cond = torch.cat([text_emb, img_emb], dim=1)
@@ -85,16 +83,15 @@ def infer(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123):
85
  fs = torch.tensor([fs], dtype=torch.long, device=model.device)
86
  cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
87
 
88
- ## inference
89
  batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
90
- ## b,samples,c,t,h,w
91
 
92
  video_path = './output.mp4'
93
  save_videos(batch_samples, './', filenames=['output'], fps=save_fps)
94
  model = model.cpu()
95
  return video_path
96
 
97
-
98
  i2v_examples = [
99
  ['prompts/512/bloom01.png', 'time-lapse of a blooming flower with leaves and a stem', 50, 7.5, 1.0, 24, 123],
100
  ['prompts/512/campfire.png', 'a bonfire is lit in the middle of a field', 50, 7.5, 1.0, 24, 123],
@@ -104,31 +101,32 @@ i2v_examples = [
104
  ['prompts/512/zreal_penguin.png', 'a group of penguins walking on a beach', 50, 7.5, 1.0, 20, 123],
105
  ]
106
 
107
-
108
-
109
-
110
  css = """#input_img {max-width: 512px !important} #output_vid {max-width: 512px; max-height: 320px}"""
111
 
112
  with gr.Blocks(analytics_enabled=False, css=css) as dynamicrafter_iface:
113
- gr.Markdown("<div align='center'> <h1> DynamiCrafter: Animating Open-domain Images with Video Diffusion Priors </span> </h1> \
114
- <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
115
- <a href='https://doubiiu.github.io/'>Jinbo Xing</a>, \
116
- <a href='https://menghanxia.github.io/'>Menghan Xia</a>, <a href='https://yzhang2016.github.io/'>Yong Zhang</a>, \
117
- <a href=''>Haoxin Chen</a>, <a href=''> Wangbo Yu</a>,\
118
- <a href='https://github.com/hyliu'>Hanyuan Liu</a>, <a href='https://xinntao.github.io/'>Xintao Wang</a>,\
119
- <a href='https://www.cse.cuhk.edu.hk/~ttwong/myself.html'>Tien-Tsin Wong</a>,\
120
- <a href='https://scholar.google.com/citations?user=4oXBp9UAAAAJ&hl=zh-CN'>Ying Shan</a>\
121
- </h2> \
122
- <a style='font-size:18px;color: #000000' href='https://arxiv.org/abs/2310.12190'> [ArXiv] </a>\
123
- <a style='font-size:18px;color: #000000' href='https://doubiiu.github.io/projects/DynamiCrafter/'> [Project Page] </a> \
124
- <a style='font-size:18px;color: #000000' href='https://github.com/Doubiiu/DynamiCrafter'> [Github] </a> </div>")
125
-
 
 
 
 
126
  with gr.Tab(label='ImageAnimation_320x512'):
127
  with gr.Column():
128
  with gr.Row():
129
  with gr.Column():
130
  with gr.Row():
131
- i2v_input_image = gr.Image(label="Input Image",elem_id="input_img")
132
  with gr.Row():
133
  i2v_input_text = gr.Text(label='Prompts')
134
  with gr.Row():
@@ -139,18 +137,19 @@ with gr.Blocks(analytics_enabled=False, css=css) as dynamicrafter_iface:
139
  i2v_steps = gr.Slider(minimum=1, maximum=60, step=1, elem_id="i2v_steps", label="Sampling steps", value=50)
140
  i2v_motion = gr.Slider(minimum=15, maximum=30, step=1, elem_id="i2v_motion", label="FPS", value=24)
141
  i2v_end_btn = gr.Button("Generate")
142
- # with gr.Tab(label='Result'):
143
  with gr.Row():
144
- i2v_output_video = gr.Video(label="Generated Video",elem_id="output_vid",autoplay=True,show_share_button=True)
145
 
146
- gr.Examples(examples=i2v_examples,
147
- inputs=[i2v_input_image, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed],
148
- outputs=[i2v_output_video],
149
- fn = infer,
 
150
  )
151
- i2v_end_btn.click(inputs=[i2v_input_image, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed],
152
- outputs=[i2v_output_video],
153
- fn = infer
 
154
  )
155
 
156
- dynamicrafter_iface.queue(max_size=12).launch(share=True, show_api=True)
 
1
  import gradio as gr
2
  import os
3
  import sys
 
4
  import random
5
  import time
6
  from omegaconf import OmegaConf
 
11
  from einops import repeat
12
  import torchvision.transforms as transforms
13
  from utils.utils import instantiate_from_config
14
+
15
  sys.path.insert(0, "scripts/evaluation")
16
  from funcs import (
17
  batch_ddim_sampling,
 
23
  def download_model():
24
  REPO_ID = 'Doubiiu/DynamiCrafter_512'
25
  filename_list = ['model.ckpt']
26
+ os.makedirs('./checkpoints/dynamicrafter_512_v1/', exist_ok=True)
 
27
  for filename in filename_list:
28
  local_file = os.path.join('./checkpoints/dynamicrafter_512_v1/', filename)
29
  if not os.path.exists(local_file):
30
  hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/dynamicrafter_512_v1/', force_download=True)
 
31
 
32
  def infer(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123):
33
  resolution = (320, 512)
 
36
  config_file='configs/inference_512_v1.0.yaml'
37
  config = OmegaConf.load(config_file)
38
  model_config = config.pop("model", OmegaConf.create())
39
+ model_config['params']['unet_config']['params']['use_checkpoint'] = False
40
  model = instantiate_from_config(model_config)
41
  assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
42
  model = load_model_checkpoint(model, ckpt_path)
 
48
  transform = transforms.Compose([
49
  transforms.Resize(min(resolution)),
50
  transforms.CenterCrop(resolution),
51
+ ])
52
  torch.cuda.empty_cache()
53
+ print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
54
  start = time.time()
55
  if steps > 60:
56
  steps = 60
57
 
58
+ batch_size = 1
59
  channels = model.model.diffusion_model.out_channels
60
  frames = model.temporal_length
61
  h, w = resolution[0] // 8, resolution[1] // 8
 
68
  img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device)
69
  img_tensor = (img_tensor / 255. - 0.5) * 2
70
 
71
+ image_tensor_resized = transform(img_tensor) # 3, 256, 256
72
+ videos = image_tensor_resized.unsqueeze(0) # b, c, h, w
73
 
74
+ z = get_latent_z(model, videos.unsqueeze(2)) # b, c, 1, h, w
75
 
76
  img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
77
 
78
+ cond_images = model.embedder(img_tensor.unsqueeze(0)) # b, l, c
79
  img_emb = model.image_proj_model(cond_images)
80
 
81
  imtext_cond = torch.cat([text_emb, img_emb], dim=1)
 
83
  fs = torch.tensor([fs], dtype=torch.long, device=model.device)
84
  cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
85
 
86
+ # inference
87
  batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale)
88
+ # b, samples, c, t, h, w
89
 
90
  video_path = './output.mp4'
91
  save_videos(batch_samples, './', filenames=['output'], fps=save_fps)
92
  model = model.cpu()
93
  return video_path
94
 
 
95
  i2v_examples = [
96
  ['prompts/512/bloom01.png', 'time-lapse of a blooming flower with leaves and a stem', 50, 7.5, 1.0, 24, 123],
97
  ['prompts/512/campfire.png', 'a bonfire is lit in the middle of a field', 50, 7.5, 1.0, 24, 123],
 
101
  ['prompts/512/zreal_penguin.png', 'a group of penguins walking on a beach', 50, 7.5, 1.0, 20, 123],
102
  ]
103
 
 
 
 
104
  css = """#input_img {max-width: 512px !important} #output_vid {max-width: 512px; max-height: 320px}"""
105
 
106
  with gr.Blocks(analytics_enabled=False, css=css) as dynamicrafter_iface:
107
+ gr.Markdown("""
108
+ <div align='center'>
109
+ <h1> DynamiCrafter: Animating Open-domain Images with Video Diffusion Priors </h1>
110
+ <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>
111
+ <a href='https://doubiiu.github.io/'>Jinbo Xing</a>,
112
+ <a href='https://menghanxia.github.io/'>Menghan Xia</a>, <a href='https://yzhang2016.github.io/'>Yong Zhang</a>,
113
+ <a href=''>Haoxin Chen</a>, <a href=''> Wangbo Yu</a>,
114
+ <a href='https://github.com/hyliu'>Hanyuan Liu</a>, <a href='https://xinntao.github.io/'>Xintao Wang</a>,
115
+ <a href='https://www.cse.cuhk.edu.hk/~ttwong/myself.html'>Tien-Tsin Wong</a>,
116
+ <a href='https://scholar.google.com/citations?user=4oXBp9UAAAAJ&hl=zh-CN'>Ying Shan</a>
117
+ </h2>
118
+ <a style='font-size:18px;color: #000000' href='https://arxiv.org/abs/2310.12190'> [ArXiv] </a>
119
+ <a style='font-size:18px;color: #000000' href='https://doubiiu.github.io/projects/DynamiCrafter/'> [Project Page] </a>
120
+ <a style='font-size:18px;color: #000000' href='https://github.com/Doubiiu/DynamiCrafter'> [Github] </a>
121
+ </div>
122
+ """)
123
+
124
  with gr.Tab(label='ImageAnimation_320x512'):
125
  with gr.Column():
126
  with gr.Row():
127
  with gr.Column():
128
  with gr.Row():
129
+ i2v_input_image = gr.Image(label="Input Image", elem_id="input_img")
130
  with gr.Row():
131
  i2v_input_text = gr.Text(label='Prompts')
132
  with gr.Row():
 
137
  i2v_steps = gr.Slider(minimum=1, maximum=60, step=1, elem_id="i2v_steps", label="Sampling steps", value=50)
138
  i2v_motion = gr.Slider(minimum=15, maximum=30, step=1, elem_id="i2v_motion", label="FPS", value=24)
139
  i2v_end_btn = gr.Button("Generate")
 
140
  with gr.Row():
141
+ i2v_output_video = gr.Video(label="Generated Video", elem_id="output_vid", autoplay=True, show_share_button=True)
142
 
143
+ gr.Examples(
144
+ examples=i2v_examples,
145
+ inputs=[i2v_input_image, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed],
146
+ outputs=[i2v_output_video],
147
+ fn=infer,
148
  )
149
+ i2v_end_btn.click(
150
+ inputs=[i2v_input_image, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed],
151
+ outputs=[i2v_output_video],
152
+ fn=infer
153
  )
154
 
155
+ dynamicrafter_iface.queue(max_size=12).launch(share=True, show_api=True)