import errno import gc import os import sys import torch # from .s2f_dir.src.speech_encoder.WavLM import WavLM, WavLMConfig from transformers import Wav2Vec2FeatureExtractor, WavLMModel from .s2f_dir.src import autoencoder as ae from .util import * g_fix_seed = False g_audio_processor = None g_audio_encoder = None class ModelInfo: def __init__( self, model, audio_processor, audio_encoder, args, device, work_root_path, config_path, checkpoint_path, verbose=False, ): self.model = model self.audio_processor = audio_processor self.audio_encoder = audio_encoder self.args = args self.device = device # snow : 아래는 debuging 을 위해 저장해 두는 것 self.work_root_path = work_root_path self.config_path = config_path self.checkpoint_path = checkpoint_path self.verbose = verbose def __del__(self): if self.verbose: print("del model , gc:", sys.getrefcount(self.model)) del self.model if self.args.model_type == "stf_v3": del self.audio_encoder del self.audio_processor def __init_fix_seed(random_seed, verbose=False): global g_fix_seed if g_fix_seed == True: return if verbose: print("fix seed") fix_seed(random_seed) g_fix_seed = True def create_model( config_path, checkpoint_path, work_root_path, device, verbose=False, wavlm_path=None ): __init_fix_seed(random_seed=1234, verbose=verbose) global g_audio_encoder global g_audio_processor if verbose: print("load model") if not os.path.exists(config_path): raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), config_path) args = read_config(config_path) if args.model_type and args.model_type == "remote": return ModelInfo( model=None, audio_processor=None, audio_encoder=None, args=args, device=device, work_root_path=work_root_path, config_path=config_path, checkpoint_path=checkpoint_path, verbose=verbose, ) if not os.path.exists(checkpoint_path): raise FileNotFoundError( errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_path ) if args.model_type: model = ae.Speech2Face( 3, (3, args.img_size, args.img_size), (1, 96, args.mel_step_size), args.model_type, ) else: model = ae.Speech2Face( 3, (3, args.img_size, args.img_size), (1, 96, args.mel_step_size), "stf_v1" ) if len(args.model_type) == 0: # snow: 나중에 생긴 설정이어서 이 항목이 없을 수가 있다. args.model_type = "stf_v1" if args.model_type == "stf_v3": if g_audio_encoder == None: if wavlm_path is None: wavlm_path = f"{Path(__file__).parent.parent}/hf_wavlm" if verbose: print(f"@@@@@@@@@@@@@@@@@@ {wavlm_path}") g_audio_processor = Wav2Vec2FeatureExtractor.from_pretrained(wavlm_path) g_audio_encoder = WavLMModel.from_pretrained(wavlm_path) checkpoint = torch.load(checkpoint_path, map_location="cpu") if "state_dict" in checkpoint: model.load_state_dict(checkpoint["state_dict"]) else: model.load_state_dict(checkpoint) if device == "cuda" and torch.cuda.device_count() > 1: gpus = list(range(torch.cuda.device_count())) print("Multi GPU activate, gpus : ", gpus) model = torch.nn.DataParallel(model, device_ids=gpus) model.to(device) model.eval() if args.model_type == "stf_v3": g_audio_encoder = torch.nn.DataParallel(g_audio_encoder, device_ids=gpus) g_audio_encoder.to(device) g_audio_encoder.eval() else: model.to(device).eval() if args.model_type == "stf_v3": g_audio_encoder.to(device).eval() model_data = ModelInfo( model=model, audio_processor=g_audio_processor, audio_encoder=g_audio_encoder, args=args, device=device, work_root_path=work_root_path, config_path=config_path, checkpoint_path=checkpoint_path, verbose=verbose, ) del checkpoint gc.collect() if verbose: print("load model complete") return model_data