Spaces:
Build error
Build error
updated app for video inference
Browse files
app.py
CHANGED
@@ -138,34 +138,35 @@ def image_inference(
|
|
138 |
out['render_masked'].cpu(), out['pred_target_shape_img'][0].cpu()], dim=2))
|
139 |
return res[..., ::-1]
|
140 |
|
141 |
-
def extract_frames(
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
157 |
image_frames = extract_frames(driver_vid)
|
158 |
|
159 |
resulted_imgs = defaultdict(list)
|
160 |
|
161 |
-
video_folder = 'jenya_driver/'
|
162 |
-
image_frames = sorted(glob(f"{video_folder}/*", recursive=True), key=lambda x: int(x.split('/')[-1][:-4]))
|
163 |
-
|
164 |
mask_hard_threshold = 0.5
|
165 |
-
N = len(image_frames)
|
166 |
-
for i in range(0, N, 4):
|
167 |
-
new_out = infer.evaluate(source_img,
|
168 |
-
source_information_for_reuse=out.get('source_information'))
|
169 |
|
170 |
mask_pred = (new_out['pred_target_unet_mask'].cpu() > mask_hard_threshold).float()
|
171 |
mask_pred = mask_errosion(mask_pred[0].float().numpy() * 255)
|
@@ -192,34 +193,41 @@ def video_inference(source_img, driver_vid):
|
|
192 |
im.set_data(video[i,:,:,::-1])
|
193 |
return im
|
194 |
|
195 |
-
anim = animation.FuncAnimation(fig, animate, init_func=init,
|
196 |
-
|
197 |
|
198 |
-
return
|
199 |
|
200 |
with gr.Blocks() as demo:
|
201 |
gr.Markdown("# **<p align='center'>ROME: Realistic one-shot mesh-based head avatars</p>**")
|
202 |
-
|
203 |
gr.Markdown(
|
204 |
"""
|
|
|
|
|
205 |
<p style='text-align: center'>
|
206 |
Create a personal avatar from just a single image using ROME.
|
207 |
<br> <a href='https://arxiv.org/abs/2206.08343' target='_blank'>Paper</a> | <a href='https://samsunglabs.github.io/rome' target='_blank'>Project Page</a> | <a href='https://github.com/SamsungLabs/rome' target='_blank'>Github</a>
|
208 |
</p>
|
|
|
|
|
|
|
|
|
|
|
209 |
"""
|
210 |
)
|
211 |
|
212 |
with gr.Tab("Image Inference"):
|
213 |
with gr.Row():
|
214 |
-
source_img = gr.Image(type="pil", label="
|
215 |
-
driver_img = gr.Image(type="pil", label="
|
216 |
-
image_output = gr.Image()
|
217 |
image_button = gr.Button("Predict")
|
218 |
with gr.Tab("Video Inference"):
|
219 |
with gr.Row():
|
220 |
source_img2 = gr.Image(type="pil", label="source image", show_label=True)
|
221 |
driver_vid = gr.Video(label="driver video")
|
222 |
-
video_output = gr.Image()
|
223 |
video_button = gr.Button("Predict")
|
224 |
|
225 |
gr.Examples(
|
|
|
138 |
out['render_masked'].cpu(), out['pred_target_shape_img'][0].cpu()], dim=2))
|
139 |
return res[..., ::-1]
|
140 |
|
141 |
+
def extract_frames(
|
142 |
+
driver_vid: gr.inputs.Video = None
|
143 |
+
):
|
144 |
+
image_frames = []
|
145 |
+
vid = cv2.VideoCapture(driver_vid) # path to mp4
|
146 |
+
|
147 |
+
while True:
|
148 |
+
success, img = vid.read()
|
149 |
+
|
150 |
+
if not success: break
|
151 |
+
|
152 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
153 |
+
pil_img = Image.fromarray(img)
|
154 |
+
image_frames.append(pil_img)
|
155 |
+
|
156 |
+
return image_frames
|
157 |
+
|
158 |
+
def video_inference(
|
159 |
+
source_img: gr.inputs.Image = None,
|
160 |
+
driver_vid: gr.inputs.Video = None
|
161 |
+
):
|
162 |
image_frames = extract_frames(driver_vid)
|
163 |
|
164 |
resulted_imgs = defaultdict(list)
|
165 |
|
|
|
|
|
|
|
166 |
mask_hard_threshold = 0.5
|
167 |
+
N = len(image_frames)
|
168 |
+
for i in range(0, N, 4): # frame limits
|
169 |
+
new_out = infer.evaluate(source_img, image_frames[i])
|
|
|
170 |
|
171 |
mask_pred = (new_out['pred_target_unet_mask'].cpu() > mask_hard_threshold).float()
|
172 |
mask_pred = mask_errosion(mask_pred[0].float().numpy() * 255)
|
|
|
193 |
im.set_data(video[i,:,:,::-1])
|
194 |
return im
|
195 |
|
196 |
+
anim = animation.FuncAnimation(fig, animate, init_func=init, frames=video.shape[0], interval=30)
|
197 |
+
anim.save("avatar.gif", dpi=300, writer = animation.PillowWriter(fps=24))
|
198 |
|
199 |
+
return "avatar.gif"
|
200 |
|
201 |
with gr.Blocks() as demo:
|
202 |
gr.Markdown("# **<p align='center'>ROME: Realistic one-shot mesh-based head avatars</p>**")
|
203 |
+
|
204 |
gr.Markdown(
|
205 |
"""
|
206 |
+
<img src='https://github.com/SamsungLabs/rome/blob/main/media/tease.gif'>
|
207 |
+
|
208 |
<p style='text-align: center'>
|
209 |
Create a personal avatar from just a single image using ROME.
|
210 |
<br> <a href='https://arxiv.org/abs/2206.08343' target='_blank'>Paper</a> | <a href='https://samsunglabs.github.io/rome' target='_blank'>Project Page</a> | <a href='https://github.com/SamsungLabs/rome' target='_blank'>Github</a>
|
211 |
</p>
|
212 |
+
|
213 |
+
<blockquote>
|
214 |
+
[The] system creates realistic mesh-based avatars from a single <strong>source</strong>
|
215 |
+
photo. These avatars are rigged, i.e., they can be driven by the animation parameters from a different <strong>driving</strong> frame.
|
216 |
+
</blockquote>
|
217 |
"""
|
218 |
)
|
219 |
|
220 |
with gr.Tab("Image Inference"):
|
221 |
with gr.Row():
|
222 |
+
source_img = gr.Image(type="pil", label="Source image", show_label=True)
|
223 |
+
driver_img = gr.Image(type="pil", label="Driver image", show_label=True)
|
224 |
+
image_output = gr.Image("Rendered avatar")
|
225 |
image_button = gr.Button("Predict")
|
226 |
with gr.Tab("Video Inference"):
|
227 |
with gr.Row():
|
228 |
source_img2 = gr.Image(type="pil", label="source image", show_label=True)
|
229 |
driver_vid = gr.Video(label="driver video")
|
230 |
+
video_output = gr.Image(label="Rendered GIF avatar")
|
231 |
video_button = gr.Button("Predict")
|
232 |
|
233 |
gr.Examples(
|