import argparse import os import torch from diffusers import AutoencoderKL, DDIMScheduler from omegaconf import OmegaConf from torch import nn from hallo.animate.face_animate import FaceAnimatePipeline from hallo.datasets.audio_processor import AudioProcessor from hallo.datasets.image_processor import ImageProcessor from hallo.models.audio_proj import AudioProjModel from hallo.models.face_locator import FaceLocator from hallo.models.image_proj import ImageProjModel from hallo.models.unet_2d_condition import UNet2DConditionModel from hallo.models.unet_3d import UNet3DConditionModel from hallo.utils.config import filter_non_none from hallo.utils.util import tensor_to_video from flask import Flask, request, jsonify import tempfile import uuid app = Flask(__name__) TEMP_DIR = None class Net(nn.Module): """ The Net class combines all the necessary modules for the inference process. Args: reference_unet (UNet2DConditionModel): The UNet2DConditionModel used as a reference for inference. denoising_unet (UNet3DConditionModel): The UNet3DConditionModel used for denoising the input audio. face_locator (FaceLocator): The FaceLocator model used to locate the face in the input image. imageproj (nn.Module): The ImageProjector model used to project the source image onto the face. audioproj (nn.Module): The AudioProjector model used to project the audio embeddings onto the face. """ def __init__( self, reference_unet: UNet2DConditionModel, denoising_unet: UNet3DConditionModel, face_locator: FaceLocator, imageproj, audioproj, ): super().__init__() self.reference_unet = reference_unet self.denoising_unet = denoising_unet self.face_locator = face_locator self.imageproj = imageproj self.audioproj = audioproj def forward(self,): """ empty function to override abstract function of nn Module """ def get_modules(self): """ Simple method to avoid too-few-public-methods pylint error """ return { "reference_unet": self.reference_unet, "denoising_unet": self.denoising_unet, "face_locator": self.face_locator, "imageproj": self.imageproj, "audioproj": self.audioproj, } class AnimationConfig: def __init__(self, driving_audio_path, source_image_path, result_folder): self.driven_audio = driving_audio_path self.source_image = source_image_path self.output = result_folder self.config = 'configs/inference/default.yaml' def process_audio_emb(audio_emb): """ Process the audio embedding to concatenate with other tensors. Parameters: audio_emb (torch.Tensor): The audio embedding tensor to process. Returns: concatenated_tensors (List[torch.Tensor]): The concatenated tensor list. """ concatenated_tensors = [] for i in range(audio_emb.shape[0]): vectors_to_concat = [ audio_emb[max(min(i + j, audio_emb.shape[0]-1), 0)]for j in range(-2, 3)] concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0)) audio_emb = torch.stack(concatenated_tensors, dim=0) return audio_emb def inference_process(args: argparse.Namespace): """ Perform inference processing. Args: args (argparse.Namespace): Command-line arguments. This function initializes the configuration for the inference process. It sets up the necessary modules and variables to prepare for the upcoming inference steps. """ # 1. init config cli_args = filter_non_none(vars(args)) config = OmegaConf.load(args.config) config = OmegaConf.merge(config, cli_args) source_image_path = args.source_image driving_audio_path = args.driven_audio save_path = config.save_path if not os.path.exists(save_path): os.makedirs(save_path) motion_scale = [config.pose_weight, config.face_weight, config.lip_weight] # 2. runtime variables device = torch.device( "cuda") if torch.cuda.is_available() else torch.device("cpu") if config.weight_dtype == "fp16": weight_dtype = torch.float16 elif config.weight_dtype == "bf16": weight_dtype = torch.bfloat16 elif config.weight_dtype == "fp32": weight_dtype = torch.float32 else: weight_dtype = torch.float32 # 3. prepare inference data # 3.1 prepare source image, face mask, face embeddings img_size = (config.data.source_image.width, config.data.source_image.height) clip_length = config.data.n_sample_frames face_analysis_model_path = config.face_analysis.model_path with ImageProcessor(img_size, face_analysis_model_path) as image_processor: source_image_pixels, \ source_image_face_region, \ source_image_face_emb, \ source_image_full_mask, \ source_image_face_mask, \ source_image_lip_mask = image_processor.preprocess( source_image_path, save_path, config.face_expand_ratio) # 3.2 prepare audio embeddings sample_rate = config.data.driving_audio.sample_rate assert sample_rate == 16000, "audio sample rate must be 16000" fps = config.data.export_video.fps wav2vec_model_path = config.wav2vec.model_path wav2vec_only_last_features = config.wav2vec.features == "last" audio_separator_model_file = config.audio_separator.model_path with AudioProcessor( sample_rate, fps, wav2vec_model_path, wav2vec_only_last_features, os.path.dirname(audio_separator_model_file), os.path.basename(audio_separator_model_file), os.path.join(save_path, "audio_preprocess") ) as audio_processor: audio_emb, audio_length = audio_processor.preprocess(driving_audio_path, clip_length) # 4. build modules sched_kwargs = OmegaConf.to_container(config.noise_scheduler_kwargs) if config.enable_zero_snr: sched_kwargs.update( rescale_betas_zero_snr=True, timestep_spacing="trailing", prediction_type="v_prediction", ) val_noise_scheduler = DDIMScheduler(**sched_kwargs) sched_kwargs.update({"beta_schedule": "scaled_linear"}) vae = AutoencoderKL.from_pretrained(config.vae.model_path) reference_unet = UNet2DConditionModel.from_pretrained( config.base_model_path, subfolder="unet") denoising_unet = UNet3DConditionModel.from_pretrained_2d( config.base_model_path, config.motion_module_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container( config.unet_additional_kwargs), use_landmark=False, ) face_locator = FaceLocator(conditioning_embedding_channels=320) image_proj = ImageProjModel( cross_attention_dim=denoising_unet.config.cross_attention_dim, clip_embeddings_dim=512, clip_extra_context_tokens=4, ) audio_proj = AudioProjModel( seq_len=5, blocks=12, # use 12 layers' hidden states of wav2vec channels=768, # audio embedding channel intermediate_dim=512, output_dim=768, context_tokens=32, ).to(device=device, dtype=weight_dtype) audio_ckpt_dir = config.audio_ckpt_dir # Freeze vae.requires_grad_(False) image_proj.requires_grad_(False) reference_unet.requires_grad_(False) denoising_unet.requires_grad_(False) face_locator.requires_grad_(False) audio_proj.requires_grad_(False) reference_unet.enable_gradient_checkpointing() denoising_unet.enable_gradient_checkpointing() net = Net( reference_unet, denoising_unet, face_locator, image_proj, audio_proj, ) m,u = net.load_state_dict( torch.load( os.path.join(audio_ckpt_dir, "net.pth"), map_location="cpu", ), ) assert len(m) == 0 and len(u) == 0, "Fail to load correct checkpoint." print("loaded weight from ", os.path.join(audio_ckpt_dir, "net.pth")) # 5. inference pipeline = FaceAnimatePipeline( vae=vae, reference_unet=net.reference_unet, denoising_unet=net.denoising_unet, face_locator=net.face_locator, scheduler=val_noise_scheduler, image_proj=net.imageproj, ) pipeline.to(device=device, dtype=weight_dtype) audio_emb = process_audio_emb(audio_emb) source_image_pixels = source_image_pixels.unsqueeze(0) source_image_face_region = source_image_face_region.unsqueeze(0) source_image_face_emb = source_image_face_emb.reshape(1, -1) source_image_face_emb = torch.tensor(source_image_face_emb) source_image_full_mask = [ (mask.repeat(clip_length, 1)) for mask in source_image_full_mask ] source_image_face_mask = [ (mask.repeat(clip_length, 1)) for mask in source_image_face_mask ] source_image_lip_mask = [ (mask.repeat(clip_length, 1)) for mask in source_image_lip_mask ] times = audio_emb.shape[0] // clip_length tensor_result = [] generator = torch.manual_seed(42) for t in range(times): print(f"[{t+1}/{times}]") if len(tensor_result) == 0: # The first iteration motion_zeros = source_image_pixels.repeat( config.data.n_motion_frames, 1, 1, 1) motion_zeros = motion_zeros.to( dtype=source_image_pixels.dtype, device=source_image_pixels.device) pixel_values_ref_img = torch.cat( [source_image_pixels, motion_zeros], dim=0) # concat the ref image and the first motion frames else: motion_frames = tensor_result[-1][0] motion_frames = motion_frames.permute(1, 0, 2, 3) motion_frames = motion_frames[0-config.data.n_motion_frames:] motion_frames = motion_frames * 2.0 - 1.0 motion_frames = motion_frames.to( dtype=source_image_pixels.dtype, device=source_image_pixels.device) pixel_values_ref_img = torch.cat( [source_image_pixels, motion_frames], dim=0) # concat the ref image and the motion frames pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0) audio_tensor = audio_emb[ t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0]) ] audio_tensor = audio_tensor.unsqueeze(0) audio_tensor = audio_tensor.to( device=net.audioproj.device, dtype=net.audioproj.dtype) audio_tensor = net.audioproj(audio_tensor) pipeline_output = pipeline( ref_image=pixel_values_ref_img, audio_tensor=audio_tensor, face_emb=source_image_face_emb, face_mask=source_image_face_region, pixel_values_full_mask=source_image_full_mask, pixel_values_face_mask=source_image_face_mask, pixel_values_lip_mask=source_image_lip_mask, width=img_size[0], height=img_size[1], video_length=clip_length, num_inference_steps=config.inference_steps, guidance_scale=config.cfg_scale, generator=generator, motion_scale=motion_scale, ) tensor_result.append(pipeline_output.videos) tensor_result = torch.cat(tensor_result, dim=2) tensor_result = tensor_result.squeeze(0) tensor_result = tensor_result[:, :audio_length] output_file = config.output # save the result after all iteration tensor_to_video(tensor_result, output_file, driving_audio_path) return output_file def create_temp_dir(): return tempfile.TemporaryDirectory() def save_uploaded_file(file, filename,TEMP_DIR): unique_filename = str(uuid.uuid4()) + "_" + filename file_path = os.path.join(TEMP_DIR.name, unique_filename) file.save(file_path) return file_path @app.route('/run', methods=['POST']) def generate_video(): global TEMP_DIR TEMP_DIR = create_temp_dir() if request.method == 'POST': source_image = request.files['source_image'] # text_prompt = request.form['text_prompt'] # print('Input text prompt: ', text_prompt) # text_prompt = text_prompt.strip() # if not text_prompt: # return jsonify({'error': 'Input text prompt cannot be blank'}), 400 driving_audio = request.files['driving_audio'] source_image_path = save_uploaded_file(source_image, 'source_image.png',TEMP_DIR) print(source_image_path) driving_audio_path = save_uploaded_file(driving_audio, 'driving_audio.wav', TEMP_DIR) print(driving_audio_path) result_folder = TEMP_DIR.name args = AnimationConfig( driving_audio_path=driving_audio_path, source_image_path=source_image_path, result_folder=result_folder) try: # Run the inference process output_file = inference_process(args) return jsonify({"message": "Inference completed successfully", "output_file": os.path.abspath(output_file)}) except Exception as e: return jsonify({"error": "Inference failed", "details": str(e)}), 500 @app.route("/health", methods=["GET"]) def health_status(): response = {"online": "true"} return jsonify(response) if __name__ == '__main__': app.run(debug=True)