|
import copy |
|
import glob |
|
import json |
|
import os |
|
import pickle |
|
import queue |
|
import shutil |
|
import threading |
|
import time |
|
from datetime import datetime |
|
from pathlib import Path |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
import wget |
|
from loguru import logger |
|
from tqdm import tqdm |
|
|
|
from ...web_configs import WEB_CONFIGS |
|
from .musetalk.utils.blending import get_image_blending, get_image_prepare_material, init_face_parsing_model |
|
from .musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs |
|
from .musetalk.utils.utils import datagen, load_all_model |
|
|
|
|
|
def setup_ffmpeg_env(model_dir): |
|
|
|
|
|
|
|
|
|
ffmpeg_file_name = "ffmpeg-release-amd64-static" |
|
ffmpeg_root = Path(model_dir).joinpath(f"drivers").absolute() |
|
Path(ffmpeg_root).mkdir(exist_ok=True, parents=True) |
|
|
|
ffmpeg_dir = None |
|
for ffmpeg_dir_path in Path(ffmpeg_root).iterdir(): |
|
if not ffmpeg_dir_path.is_dir(): |
|
continue |
|
ffmpeg_dir = str(ffmpeg_dir_path) |
|
|
|
if ffmpeg_dir is None: |
|
os.system( |
|
f"cd {str(ffmpeg_root)} && wget https://johnvansickle.com/ffmpeg/releases/{ffmpeg_file_name}.tar.xz && xz -d {ffmpeg_file_name}.tar.xz && tar -xvf {ffmpeg_file_name}.tar" |
|
) |
|
|
|
for ffmpeg_dir_path in Path(ffmpeg_root).iterdir(): |
|
if not ffmpeg_dir_path.is_dir(): |
|
continue |
|
ffmpeg_dir = str(ffmpeg_dir_path) |
|
break |
|
logger.info(f"setting ffmpeg dir: {ffmpeg_dir}") |
|
if str(ffmpeg_dir) not in os.getenv("PATH"): |
|
logger.info(f"add ffmpeg to path : {str(ffmpeg_dir)}") |
|
os.environ["PATH"] = f"{str(ffmpeg_dir)}:{os.environ['PATH']}" |
|
|
|
|
|
def init_digital_model(model_dir, use_float16=False): |
|
|
|
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" |
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
muse_talk_model_path = snapshot_download(repo_id="TMElyralab/MuseTalk", local_dir=model_dir) |
|
sd_model_path = snapshot_download(repo_id="stabilityai/sd-vae-ft-mse", local_dir=Path(model_dir).joinpath("sd-vae-ft-mse")) |
|
|
|
whisper_pth_path = Path(model_dir).joinpath(r"whisper/tiny.pt") |
|
whisper_pth_path.parent.mkdir(parents=True, exist_ok=True) |
|
if not whisper_pth_path.exists(): |
|
|
|
wget.download( |
|
url="https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", |
|
out=str(whisper_pth_path), |
|
) |
|
|
|
|
|
logger.info("Loading models...") |
|
audio_processor, vae, unet, pe = load_all_model( |
|
audio2feature_model_path=str(whisper_pth_path), |
|
vae_model_path=sd_model_path, |
|
unet_model_dict={ |
|
"unet_config": str(Path(muse_talk_model_path).joinpath("musetalk", "musetalk.json")), |
|
"model_path": str(Path(muse_talk_model_path).joinpath("musetalk", "pytorch_model.bin")), |
|
}, |
|
) |
|
|
|
if use_float16 is True: |
|
pe = pe.half() |
|
vae.vae = vae.vae.half() |
|
unet.model = unet.model.half() |
|
logger.info("Loaded models done !...") |
|
return audio_processor, vae, unet, pe |
|
|
|
|
|
def load_pose_model(model_dir): |
|
|
|
from mmpose.apis import init_model |
|
|
|
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
dw_pose_path = hf_hub_download( |
|
repo_id="yzd-v/DWPose", |
|
filename="dw-ll_ucoco_384.pth", |
|
local_dir=Path(model_dir).joinpath("dwpose"), |
|
) |
|
|
|
config_file = r"./server/digital_human/modules/musetalk/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py" |
|
pose_model = init_model(config_file, dw_pose_path, device="cuda") |
|
|
|
return pose_model |
|
|
|
|
|
def load_face_parsing_model(model_dir): |
|
|
|
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" |
|
from huggingface_hub import hf_hub_download |
|
|
|
model_dir = Path(model_dir).joinpath("face-parse-bisent") |
|
model_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
resnet_path = Path(model_dir).joinpath("resnet18-5c106cde.pth") |
|
if not resnet_path.exists(): |
|
|
|
wget.download( |
|
url="https://download.pytorch.org/models/resnet18-5c106cde.pth", |
|
out=str(resnet_path), |
|
) |
|
|
|
|
|
|
|
_ = hf_hub_download( |
|
repo_id="ManyOtherFunctions/face-parse-bisent", |
|
filename="79999_iter.pth", |
|
local_dir=str(model_dir), |
|
) |
|
|
|
face_parsing_model = init_face_parsing_model( |
|
resnet_path=str(resnet_path), |
|
face_model_pth=Path(model_dir).joinpath("79999_iter.pth"), |
|
) |
|
return face_parsing_model |
|
|
|
|
|
def video2imgs(vid_path, save_path, ext=".png", cut_frame=10000000): |
|
cap = cv2.VideoCapture(vid_path) |
|
count = 0 |
|
while True: |
|
if count > cut_frame: |
|
break |
|
ret, frame = cap.read() |
|
if ret: |
|
cv2.imwrite(f"{save_path}/{count:08d}.png", frame) |
|
count += 1 |
|
else: |
|
break |
|
|
|
|
|
def osmakedirs(path_list): |
|
for path in path_list: |
|
os.makedirs(path) if not os.path.exists(path) else None |
|
|
|
|
|
@torch.no_grad() |
|
class Avatar: |
|
def __init__(self, avatar_id, work_dir, model_dir, video_path, bbox_shift, batch_size, fps, preparation_force): |
|
self.avatar_id = str(avatar_id) |
|
self.video_path = video_path |
|
|
|
self.bbox_shift = bbox_shift |
|
self.model_dir = model_dir |
|
self.work_dir = work_dir |
|
self.preparation_force = preparation_force |
|
self.batch_size = batch_size |
|
self.idx = 0 |
|
self.fps = fps |
|
|
|
self.frame_list_cycle = [] |
|
self.coord_list_cycle = [] |
|
self.input_latent_list_cycle = [] |
|
self.mask_coords_list_cycle = [] |
|
self.mask_list_cycle = [] |
|
|
|
|
|
self.face_parsing_model = load_face_parsing_model(self.model_dir) |
|
self.audio_processor, self.vae, self.unet, self.pe = init_digital_model(self.model_dir, use_float16=False) |
|
self.pe = self.pe.half() |
|
self.vae.vae = self.vae.vae.half() |
|
self.unet.model = self.unet.model.half() |
|
|
|
self.change_character(avatar_id) |
|
|
|
def change_character(self, avatar_id, video_path=""): |
|
|
|
if video_path != "": |
|
logger.info(f"Switch video from {self.video_path} to {video_path}") |
|
self.video_path = video_path |
|
|
|
self.avatar_id = str(avatar_id) |
|
self.avatar_path = f"{self.work_dir}/{self.avatar_id}" |
|
self.full_imgs_path = f"{self.avatar_path}/full_imgs" |
|
self.coords_path = f"{self.avatar_path}/coords.pkl" |
|
self.latents_out_path = f"{self.avatar_path}/latents.pt" |
|
self.mask_out_path = f"{self.avatar_path}/mask" |
|
self.mask_coords_path = f"{self.avatar_path}/mask_coords.pkl" |
|
self.avatar_info_path = f"{self.avatar_path}/avator_info.json" |
|
self.avatar_info = {"avatar_id": self.avatar_id, "video_path": self.video_path, "bbox_shift": self.bbox_shift} |
|
|
|
self.init(vae_model=self.vae, face_parsing_model=self.face_parsing_model) |
|
|
|
def init(self, vae_model, face_parsing_model): |
|
need_to_prepare = False |
|
|
|
if self.preparation_force and os.path.exists(self.avatar_path): |
|
shutil.rmtree(self.avatar_path) |
|
need_to_prepare = True |
|
elif not os.path.exists(self.avatar_path): |
|
|
|
need_to_prepare = True |
|
elif os.path.exists(self.avatar_path): |
|
|
|
with open(self.avatar_info_path, "r") as f: |
|
avatar_info = json.load(f) |
|
if avatar_info["bbox_shift"] != self.avatar_info["bbox_shift"]: |
|
need_to_prepare = True |
|
shutil.rmtree(self.avatar_path) |
|
|
|
if need_to_prepare is False: |
|
|
|
for prepare_file in [ |
|
self.full_imgs_path, |
|
self.coords_path, |
|
self.latents_out_path, |
|
self.mask_out_path, |
|
self.mask_coords_path, |
|
self.avatar_info_path, |
|
]: |
|
if not os.path.exists(prepare_file): |
|
|
|
logger.info(f"Missing file {prepare_file}, will process prerpare...") |
|
need_to_prepare = True |
|
shutil.rmtree(self.avatar_path) |
|
break |
|
|
|
if need_to_prepare: |
|
logger.info("*********************************") |
|
logger.info(f" creating avator: {self.avatar_id}") |
|
logger.info("*********************************") |
|
osmakedirs([self.avatar_path, self.full_imgs_path, self.mask_out_path]) |
|
self.prepare_material(vae_model=vae_model, face_parsing_model=face_parsing_model) |
|
|
|
self.input_latent_list_cycle = torch.load(self.latents_out_path) |
|
with open(self.coords_path, "rb") as f: |
|
self.coord_list_cycle = pickle.load(f) |
|
input_img_list = glob.glob(os.path.join(self.full_imgs_path, "*.[jpJP][pnPN]*[gG]")) |
|
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) |
|
self.frame_list_cycle = read_imgs(input_img_list) |
|
with open(self.mask_coords_path, "rb") as f: |
|
self.mask_coords_list_cycle = pickle.load(f) |
|
input_mask_list = glob.glob(os.path.join(self.mask_out_path, "*.[jpJP][pnPN]*[gG]")) |
|
input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) |
|
self.mask_list_cycle = read_imgs(input_mask_list) |
|
|
|
def prepare_material(self, vae_model, face_parsing_model): |
|
logger.info("preparing data materials ... ...") |
|
with open(self.avatar_info_path, "w") as f: |
|
json.dump(self.avatar_info, f) |
|
|
|
if os.path.isfile(self.video_path): |
|
video2imgs(self.video_path, self.full_imgs_path, ext="png") |
|
else: |
|
logger.info(f"copy files in {self.video_path}") |
|
files = os.listdir(self.video_path) |
|
files.sort() |
|
files = [file for file in files if file.split(".")[-1] == "png"] |
|
for filename in files: |
|
shutil.copyfile(f"{self.video_path}/{filename}", f"{self.full_imgs_path}/{filename}") |
|
input_img_list = sorted(glob.glob(os.path.join(self.full_imgs_path, "*.[jpJP][pnPN]*[gG]"))) |
|
|
|
logger.info("extracting landmarks...") |
|
pose_model = load_pose_model(self.model_dir) |
|
coord_list, frame_list = get_landmark_and_bbox(input_img_list, pose_model, self.bbox_shift) |
|
del pose_model |
|
torch.cuda.empty_cache() |
|
|
|
input_latent_list = [] |
|
idx = -1 |
|
|
|
coord_placeholder = (0.0, 0.0, 0.0, 0.0) |
|
for bbox, frame in zip(coord_list, frame_list): |
|
idx = idx + 1 |
|
if bbox == coord_placeholder: |
|
continue |
|
x1, y1, x2, y2 = bbox |
|
crop_frame = frame[y1:y2, x1:x2] |
|
resized_crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4) |
|
latents = vae_model.get_latents_for_unet(resized_crop_frame) |
|
input_latent_list.append(latents) |
|
|
|
self.frame_list_cycle = frame_list + frame_list[::-1] |
|
self.coord_list_cycle = coord_list + coord_list[::-1] |
|
self.input_latent_list_cycle = input_latent_list + input_latent_list[::-1] |
|
self.mask_coords_list_cycle = [] |
|
self.mask_list_cycle = [] |
|
|
|
for i, frame in enumerate(tqdm(self.frame_list_cycle)): |
|
cv2.imwrite(f"{self.full_imgs_path}/{str(i).zfill(8)}.png", frame) |
|
|
|
face_box = self.coord_list_cycle[i] |
|
mask, crop_box = get_image_prepare_material(frame, face_box, face_parsing_model) |
|
cv2.imwrite(f"{self.mask_out_path}/{str(i).zfill(8)}.png", mask) |
|
self.mask_coords_list_cycle += [crop_box] |
|
self.mask_list_cycle.append(mask) |
|
|
|
with open(self.mask_coords_path, "wb") as f: |
|
pickle.dump(self.mask_coords_list_cycle, f) |
|
|
|
with open(self.coords_path, "wb") as f: |
|
pickle.dump(self.coord_list_cycle, f) |
|
|
|
torch.save(self.input_latent_list_cycle, os.path.join(self.latents_out_path)) |
|
|
|
def process_frames(self, res_frame_queue, video_len, skip_save_images, save_dir_name): |
|
logger.info(video_len) |
|
while True: |
|
if self.idx >= video_len - 1: |
|
break |
|
try: |
|
res_frame = res_frame_queue.get(block=True, timeout=1) |
|
except queue.Empty: |
|
continue |
|
|
|
bbox = self.coord_list_cycle[self.idx % (len(self.coord_list_cycle))] |
|
ori_frame = copy.deepcopy(self.frame_list_cycle[self.idx % (len(self.frame_list_cycle))]) |
|
x1, y1, x2, y2 = bbox |
|
try: |
|
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1)) |
|
except: |
|
continue |
|
mask = self.mask_list_cycle[self.idx % (len(self.mask_list_cycle))] |
|
mask_crop_box = self.mask_coords_list_cycle[self.idx % (len(self.mask_coords_list_cycle))] |
|
|
|
combine_frame = get_image_blending(ori_frame, res_frame, bbox, mask, mask_crop_box) |
|
|
|
if skip_save_images is False: |
|
cv2.imwrite(f"{self.avatar_path}/{save_dir_name}/{str(self.idx).zfill(8)}.png", combine_frame) |
|
self.idx = self.idx + 1 |
|
|
|
def inference(self, audio_path, output_vid, fps, skip_save_images=False): |
|
|
|
tmp_tag = "tmp_" + datetime.now().strftime("%Y-%m-%d-%H-%M-%S") |
|
|
|
os.makedirs(self.avatar_path + f"/{tmp_tag}", exist_ok=True) |
|
logger.info("start inference") |
|
|
|
start_time = time.time() |
|
whisper_feature = self.audio_processor.audio2feat(audio_path) |
|
whisper_chunks = self.audio_processor.feature2chunks(feature_array=whisper_feature, fps=fps) |
|
logger.info(f"processing audio:{audio_path} costs {(time.time() - start_time) * 1000}ms") |
|
|
|
video_num = len(whisper_chunks) |
|
res_frame_queue = queue.Queue() |
|
self.idx = 0 |
|
|
|
process_thread = threading.Thread( |
|
target=self.process_frames, args=(res_frame_queue, video_num, skip_save_images, tmp_tag) |
|
) |
|
process_thread.start() |
|
|
|
gen = datagen(whisper_chunks, self.input_latent_list_cycle, self.batch_size) |
|
start_time = time.time() |
|
|
|
for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=int(np.ceil(float(video_num) / self.batch_size)))): |
|
audio_feature_batch = torch.from_numpy(whisper_batch) |
|
audio_feature_batch = audio_feature_batch.to(device=self.unet.device, dtype=self.unet.model.dtype) |
|
audio_feature_batch = self.pe(audio_feature_batch) |
|
latent_batch = latent_batch.to(dtype=self.unet.model.dtype) |
|
|
|
timesteps = torch.tensor([0], device="cuda") |
|
pred_latents = self.unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample |
|
recon = self.vae.decode_latents(pred_latents) |
|
for res_frame in recon: |
|
res_frame_queue.put(res_frame) |
|
|
|
logger.info("waitting for all queue...") |
|
process_thread.join() |
|
|
|
logger.info("Total process time of {} frames including saving images = {}s".format(video_num, time.time() - start_time)) |
|
|
|
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {self.avatar_path}/{tmp_tag}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 {self.avatar_path}/{tmp_tag}.mp4" |
|
logger.info(cmd_img2video) |
|
os.system(cmd_img2video) |
|
|
|
|
|
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {self.avatar_path}/{tmp_tag}.mp4 {output_vid}" |
|
logger.info(cmd_combine_audio) |
|
os.system(cmd_combine_audio) |
|
|
|
logger.info("Remove tmp files...") |
|
os.remove(f"{self.avatar_path}/{tmp_tag}.mp4") |
|
shutil.rmtree(f"{self.avatar_path}/{tmp_tag}") |
|
|
|
|
|
with open(Path(output_vid).with_suffix(".txt"), "w") as f: |
|
f.write("") |
|
|
|
logger.info(f"result is save to {output_vid}") |
|
|
|
return str(output_vid) |
|
|
|
|
|
def digital_human_preprocess(model_dir, use_float16, video_path, work_dir, fps, bbox_shift): |
|
|
|
avatar = Avatar( |
|
avatar_id="1", |
|
work_dir=work_dir, |
|
model_dir=model_dir, |
|
video_path=video_path, |
|
bbox_shift=bbox_shift, |
|
batch_size=32, |
|
fps=fps, |
|
preparation_force=False, |
|
) |
|
|
|
setup_ffmpeg_env(model_dir) |
|
|
|
return avatar |
|
|
|
|
|
@torch.no_grad() |
|
def gen_digital_human_video( |
|
avatar_handler: Avatar, |
|
stream_id, |
|
audio_path, |
|
work_dir, |
|
video_path, |
|
fps, |
|
): |
|
if not Path(work_dir).exists(): |
|
Path(work_dir).mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
|
|
|
|
if avatar_handler.avatar_id != str(stream_id): |
|
logger.info(f"Change digital human avatar from {avatar_handler.avatar_id} to {stream_id}") |
|
avatar_handler.change_character(str(stream_id)) |
|
|
|
output_vid_file_path = Path(work_dir).joinpath(video_path) |
|
output_vid = avatar_handler.inference( |
|
audio_path=audio_path, |
|
output_vid=str(output_vid_file_path), |
|
fps=fps, |
|
skip_save_images=False, |
|
) |
|
|
|
return output_vid |
|
|
|
|
|
@torch.no_grad() |
|
def gen_digital_human_preprocess(avatar_handler: Avatar, stream_id, work_dir, video_path): |
|
"""更换数字人并进行预处理""" |
|
|
|
if not Path(work_dir).exists(): |
|
Path(work_dir).mkdir(exist_ok=True, parents=True) |
|
|
|
old_id = avatar_handler.avatar_id |
|
old_video_path = avatar_handler.video_path |
|
avatar_handler.preparation_force = True |
|
|
|
logger.info(f"Processing for id: {stream_id}") |
|
avatar_handler.change_character(str(stream_id), video_path) |
|
|
|
|
|
avatar_handler.preparation_force = False |
|
avatar_handler.change_character(old_id, old_video_path) |
|
|
|
return True |
|
|
|
|
|
if WEB_CONFIGS.ENABLE_DIGITAL_HUMAN: |
|
DIGITAL_HUMAN_HANDLER = digital_human_preprocess( |
|
model_dir=WEB_CONFIGS.DIGITAL_HUMAN_MODEL_DIR, |
|
use_float16=False, |
|
video_path=WEB_CONFIGS.DIGITAL_HUMAN_VIDEO_PATH, |
|
work_dir=WEB_CONFIGS.DIGITAL_HUMAN_GEN_PATH, |
|
fps=WEB_CONFIGS.DIGITAL_HUMAN_FPS, |
|
bbox_shift=WEB_CONFIGS.DIGITAL_HUMAN_BBOX_SHIFT, |
|
) |
|
else: |
|
DIGITAL_HUMAN_HANDLER = None |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
data_preparation = False |
|
video_path = "./work_dirs/tts_wavs/2024-06-05-20-48-53.wav" |
|
bbox_shift = 5 |
|
avatar = Avatar( |
|
avatar_id="lelemiao", video_path=video_path, bbox_shift=bbox_shift, batch_size=4, preparation=data_preparation |
|
) |
|
|
|
avatar.inference( |
|
audio_path=r"./work_dirs/tts_wavs/2024-06-05-20-48-53.wav", |
|
out_vid_name="avatar_1", |
|
fps=25, |
|
skip_save_images=False, |
|
) |
|
|