import spaces import os # os.system("Xvfb :99 -ac &") # os.environ["DISPLAY"] = ":99" import OpenGL.GL as gl os.environ["PYOPENGL_PLATFORM"] = "egl" os.environ["MESA_GL_VERSION_OVERRIDE"] = "4.1" import gradio as gr import torch import numpy as np import soundfile as sf import librosa from torchvision.io import write_video from emage_utils.motion_io import beat_format_save from emage_utils import fast_render from emage_utils.npz2pose import render2d from models.camn_audio import CamnAudioModel from models.disco_audio import DiscoAudioModel from models.emage_audio import EmageAudioModel, EmageVQVAEConv, EmageVAEConv, EmageVQModel import torch.nn.functional as F device = torch.device("cuda" if torch.cuda.is_available() else "cpu") save_folder = "./gradio_results" os.makedirs(save_folder, exist_ok=True) print(device) if not os.path.exists("./emage_evaltools/smplx_models"): import subprocess subprocess.run(["git", "clone", "https://huggingface.co/H-Liu1997/emage_evaltools"]) model_camn = CamnAudioModel.from_pretrained("H-Liu1997/camn_audio").to(device).eval() model_disco = DiscoAudioModel.from_pretrained("H-Liu1997/disco_audio").to(device).eval() face_motion_vq = EmageVQVAEConv.from_pretrained("H-Liu1997/emage_audio", subfolder="emage_vq/face").to(device).eval() upper_motion_vq = EmageVQVAEConv.from_pretrained("H-Liu1997/emage_audio", subfolder="emage_vq/upper").to(device).eval() lower_motion_vq = EmageVQVAEConv.from_pretrained("H-Liu1997/emage_audio", subfolder="emage_vq/lower").to(device).eval() hands_motion_vq = EmageVQVAEConv.from_pretrained("H-Liu1997/emage_audio", subfolder="emage_vq/hands").to(device).eval() global_motion_ae = EmageVAEConv.from_pretrained("H-Liu1997/emage_audio", subfolder="emage_vq/global").to(device).eval() emage_vq_model = EmageVQModel( face_model=face_motion_vq, upper_model=upper_motion_vq, lower_model=lower_motion_vq, hands_model=hands_motion_vq, global_model=global_motion_ae ).to(device).eval() model_emage = EmageAudioModel.from_pretrained("H-Liu1997/emage_audio").to(device).eval() def inference_camn(audio_path, sr_model, pose_fps, seed_frames): audio_loaded, _ = librosa.load(audio_path, sr=sr_model) audio_t = torch.from_numpy(audio_loaded).float().unsqueeze(0).to(device) sid = torch.zeros(1, 1).long().to(device) with torch.no_grad(): motion_pred = model_camn(audio_t, sid, seed_frames=seed_frames)["motion_axis_angle"] t = motion_pred.shape[1] motion_pred = motion_pred.cpu().numpy().reshape(t, -1) npz_path = os.path.join(save_folder, "camn_output.npz") beat_format_save(npz_path, motion_pred, upsample=30 // pose_fps) return npz_path def inference_disco(audio_path, sr_model, pose_fps, seed_frames): audio_loaded, _ = librosa.load(audio_path, sr=sr_model) audio_t = torch.from_numpy(audio_loaded).float().unsqueeze(0).to(device) sid = torch.zeros(1, 1).long().to(device) with torch.no_grad(): motion_pred = model_disco(audio_t, sid, seed_frames=seed_frames, seed_motion=None)["motion_axis_angle"] t = motion_pred.shape[1] motion_pred = motion_pred.cpu().numpy().reshape(t, -1) npz_path = os.path.join(save_folder, "disco_output.npz") beat_format_save(npz_path, motion_pred, upsample=30 // pose_fps) return npz_path def inference_emage(audio_path, sr_model, pose_fps): audio_loaded, _ = librosa.load(audio_path, sr=sr_model) audio_t = torch.from_numpy(audio_loaded).float().unsqueeze(0).to(device) sid = torch.zeros(1, 1).long().to(device) with torch.no_grad(): latent_dict = model_emage.inference(audio_t, sid, emage_vq_model, masked_motion=None, mask=None) face_latent = latent_dict["rec_face"] if model_emage.cfg.lf > 0 and model_emage.cfg.cf == 0 else None upper_latent = latent_dict["rec_upper"] if model_emage.cfg.lu > 0 and model_emage.cfg.cu == 0 else None hands_latent = latent_dict["rec_hands"] if model_emage.cfg.lh > 0 and model_emage.cfg.ch == 0 else None lower_latent = latent_dict["rec_lower"] if model_emage.cfg.ll > 0 and model_emage.cfg.cl == 0 else None face_index = torch.max(F.log_softmax(latent_dict["cls_face"], dim=2), dim=2)[1] if model_emage.cfg.cf > 0 else None upper_index = torch.max(F.log_softmax(latent_dict["cls_upper"], dim=2), dim=2)[1] if model_emage.cfg.cu > 0 else None hands_index = torch.max(F.log_softmax(latent_dict["cls_hands"], dim=2), dim=2)[1] if model_emage.cfg.ch > 0 else None lower_index = torch.max(F.log_softmax(latent_dict["cls_lower"], dim=2), dim=2)[1] if model_emage.cfg.cl > 0 else None ref_trans = torch.zeros(1, 1, 3).to(device) all_pred = emage_vq_model.decode( face_latent=face_latent, upper_latent=upper_latent, lower_latent=lower_latent, hands_latent=hands_latent, face_index=face_index, upper_index=upper_index, lower_index=lower_index, hands_index=hands_index, get_global_motion=True, ref_trans=ref_trans[:, 0] ) motion_pred = all_pred["motion_axis_angle"] t = motion_pred.shape[1] motion_pred = motion_pred.cpu().numpy().reshape(t, -1) face_pred = all_pred["expression"].cpu().numpy().reshape(t, -1) trans_pred = all_pred["trans"].cpu().numpy().reshape(t, -1) npz_path = os.path.join(save_folder, "emage_output.npz") beat_format_save(npz_path, motion_pred, upsample=30 // pose_fps, expressions=face_pred, trans=trans_pred) return npz_path def inference_app(audio, model_type, render_mesh=False, render_face=False, render_mesh_face=False): if audio is None: return [None, None, None, None, None] sr_in, audio_data = audio # --- TRUNCATE to 60 seconds if longer --- max_len = int(60 * sr_in) if len(audio_data) > max_len: audio_data = audio_data[:max_len] # ---------------------------------------- tmp_audio_path = os.path.join(save_folder, "tmp_input.wav") sf.write(tmp_audio_path, audio_data, sr_in) if model_type == "CaMN (Upper only)": sr_model, pose_fps, seed_frames = model_camn.cfg.audio_sr, model_camn.cfg.pose_fps, model_camn.cfg.seed_frames npz_path = inference_camn(tmp_audio_path, sr_model, pose_fps, seed_frames) elif model_type == "DisCo (Upper only)": sr_model, pose_fps, seed_frames = model_disco.cfg.audio_sr, model_disco.cfg.pose_fps, model_disco.cfg.seed_frames npz_path = inference_disco(tmp_audio_path, sr_model, pose_fps, seed_frames) else: sr_model, pose_fps = model_emage.cfg.audio_sr, model_emage.cfg.pose_fps npz_path = inference_emage(tmp_audio_path, sr_model, pose_fps) motion_dict = np.load(npz_path, allow_pickle=True) v2d_body = render2d(motion_dict, (720, 480), face_only=False, remove_global=True) out_2d_body = npz_path.replace(".npz", "_2dbody.mp4") write_video(out_2d_body, v2d_body.permute(0, 2, 3, 1), fps=30) final_2d_body = out_2d_body.replace(".mp4", "_audio.mp4") fast_render.add_audio_to_video(out_2d_body, tmp_audio_path, final_2d_body) final_mesh_video = None final_meshface_video = None if render_mesh: mesh_vid = fast_render.render_one_sequence_no_gt( npz_path, save_folder, tmp_audio_path, "./emage_evaltools/smplx_models/" ) final_mesh_video = mesh_vid if render_mesh_face and render_mesh: meshface_vid = fast_render.render_one_sequence_face_only( npz_path, save_folder, tmp_audio_path, "./emage_evaltools/smplx_models/" ) final_meshface_video = meshface_vid final_face_video = None if render_face: v2d_face = render2d(motion_dict, (720, 480), face_only=True, remove_global=True) out_2d_face = npz_path.replace(".npz", "_2dface.mp4") write_video(out_2d_face, v2d_face.permute(0, 2, 3, 1), fps=30) final_face_video = out_2d_face.replace(".mp4", "_audio.mp4") fast_render.add_audio_to_video(out_2d_face, tmp_audio_path, final_face_video) return [final_2d_body, final_mesh_video, final_face_video, final_meshface_video, npz_path] examples_data = [ ["./examples/audio/2_scott_0_103_103_10s.wav", "DisCo (Upper only)", True, True, True], ["./examples/audio/2_scott_0_103_103_10s.wav", "CaMN (Upper only)", True, True, True], ["./examples/audio/2_scott_0_103_103_10s.wav", "EMAGE (Full body + Face)", True, True, True], ] with gr.Blocks() as demo: with gr.Column(): gr.Markdown( """
""" ) with gr.Row(): input_audio = gr.Audio(type="numpy", label="Upload Audio") with gr.Column(): model_type = gr.Radio( choices=["DisCo (Upper only)", "CaMN (Upper only)", "EMAGE (Full body + Face)"], value="CaMN (Upper only)", label="Select Model: DisCo/CaMN for Upper, EMAGE for Full Body+Face" ) render_face = gr.Checkbox(value=False, label="Render 2D Face Landmark (Fast ~4s for 7s)") render_mesh = gr.Checkbox(value=False, label="Render Mesh Body (Slow ~1min for 7s)") render_mesh_face = gr.Checkbox(value=False, label="Render Mesh Face (Extra Slow)") btn = gr.Button("Run Inference") with gr.Row(): vid_body = gr.Video(label="2D Body Video") vid_mesh = gr.Video(label="Mesh Body Video (optional)") vid_face = gr.Video(label="2D Face Video (optional)") vid_meshface = gr.Video(label="Mesh Face Video (optional)") with gr.Column(): gr.Markdown("Download Motion NPZ, Use Our [Blender Add-on](https://huggingface.co/datasets/H-Liu1997/BEAT2_Tools/blob/main/smplx_blender_addon_20230921.zip) for Visualization. [Demo](https://github.com/PantoMatrix/PantoMatrix/issues/178) of how to install on blender.") file_npz = gr.File(label="Motion NPZ") btn.click( fn=inference_app, inputs=[input_audio, model_type, render_mesh, render_face, render_mesh_face], outputs=[vid_body, vid_mesh, vid_face, vid_meshface, file_npz] ) gr.Examples( examples=examples_data, inputs=[input_audio, model_type, render_mesh, render_face, render_mesh_face], outputs=[vid_body, vid_mesh, vid_face, vid_meshface, file_npz], fn=inference_app, cache_examples=True ) if __name__ == "__main__": demo.launch(share=True)