yerang commited on
Commit
b37595d
ยท
verified ยท
1 Parent(s): 85f2666

Create model_ori.py

Browse files
stf/stf-api-alternative/src/stf_alternative/model_ori.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import errno
2
+ import gc
3
+ import os
4
+ import sys
5
+
6
+ import torch
7
+
8
+ # from .s2f_dir.src.speech_encoder.WavLM import WavLM, WavLMConfig
9
+ from transformers import Wav2Vec2FeatureExtractor, WavLMModel
10
+
11
+ from .s2f_dir.src import autoencoder as ae
12
+ from .util import *
13
+
14
+ g_fix_seed = False
15
+ g_audio_processor = None
16
+ g_audio_encoder = None
17
+
18
+
19
+ class ModelInfo:
20
+ def __init__(
21
+ self,
22
+ model,
23
+ audio_processor,
24
+ audio_encoder,
25
+ args,
26
+ device,
27
+ work_root_path,
28
+ config_path,
29
+ checkpoint_path,
30
+ verbose=False,
31
+ ):
32
+ self.model = model
33
+ self.audio_processor = audio_processor
34
+ self.audio_encoder = audio_encoder
35
+ self.args = args
36
+ self.device = device
37
+ # snow : ์•„๋ž˜๋Š” debuging ์„ ์œ„ํ•ด ์ €์žฅํ•ด ๋‘๋Š” ๊ฒƒ
38
+ self.work_root_path = work_root_path
39
+ self.config_path = config_path
40
+ self.checkpoint_path = checkpoint_path
41
+ self.verbose = verbose
42
+
43
+ def __del__(self):
44
+ if self.verbose:
45
+ print("del model , gc:", sys.getrefcount(self.model))
46
+ del self.model
47
+ if self.args.model_type == "stf_v3":
48
+ del self.audio_encoder
49
+ del self.audio_processor
50
+
51
+
52
+ def __init_fix_seed(random_seed, verbose=False):
53
+ global g_fix_seed
54
+ if g_fix_seed == True:
55
+ return
56
+
57
+ if verbose:
58
+ print("fix seed")
59
+ fix_seed(random_seed)
60
+ g_fix_seed = True
61
+
62
+
63
+ def create_model(
64
+ config_path, checkpoint_path, work_root_path, device, verbose=False, wavlm_path=None
65
+ ):
66
+ __init_fix_seed(random_seed=1234, verbose=verbose)
67
+ global g_audio_encoder
68
+ global g_audio_processor
69
+ if verbose:
70
+ print("load model")
71
+
72
+ if not os.path.exists(config_path):
73
+ raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), config_path)
74
+
75
+ args = read_config(config_path)
76
+ if args.model_type and args.model_type == "remote":
77
+ return ModelInfo(
78
+ model=None,
79
+ audio_processor=None,
80
+ audio_encoder=None,
81
+ args=args,
82
+ device=device,
83
+ work_root_path=work_root_path,
84
+ config_path=config_path,
85
+ checkpoint_path=checkpoint_path,
86
+ verbose=verbose,
87
+ )
88
+
89
+ if not os.path.exists(checkpoint_path):
90
+ raise FileNotFoundError(
91
+ errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_path
92
+ )
93
+
94
+ if args.model_type:
95
+ model = ae.Speech2Face(
96
+ 3,
97
+ (3, args.img_size, args.img_size),
98
+ (1, 96, args.mel_step_size),
99
+ args.model_type,
100
+ )
101
+ else:
102
+ model = ae.Speech2Face(
103
+ 3, (3, args.img_size, args.img_size), (1, 96, args.mel_step_size), "stf_v1"
104
+ )
105
+
106
+ if len(args.model_type) == 0: # snow: ๋‚˜์ค‘์— ์ƒ๊ธด ์„ค์ •์ด์–ด์„œ ์ด ํ•ญ๋ชฉ์ด ์—†์„ ์ˆ˜๊ฐ€ ์žˆ๋‹ค.
107
+ args.model_type = "stf_v1"
108
+
109
+ if args.model_type == "stf_v3":
110
+ if g_audio_encoder == None:
111
+ if wavlm_path is None:
112
+ wavlm_path = f"{Path(__file__).parent.parent}/hf_wavlm"
113
+
114
+ if verbose:
115
+ print(f"@@@@@@@@@@@@@@@@@@ {wavlm_path}")
116
+ g_audio_processor = Wav2Vec2FeatureExtractor.from_pretrained(wavlm_path)
117
+ g_audio_encoder = WavLMModel.from_pretrained(wavlm_path)
118
+
119
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
120
+ if "state_dict" in checkpoint:
121
+ model.load_state_dict(checkpoint["state_dict"])
122
+ else:
123
+ model.load_state_dict(checkpoint)
124
+ if device == "cuda" and torch.cuda.device_count() > 1:
125
+ gpus = list(range(torch.cuda.device_count()))
126
+ print("Multi GPU activate, gpus : ", gpus)
127
+ model = torch.nn.DataParallel(model, device_ids=gpus)
128
+ model.to(device)
129
+ model.eval()
130
+
131
+ if args.model_type == "stf_v3":
132
+ g_audio_encoder = torch.nn.DataParallel(g_audio_encoder, device_ids=gpus)
133
+ g_audio_encoder.to(device)
134
+ g_audio_encoder.eval()
135
+ else:
136
+ model.to(device).eval()
137
+ if args.model_type == "stf_v3":
138
+ g_audio_encoder.to(device).eval()
139
+
140
+ model_data = ModelInfo(
141
+ model=model,
142
+ audio_processor=g_audio_processor,
143
+ audio_encoder=g_audio_encoder,
144
+ args=args,
145
+ device=device,
146
+ work_root_path=work_root_path,
147
+ config_path=config_path,
148
+ checkpoint_path=checkpoint_path,
149
+ verbose=verbose,
150
+ )
151
+ del checkpoint
152
+ gc.collect()
153
+ if verbose:
154
+ print("load model complete")
155
+
156
+ return model_data