File size: 4,544 Bytes
e3af00f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import errno
import gc
import os
import sys

import torch

# from .s2f_dir.src.speech_encoder.WavLM import WavLM, WavLMConfig
from transformers import Wav2Vec2FeatureExtractor, WavLMModel

from .s2f_dir.src import autoencoder as ae
from .util import *

g_fix_seed = False
g_audio_processor = None
g_audio_encoder = None


class ModelInfo:
    def __init__(
        self,
        model,
        audio_processor,
        audio_encoder,
        args,
        device,
        work_root_path,
        config_path,
        checkpoint_path,
        verbose=False,
    ):
        self.model = model
        self.audio_processor = audio_processor
        self.audio_encoder = audio_encoder
        self.args = args
        self.device = device
        # snow : μ•„λž˜λŠ” debuging 을 μœ„ν•΄ μ €μž₯ν•΄ λ‘λŠ” 것
        self.work_root_path = work_root_path
        self.config_path = config_path
        self.checkpoint_path = checkpoint_path
        self.verbose = verbose

    def __del__(self):
        if self.verbose:
            print("del model , gc:", sys.getrefcount(self.model))
        del self.model
        if self.args.model_type == "stf_v3":
            del self.audio_encoder
            del self.audio_processor


def __init_fix_seed(random_seed, verbose=False):
    global g_fix_seed
    if g_fix_seed == True:
        return

    if verbose:
        print("fix seed")
    fix_seed(random_seed)
    g_fix_seed = True


def create_model(
    config_path, checkpoint_path, work_root_path, device, verbose=False, wavlm_path=None
):
    __init_fix_seed(random_seed=1234, verbose=verbose)
    global g_audio_encoder
    global g_audio_processor
    if verbose:
        print("load model")

    if not os.path.exists(config_path):
        raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), config_path)

    args = read_config(config_path)
    if args.model_type and args.model_type == "remote":
        return ModelInfo(
            model=None,
            audio_processor=None,
            audio_encoder=None,
            args=args,
            device=device,
            work_root_path=work_root_path,
            config_path=config_path,
            checkpoint_path=checkpoint_path,
            verbose=verbose,
        )

    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(
            errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_path
        )

    if args.model_type:
        model = ae.Speech2Face(
            3,
            (3, args.img_size, args.img_size),
            (1, 96, args.mel_step_size),
            args.model_type,
        )
    else:
        model = ae.Speech2Face(
            3, (3, args.img_size, args.img_size), (1, 96, args.mel_step_size), "stf_v1"
        )

    if len(args.model_type) == 0:  # snow: λ‚˜μ€‘μ— 생긴 μ„€μ •μ΄μ–΄μ„œ 이 ν•­λͺ©μ΄ 없을 μˆ˜κ°€ μžˆλ‹€.
        args.model_type = "stf_v1"

    if args.model_type == "stf_v3":
        if g_audio_encoder == None:
            if wavlm_path is None:
                wavlm_path = f"{Path(__file__).parent.parent}/hf_wavlm"

            if verbose:
                print(f"@@@@@@@@@@@@@@@@@@ {wavlm_path}")
            g_audio_processor = Wav2Vec2FeatureExtractor.from_pretrained(wavlm_path)
            g_audio_encoder = WavLMModel.from_pretrained(wavlm_path)

    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    if "state_dict" in checkpoint:
        model.load_state_dict(checkpoint["state_dict"])
    else:
        model.load_state_dict(checkpoint)
    if device == "cuda" and torch.cuda.device_count() > 1:
        gpus = list(range(torch.cuda.device_count()))
        print("Multi GPU activate, gpus : ", gpus)
        model = torch.nn.DataParallel(model, device_ids=gpus)
        model.to(device)
        model.eval()

        if args.model_type == "stf_v3":
            g_audio_encoder = torch.nn.DataParallel(g_audio_encoder, device_ids=gpus)
            g_audio_encoder.to(device)
            g_audio_encoder.eval()
    else:
        model.to(device).eval()
        if args.model_type == "stf_v3":
            g_audio_encoder.to(device).eval()

    model_data = ModelInfo(
        model=model,
        audio_processor=g_audio_processor,
        audio_encoder=g_audio_encoder,
        args=args,
        device=device,
        work_root_path=work_root_path,
        config_path=config_path,
        checkpoint_path=checkpoint_path,
        verbose=verbose,
    )
    del checkpoint
    gc.collect()
    if verbose:
        print("load model complete")

    return model_data