File size: 6,703 Bytes
b46f992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import torch
import torch.nn as nn
import whisper
from whisper.model import AudioEncoder, ModelDimensions
from typing import Dict, Optional
from whisperspeech.vq_stoks import RQBottleneckTransformer, Tunables
from huggingface_hub import hf_hub_download
import torch.nn.functional as F
import os
from typing import List, Optional, Union
import io
import urllib
from tqdm import tqdm
import torchaudio

_HF_MODELS = {
    "medium": "https://huggingface.co/jan-hq/WhisperVQ/resolve/main/medium_encoder_only.pt",
}


def available_models() -> List[str]:
    """Returns the names of available models"""
    return list(_HF_MODELS.keys())


def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
    os.makedirs(root, exist_ok=True)

    expected_sha256 = url.split("/")[-2]
    download_target = os.path.join(root, os.path.basename(url))

    if os.path.exists(download_target) and not os.path.isfile(download_target):
        raise RuntimeError(
            f"{download_target} exists and is not a regular file")

    if os.path.isfile(download_target):
        with open(download_target, "rb") as f:
            model_bytes = f.read()
        return model_bytes if in_memory else download_target
    import ssl
    ssl._create_default_https_context = ssl._create_unverified_context
    with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
        with tqdm(
            total=int(source.info().get("Content-Length")),
            ncols=80,
            unit="iB",
            unit_scale=True,
            unit_divisor=1024,
        ) as loop:
            while True:
                buffer = source.read(8192)
                if not buffer:
                    break

                output.write(buffer)
                loop.update(len(buffer))

    model_bytes = open(download_target, "rb").read()
    return model_bytes if in_memory else download_target


class CustomWhisperEncoder(nn.Module):
    """
    Lightweight wrapper that only loads the AudioEncoder part of Whisper
    """

    def __init__(self, name: str, device: str = None, download_root: str = None, in_memory: bool = False,):
        super().__init__()
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        if download_root is None:
            default = os.path.join(os.path.expanduser("~"), ".cache")
            # os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
            download_root = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))

        if name in _HF_MODELS:
            checkpoint_file = _download(
                _HF_MODELS[name], download_root, in_memory)
        elif os.path.isfile(name):
            checkpoint_file = open(name, "rb").read() if in_memory else name
        else:
            raise RuntimeError(
                f"Model {name} not found; available models = {available_models()}"
            )

        # Load weights
        with (
            io.BytesIO(checkpoint_file) if in_memory else open(
                checkpoint_file, "rb")
        ) as fp:
            checkpoint = torch.load(fp, map_location=device)
        del checkpoint_file
        dims = ModelDimensions(**checkpoint["dims"])
        self.encoder = AudioEncoder(
            dims.n_mels,
            dims.n_audio_ctx,
            dims.n_audio_state,
            dims.n_audio_head,
            dims.n_audio_layer,
        )

        self.encoder.load_state_dict(checkpoint["model_state_dict"])

        if device:
            self.to(device)

        self.eval()

    def forward(self, mel: torch.Tensor):
        return self.encoder(mel)


class CustomRQBottleneckTransformer(RQBottleneckTransformer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    @classmethod
    def load_vq_only(cls, ref="collabora/spear-tts-pytorch:whisper-vq-stoks-medium-en+pl.model",
                     repo_id=None, filename=None, local_filename=None):
        if repo_id is None and filename is None and local_filename is None:
            if ":" in ref:
                repo_id, filename = ref.split(":", 1)
            else:
                local_filename = ref
        if not local_filename:
            local_filename = hf_hub_download(
                repo_id=repo_id, filename=filename)

        # Load the spec
        spec = torch.load(local_filename)

        # Create instance with minimal required components
        instance = cls(**spec['config'], tunables=Tunables(**
                       Tunables.upgrade(spec.get('tunables', {}))))

        # Load only necessary state dict entries
        required_components = {
            'rq', 'mlp', 'mlp_ln'
        }
        filtered_state_dict = {
            k: v for k, v in spec['state_dict'].items()
            if any(k.startswith(comp) for comp in required_components)
        }

        instance.load_state_dict(filtered_state_dict, strict=False)
        instance.eval()
        return instance

    def load_encoder(self, device=None):
        if self.whmodel is not None:
            return
        device = device or self.device
        # Use our custom encoder-only model
        if self.whmodel is None:
            encoder = CustomWhisperEncoder(
                self.whisper_model_name, device=device)
            self.whmodel = encoder
        multilingual = not self.whisper_model_name.endswith('.en')
        self.tokenizer = whisper.tokenizer.get_tokenizer(multilingual)

    def optimzed_encode_mel(self, mel):
        assert len(
            mel.shape) == 3, "invalid mel spectrogram shape, expect (batch,chn,time)"
        self.load_encoder()
        n = mel.shape[-1]
        if n > whisper.audio.N_FRAMES:
            padding = 0
            padded = mel[:, :, :whisper.audio.N_FRAMES]
        else:
            padding = -n % whisper.audio.N_FRAMES
            padded = F.pad(mel, (0, padding), value=-1.5)
        # .to(self.whmodel[0].device))#[:,:n//2]
        embs = self.whmodel.encoder(padded)
        stoks = self.quantize(embs)
        if self.tunables.mask_embs:
            return stoks[:, :n//2//self.downsample]
        else:
            return stoks
    # overide

    def encode_audio(self, audio):
        if isinstance(audio, str):
            x, sr = torchaudio.load(audio)
            x = torchaudio.transforms.Resample(sr, 16000)(x)[0]
            audio = x.unsqueeze(0)
        return self.optimzed_encode_mel(self.log_mel_spectrogram(audio).to(self.device))


if __name__ == "__main__":
    # Load the model
    vqmodel = CustomRQBottleneckTransformer.load_vq_only(
        "whisper-vq-stoks-v3-7lang-fixed.model"
    ).to("cuda")
    vqmodel.load_encoder('cuda')
    vqmodel.eval()