yerang commited on
Commit
9d6f8ad
ยท
verified ยท
1 Parent(s): b47480a

Create stf_utils.py

Browse files
Files changed (1) hide show
  1. stf_utils.py +206 -0
stf_utils.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from concurrent.futures import ThreadPoolExecutor
4
+ from pydub import AudioSegment
5
+ import cv2
6
+ from pathlib import Path
7
+ import subprocess
8
+ from pathlib import Path
9
+ import av
10
+ import imageio
11
+ import numpy as np
12
+ from rich.progress import track
13
+ from tqdm import tqdm
14
+
15
+ import stf_alternative
16
+
17
+ import spaces
18
+
19
+
20
+ def exec_cmd(cmd):
21
+ subprocess.run(
22
+ cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
23
+ )
24
+
25
+
26
+ def images2video(images, wfp, **kwargs):
27
+ fps = kwargs.get("fps", 24)
28
+ video_format = kwargs.get("format", "mp4") # default is mp4 format
29
+ codec = kwargs.get("codec", "libx264") # default is libx264 encoding
30
+ quality = kwargs.get("quality") # video quality
31
+ pixelformat = kwargs.get("pixelformat", "yuv420p") # video pixel format
32
+ image_mode = kwargs.get("image_mode", "rgb")
33
+ macro_block_size = kwargs.get("macro_block_size", 2)
34
+ ffmpeg_params = ["-crf", str(kwargs.get("crf", 18))]
35
+
36
+ writer = imageio.get_writer(
37
+ wfp,
38
+ fps=fps,
39
+ format=video_format,
40
+ codec=codec,
41
+ quality=quality,
42
+ ffmpeg_params=ffmpeg_params,
43
+ pixelformat=pixelformat,
44
+ macro_block_size=macro_block_size,
45
+ )
46
+
47
+ n = len(images)
48
+ for i in track(range(n), description="writing", transient=True):
49
+ if image_mode.lower() == "bgr":
50
+ writer.append_data(images[i][..., ::-1])
51
+ else:
52
+ writer.append_data(images[i])
53
+
54
+ writer.close()
55
+
56
+ # print(f':smiley: Dump to {wfp}\n', style="bold green")
57
+ print(f"Dump to {wfp}\n")
58
+
59
+
60
+ def merge_audio_video(video_fp, audio_fp, wfp):
61
+ if osp.exists(video_fp) and osp.exists(audio_fp):
62
+ cmd = f"ffmpeg -i {video_fp} -i {audio_fp} -c:v copy -c:a aac {wfp} -y"
63
+ exec_cmd(cmd)
64
+ print(f"merge {video_fp} and {audio_fp} to {wfp}")
65
+ else:
66
+ print(f"video_fp: {video_fp} or audio_fp: {audio_fp} not exists!")
67
+
68
+
69
+
70
+
71
+ class STFPipeline:
72
+ def __init__(
73
+ self,
74
+ stf_path: str = "/home/user/app/stf/",
75
+ template_video_path: str = "templates/front_one_piece_dress_nodded_cut.webm",
76
+ config_path: str = "front_config.json",
77
+ checkpoint_path: str = "089.pth",
78
+ root_path: str = "works",
79
+ wavlm_path: str = "microsoft/wavlm-large",
80
+ device: str = "cuda:0"
81
+ ):
82
+ self.device = device
83
+ self.stf_path = stf_path
84
+ self.config_path = os.path.join(stf_path, config_path)
85
+ self.checkpoint_path = os.path.join(stf_path, checkpoint_path)
86
+ self.work_root_path = os.path.join(stf_path, root_path)
87
+ self.wavlm_path = wavlm_path
88
+ self.template_video_path = template_video_path
89
+
90
+ # ๋น„๋™๊ธฐ์ ์œผ๋กœ ๋ชจ๋ธ ๋กœ๋”ฉ
91
+ self.model = self.load_model()
92
+ self.template = self.create_template()
93
+
94
+ @spaces.GPU(duration=240)
95
+ def load_model(self):
96
+ """๋ชจ๋ธ์„ ์ƒ์„ฑํ•˜๊ณ  GPU์— ํ• ๋‹น."""
97
+ model = stf_alternative.create_model(
98
+ config_path=self.config_path,
99
+ checkpoint_path=self.checkpoint_path,
100
+ work_root_path=self.work_root_path,
101
+ device=self.device,
102
+ wavlm_path=self.wavlm_path
103
+ )
104
+ return model
105
+
106
+ @spaces.GPU(duration=240)
107
+ def create_template(self):
108
+ """ํ…œํ”Œ๋ฆฟ ์ƒ์„ฑ."""
109
+ template = stf_alternative.Template(
110
+ model=self.model,
111
+ config_path=self.config_path,
112
+ template_video_path=self.template_video_path
113
+ )
114
+ return template
115
+
116
+ def execute(self, audio: str) -> str:
117
+ """์˜ค๋””์˜ค๋ฅผ ์ž…๋ ฅ ๋ฐ›์•„ ๋น„๋””์˜ค๋ฅผ ์ƒ์„ฑ."""
118
+ # ํด๋” ์ƒ์„ฑ
119
+ Path("dubbing").mkdir(exist_ok=True)
120
+ save_path = os.path.join("dubbing", Path(audio).stem + "--lip.mp4")
121
+
122
+ reader = iter(self.template._get_reader(num_skip_frames=0))
123
+ audio_segment = AudioSegment.from_file(audio)
124
+ results = []
125
+
126
+ # ๋น„๋™๊ธฐ ํ”„๋ ˆ์ž„ ์ƒ์„ฑ
127
+ with ThreadPoolExecutor(max_workers=4) as executor:
128
+ try:
129
+ gen_infer = self.template.gen_infer_concurrent(
130
+ executor, audio_segment, 0
131
+ )
132
+ for idx, (it, _) in enumerate(gen_infer):
133
+ frame = next(reader)
134
+ composed = self.template.compose(idx, frame, it)
135
+ results.append(it["pred"])
136
+ except StopIteration:
137
+ pass
138
+
139
+ self.images_to_video(results, save_path)
140
+ return save_path
141
+
142
+ @staticmethod
143
+ def images_to_video(images, output_path, fps=24):
144
+ """์ด๋ฏธ์ง€ ๋ฐฐ์—ด์„ ๋น„๋””์˜ค๋กœ ๋ณ€ํ™˜."""
145
+ writer = imageio.get_writer(output_path, fps=fps, format="mp4", codec="libx264")
146
+ for i in track(range(len(images)), description="๋น„๋””์˜ค ์ƒ์„ฑ ์ค‘"):
147
+ writer.append_data(images[i])
148
+ writer.close()
149
+ print(f"๋น„๋””์˜ค ์ €์žฅ ์™„๋ฃŒ: {output_path}")
150
+
151
+ # class STFPipeline:
152
+ # def __init__(self,
153
+ # stf_path: str = "/home/user/app/stf/",
154
+ # device: str = "cuda:0",
155
+ # template_video_path: str = "templates/front_one_piece_dress_nodded_cut.webm",
156
+ # config_path: str = "front_config.json",
157
+ # checkpoint_path: str = "089.pth",
158
+ # root_path: str = "works"
159
+
160
+ # ):
161
+
162
+ # config_path = os.path.join(stf_path, config_path)
163
+ # checkpoint_path = os.path.join(stf_path, checkpoint_path)
164
+ # work_root_path = os.path.join(stf_path, root_path)
165
+
166
+ # model = stf_alternative.create_model(
167
+ # config_path=config_path,
168
+ # checkpoint_path=checkpoint_path,
169
+ # work_root_path=work_root_path,
170
+ # device=device,
171
+ # wavlm_path="microsoft/wavlm-large",
172
+ # )
173
+ # self.template = stf_alternative.Template(
174
+ # model=model,
175
+ # config_path=config_path,
176
+ # template_video_path=template_video_path,
177
+ # )
178
+
179
+
180
+ # def execute(self, audio: str):
181
+ # Path("dubbing").mkdir(exist_ok=True)
182
+ # save_path = os.path.join("dubbing", Path(audio).stem+"--lip.mp4")
183
+ # reader = iter(self.template._get_reader(num_skip_frames=0))
184
+ # audio_segment = AudioSegment.from_file(audio)
185
+ # pivot = 0
186
+ # results = []
187
+ # with ThreadPoolExecutor(4) as p:
188
+ # try:
189
+
190
+ # gen_infer = self.template.gen_infer_concurrent(
191
+ # p,
192
+ # audio_segment,
193
+ # pivot,
194
+ # )
195
+ # for idx, (it, chunk) in enumerate(gen_infer, pivot):
196
+ # frame = next(reader)
197
+ # composed = self.template.compose(idx, frame, it)
198
+ # frame_name = f"{idx}".zfill(5)+".jpg"
199
+ # results.append(it['pred'])
200
+ # pivot = idx + 1
201
+ # except StopIteration as e:
202
+ # pass
203
+
204
+ # images2video(results, save_path)
205
+
206
+ # return save_path