Spaces:
Sleeping
Sleeping
import errno | |
import gc | |
import os | |
import sys | |
import torch | |
# from .s2f_dir.src.speech_encoder.WavLM import WavLM, WavLMConfig | |
from transformers import Wav2Vec2FeatureExtractor, WavLMModel | |
from .s2f_dir.src import autoencoder as ae | |
from .util import * | |
g_fix_seed = False | |
g_audio_processor = None | |
g_audio_encoder = None | |
class ModelInfo: | |
def __init__( | |
self, | |
model, | |
audio_processor, | |
audio_encoder, | |
args, | |
device, | |
work_root_path, | |
config_path, | |
checkpoint_path, | |
verbose=False, | |
): | |
self.model = model | |
self.audio_processor = audio_processor | |
self.audio_encoder = audio_encoder | |
self.args = args | |
self.device = device | |
# snow : μλλ debuging μ μν΄ μ μ₯ν΄ λλ κ² | |
self.work_root_path = work_root_path | |
self.config_path = config_path | |
self.checkpoint_path = checkpoint_path | |
self.verbose = verbose | |
def __del__(self): | |
if self.verbose: | |
print("del model , gc:", sys.getrefcount(self.model)) | |
del self.model | |
if self.args.model_type == "stf_v3": | |
del self.audio_encoder | |
del self.audio_processor | |
def __init_fix_seed(random_seed, verbose=False): | |
global g_fix_seed | |
if g_fix_seed == True: | |
return | |
if verbose: | |
print("fix seed") | |
fix_seed(random_seed) | |
g_fix_seed = True | |
def create_model( | |
config_path, checkpoint_path, work_root_path, device, verbose=False, wavlm_path=None | |
): | |
__init_fix_seed(random_seed=1234, verbose=verbose) | |
global g_audio_encoder | |
global g_audio_processor | |
if verbose: | |
print("load model") | |
if not os.path.exists(config_path): | |
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), config_path) | |
args = read_config(config_path) | |
if args.model_type and args.model_type == "remote": | |
return ModelInfo( | |
model=None, | |
audio_processor=None, | |
audio_encoder=None, | |
args=args, | |
device=device, | |
work_root_path=work_root_path, | |
config_path=config_path, | |
checkpoint_path=checkpoint_path, | |
verbose=verbose, | |
) | |
if not os.path.exists(checkpoint_path): | |
raise FileNotFoundError( | |
errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_path | |
) | |
if args.model_type: | |
model = ae.Speech2Face( | |
3, | |
(3, args.img_size, args.img_size), | |
(1, 96, args.mel_step_size), | |
args.model_type, | |
) | |
else: | |
model = ae.Speech2Face( | |
3, (3, args.img_size, args.img_size), (1, 96, args.mel_step_size), "stf_v1" | |
) | |
if len(args.model_type) == 0: # snow: λμ€μ μκΈ΄ μ€μ μ΄μ΄μ μ΄ νλͺ©μ΄ μμ μκ° μλ€. | |
args.model_type = "stf_v1" | |
if args.model_type == "stf_v3": | |
if g_audio_encoder == None: | |
if wavlm_path is None: | |
wavlm_path = f"{Path(__file__).parent.parent}/hf_wavlm" | |
if verbose: | |
print(f"@@@@@@@@@@@@@@@@@@ {wavlm_path}") | |
g_audio_processor = Wav2Vec2FeatureExtractor.from_pretrained(wavlm_path) | |
g_audio_encoder = WavLMModel.from_pretrained(wavlm_path) | |
checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
if "state_dict" in checkpoint: | |
model.load_state_dict(checkpoint["state_dict"]) | |
else: | |
model.load_state_dict(checkpoint) | |
if device == "cuda" and torch.cuda.device_count() > 1: | |
gpus = list(range(torch.cuda.device_count())) | |
print("Multi GPU activate, gpus : ", gpus) | |
model = torch.nn.DataParallel(model, device_ids=gpus) | |
model.to(device) | |
model.eval() | |
if args.model_type == "stf_v3": | |
g_audio_encoder = torch.nn.DataParallel(g_audio_encoder, device_ids=gpus) | |
g_audio_encoder.to(device) | |
g_audio_encoder.eval() | |
else: | |
model.to(device).eval() | |
if args.model_type == "stf_v3": | |
g_audio_encoder.to(device).eval() | |
model_data = ModelInfo( | |
model=model, | |
audio_processor=g_audio_processor, | |
audio_encoder=g_audio_encoder, | |
args=args, | |
device=device, | |
work_root_path=work_root_path, | |
config_path=config_path, | |
checkpoint_path=checkpoint_path, | |
verbose=verbose, | |
) | |
del checkpoint | |
gc.collect() | |
if verbose: | |
print("load model complete") | |
return model_data |