yerang's picture
Create model_ori.py
b37595d verified
raw
history blame
4.54 kB
import errno
import gc
import os
import sys
import torch
# from .s2f_dir.src.speech_encoder.WavLM import WavLM, WavLMConfig
from transformers import Wav2Vec2FeatureExtractor, WavLMModel
from .s2f_dir.src import autoencoder as ae
from .util import *
g_fix_seed = False
g_audio_processor = None
g_audio_encoder = None
class ModelInfo:
def __init__(
self,
model,
audio_processor,
audio_encoder,
args,
device,
work_root_path,
config_path,
checkpoint_path,
verbose=False,
):
self.model = model
self.audio_processor = audio_processor
self.audio_encoder = audio_encoder
self.args = args
self.device = device
# snow : μ•„λž˜λŠ” debuging 을 μœ„ν•΄ μ €μž₯ν•΄ λ‘λŠ” 것
self.work_root_path = work_root_path
self.config_path = config_path
self.checkpoint_path = checkpoint_path
self.verbose = verbose
def __del__(self):
if self.verbose:
print("del model , gc:", sys.getrefcount(self.model))
del self.model
if self.args.model_type == "stf_v3":
del self.audio_encoder
del self.audio_processor
def __init_fix_seed(random_seed, verbose=False):
global g_fix_seed
if g_fix_seed == True:
return
if verbose:
print("fix seed")
fix_seed(random_seed)
g_fix_seed = True
def create_model(
config_path, checkpoint_path, work_root_path, device, verbose=False, wavlm_path=None
):
__init_fix_seed(random_seed=1234, verbose=verbose)
global g_audio_encoder
global g_audio_processor
if verbose:
print("load model")
if not os.path.exists(config_path):
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), config_path)
args = read_config(config_path)
if args.model_type and args.model_type == "remote":
return ModelInfo(
model=None,
audio_processor=None,
audio_encoder=None,
args=args,
device=device,
work_root_path=work_root_path,
config_path=config_path,
checkpoint_path=checkpoint_path,
verbose=verbose,
)
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(
errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_path
)
if args.model_type:
model = ae.Speech2Face(
3,
(3, args.img_size, args.img_size),
(1, 96, args.mel_step_size),
args.model_type,
)
else:
model = ae.Speech2Face(
3, (3, args.img_size, args.img_size), (1, 96, args.mel_step_size), "stf_v1"
)
if len(args.model_type) == 0: # snow: λ‚˜μ€‘μ— 생긴 μ„€μ •μ΄μ–΄μ„œ 이 ν•­λͺ©μ΄ 없을 μˆ˜κ°€ μžˆλ‹€.
args.model_type = "stf_v1"
if args.model_type == "stf_v3":
if g_audio_encoder == None:
if wavlm_path is None:
wavlm_path = f"{Path(__file__).parent.parent}/hf_wavlm"
if verbose:
print(f"@@@@@@@@@@@@@@@@@@ {wavlm_path}")
g_audio_processor = Wav2Vec2FeatureExtractor.from_pretrained(wavlm_path)
g_audio_encoder = WavLMModel.from_pretrained(wavlm_path)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
if "state_dict" in checkpoint:
model.load_state_dict(checkpoint["state_dict"])
else:
model.load_state_dict(checkpoint)
if device == "cuda" and torch.cuda.device_count() > 1:
gpus = list(range(torch.cuda.device_count()))
print("Multi GPU activate, gpus : ", gpus)
model = torch.nn.DataParallel(model, device_ids=gpus)
model.to(device)
model.eval()
if args.model_type == "stf_v3":
g_audio_encoder = torch.nn.DataParallel(g_audio_encoder, device_ids=gpus)
g_audio_encoder.to(device)
g_audio_encoder.eval()
else:
model.to(device).eval()
if args.model_type == "stf_v3":
g_audio_encoder.to(device).eval()
model_data = ModelInfo(
model=model,
audio_processor=g_audio_processor,
audio_encoder=g_audio_encoder,
args=args,
device=device,
work_root_path=work_root_path,
config_path=config_path,
checkpoint_path=checkpoint_path,
verbose=verbose,
)
del checkpoint
gc.collect()
if verbose:
print("load model complete")
return model_data