xunsong.li commited on
Commit
dd93214
1 Parent(s): 75c09e2

fix out length exceeds than pose frames

Browse files
Files changed (2) hide show
  1. .gitignore +8 -1
  2. app.py +12 -8
.gitignore CHANGED
@@ -1,4 +1,11 @@
1
  __pycache__/
2
  pretrained_weights/
3
  output/
4
- .venv/
 
 
 
 
 
 
 
 
1
  __pycache__/
2
  pretrained_weights/
3
  output/
4
+ .venv/
5
+ mlruns/
6
+ data/
7
+
8
+ *.pth
9
+ *.pt
10
+ *.pkl
11
+ *.bin
app.py CHANGED
@@ -114,29 +114,33 @@ class AnimateController:
114
  src_fps = get_fps(pose_video_path)
115
 
116
  pose_list = []
117
- pose_tensor_list = []
118
- pose_transform = transforms.Compose(
119
- [transforms.Resize((height, width)), transforms.ToTensor()]
120
- )
121
- for pose_image_pil in pose_images[:length]:
122
  pose_list.append(pose_image_pil)
123
- pose_tensor_list.append(pose_transform(pose_image_pil))
124
 
125
  video = self.pipeline(
126
  ref_image,
127
  pose_list,
128
  width=width,
129
  height=height,
130
- video_length=length,
131
  num_inference_steps=num_inference_steps,
132
  guidance_scale=cfg,
133
  generator=generator,
134
  ).videos
135
 
 
 
 
 
 
 
 
 
136
  ref_image_tensor = pose_transform(ref_image) # (c, h, w)
137
  ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w)
138
  ref_image_tensor = repeat(
139
- ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=length
140
  )
141
  pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w)
142
  pose_tensor = pose_tensor.transpose(0, 1)
 
114
  src_fps = get_fps(pose_video_path)
115
 
116
  pose_list = []
117
+ total_length = min(length, len(pose_images))
118
+ for pose_image_pil in pose_images[:total_length]:
 
 
 
119
  pose_list.append(pose_image_pil)
 
120
 
121
  video = self.pipeline(
122
  ref_image,
123
  pose_list,
124
  width=width,
125
  height=height,
126
+ video_length=total_length,
127
  num_inference_steps=num_inference_steps,
128
  guidance_scale=cfg,
129
  generator=generator,
130
  ).videos
131
 
132
+ new_h, new_w = video.shape[-2:]
133
+ pose_transform = transforms.Compose(
134
+ [transforms.Resize((new_h, new_w)), transforms.ToTensor()]
135
+ )
136
+ pose_tensor_list = []
137
+ for pose_image_pil in pose_images[:total_length]:
138
+ pose_tensor_list.append(pose_transform(pose_image_pil))
139
+
140
  ref_image_tensor = pose_transform(ref_image) # (c, h, w)
141
  ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w)
142
  ref_image_tensor = repeat(
143
+ ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=total_length
144
  )
145
  pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w)
146
  pose_tensor = pose_tensor.transpose(0, 1)