import os import random from datetime import datetime import gradio as gr import numpy as np import torch from diffusers import AutoencoderKL, DDIMScheduler from einops import repeat from huggingface_hub import hf_hub_download, snapshot_download from omegaconf import OmegaConf from PIL import Image from torchvision import transforms from transformers import CLIPVisionModelWithProjection from src.models.pose_guider import PoseGuider from src.models.unet_2d_condition import UNet2DConditionModel from src.models.unet_3d import UNet3DConditionModel from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline from src.utils.download_models import prepare_base_model, prepare_image_encoder from src.utils.util import get_fps, read_frames, save_videos_grid # Partial download prepare_base_model() prepare_image_encoder() snapshot_download( repo_id="stabilityai/sd-vae-ft-mse", local_dir="./pretrained_weights/sd-vae-ft-mse" ) snapshot_download( repo_id="patrolli/AnimateAnyone", local_dir="./pretrained_weights", ) class AnimateController: def __init__( self, config_path="./configs/prompts/animation.yaml", weight_dtype=torch.float16, ): # Read pretrained weights path from config self.config = OmegaConf.load(config_path) self.pipeline = None self.weight_dtype = weight_dtype def animate( self, ref_image, pose_video_path, width=512, height=768, length=24, num_inference_steps=25, cfg=3.5, seed=123, ): generator = torch.manual_seed(seed) if isinstance(ref_image, np.ndarray): ref_image = Image.fromarray(ref_image) if self.pipeline is None: vae = AutoencoderKL.from_pretrained( self.config.pretrained_vae_path, ).to("cuda", dtype=self.weight_dtype) reference_unet = UNet2DConditionModel.from_pretrained( self.config.pretrained_base_model_path, subfolder="unet", ).to(dtype=self.weight_dtype, device="cuda") inference_config_path = self.config.inference_config infer_config = OmegaConf.load(inference_config_path) denoising_unet = UNet3DConditionModel.from_pretrained_2d( self.config.pretrained_base_model_path, self.config.motion_module_path, subfolder="unet", unet_additional_kwargs=infer_config.unet_additional_kwargs, ).to(dtype=self.weight_dtype, device="cuda") pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to( dtype=self.weight_dtype, device="cuda" ) image_enc = CLIPVisionModelWithProjection.from_pretrained( self.config.image_encoder_path ).to(dtype=self.weight_dtype, device="cuda") sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) scheduler = DDIMScheduler(**sched_kwargs) # load pretrained weights denoising_unet.load_state_dict( torch.load(self.config.denoising_unet_path, map_location="cpu"), strict=False, ) reference_unet.load_state_dict( torch.load(self.config.reference_unet_path, map_location="cpu"), ) pose_guider.load_state_dict( torch.load(self.config.pose_guider_path, map_location="cpu"), ) pipe = Pose2VideoPipeline( vae=vae, image_encoder=image_enc, reference_unet=reference_unet, denoising_unet=denoising_unet, pose_guider=pose_guider, scheduler=scheduler, ) pipe = pipe.to("cuda", dtype=self.weight_dtype) self.pipeline = pipe pose_images = read_frames(pose_video_path) src_fps = get_fps(pose_video_path) pose_list = [] pose_tensor_list = [] pose_transform = transforms.Compose( [transforms.Resize((height, width)), transforms.ToTensor()] ) for pose_image_pil in pose_images[:length]: pose_list.append(pose_image_pil) pose_tensor_list.append(pose_transform(pose_image_pil)) video = self.pipeline( ref_image, pose_list, width=width, height=height, video_length=length, num_inference_steps=num_inference_steps, guidance_scale=cfg, generator=generator, ).videos ref_image_tensor = pose_transform(ref_image) # (c, h, w) ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w) ref_image_tensor = repeat( ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=length ) pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w) pose_tensor = pose_tensor.transpose(0, 1) pose_tensor = pose_tensor.unsqueeze(0) video = torch.cat([ref_image_tensor, pose_tensor, video], dim=0) save_dir = f"./output/gradio" if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True) date_str = datetime.now().strftime("%Y%m%d") time_str = datetime.now().strftime("%H%M") out_path = os.path.join(save_dir, f"{date_str}T{time_str}.mp4") save_videos_grid( video, out_path, n_rows=3, fps=src_fps, ) torch.cuda.empty_cache() return out_path controller = AnimateController() def ui(): with gr.Blocks() as demo: gr.HTML( """
This is a quick preview demo of Moore-AnimateAnyone. We appreciate the assistance provided by the HuggingFace team in setting up this demo.
If you like this project, please consider giving a star on our GitHub repo 🤗.