Spaces:
Sleeping
Sleeping
File size: 4,544 Bytes
e3af00f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import errno
import gc
import os
import sys
import torch
# from .s2f_dir.src.speech_encoder.WavLM import WavLM, WavLMConfig
from transformers import Wav2Vec2FeatureExtractor, WavLMModel
from .s2f_dir.src import autoencoder as ae
from .util import *
g_fix_seed = False
g_audio_processor = None
g_audio_encoder = None
class ModelInfo:
def __init__(
self,
model,
audio_processor,
audio_encoder,
args,
device,
work_root_path,
config_path,
checkpoint_path,
verbose=False,
):
self.model = model
self.audio_processor = audio_processor
self.audio_encoder = audio_encoder
self.args = args
self.device = device
# snow : μλλ debuging μ μν΄ μ μ₯ν΄ λλ κ²
self.work_root_path = work_root_path
self.config_path = config_path
self.checkpoint_path = checkpoint_path
self.verbose = verbose
def __del__(self):
if self.verbose:
print("del model , gc:", sys.getrefcount(self.model))
del self.model
if self.args.model_type == "stf_v3":
del self.audio_encoder
del self.audio_processor
def __init_fix_seed(random_seed, verbose=False):
global g_fix_seed
if g_fix_seed == True:
return
if verbose:
print("fix seed")
fix_seed(random_seed)
g_fix_seed = True
def create_model(
config_path, checkpoint_path, work_root_path, device, verbose=False, wavlm_path=None
):
__init_fix_seed(random_seed=1234, verbose=verbose)
global g_audio_encoder
global g_audio_processor
if verbose:
print("load model")
if not os.path.exists(config_path):
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), config_path)
args = read_config(config_path)
if args.model_type and args.model_type == "remote":
return ModelInfo(
model=None,
audio_processor=None,
audio_encoder=None,
args=args,
device=device,
work_root_path=work_root_path,
config_path=config_path,
checkpoint_path=checkpoint_path,
verbose=verbose,
)
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(
errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_path
)
if args.model_type:
model = ae.Speech2Face(
3,
(3, args.img_size, args.img_size),
(1, 96, args.mel_step_size),
args.model_type,
)
else:
model = ae.Speech2Face(
3, (3, args.img_size, args.img_size), (1, 96, args.mel_step_size), "stf_v1"
)
if len(args.model_type) == 0: # snow: λμ€μ μκΈ΄ μ€μ μ΄μ΄μ μ΄ νλͺ©μ΄ μμ μκ° μλ€.
args.model_type = "stf_v1"
if args.model_type == "stf_v3":
if g_audio_encoder == None:
if wavlm_path is None:
wavlm_path = f"{Path(__file__).parent.parent}/hf_wavlm"
if verbose:
print(f"@@@@@@@@@@@@@@@@@@ {wavlm_path}")
g_audio_processor = Wav2Vec2FeatureExtractor.from_pretrained(wavlm_path)
g_audio_encoder = WavLMModel.from_pretrained(wavlm_path)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
if "state_dict" in checkpoint:
model.load_state_dict(checkpoint["state_dict"])
else:
model.load_state_dict(checkpoint)
if device == "cuda" and torch.cuda.device_count() > 1:
gpus = list(range(torch.cuda.device_count()))
print("Multi GPU activate, gpus : ", gpus)
model = torch.nn.DataParallel(model, device_ids=gpus)
model.to(device)
model.eval()
if args.model_type == "stf_v3":
g_audio_encoder = torch.nn.DataParallel(g_audio_encoder, device_ids=gpus)
g_audio_encoder.to(device)
g_audio_encoder.eval()
else:
model.to(device).eval()
if args.model_type == "stf_v3":
g_audio_encoder.to(device).eval()
model_data = ModelInfo(
model=model,
audio_processor=g_audio_processor,
audio_encoder=g_audio_encoder,
args=args,
device=device,
work_root_path=work_root_path,
config_path=config_path,
checkpoint_path=checkpoint_path,
verbose=verbose,
)
del checkpoint
gc.collect()
if verbose:
print("load model complete")
return model_data
|