File size: 6,703 Bytes
b46f992 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import torch
import torch.nn as nn
import whisper
from whisper.model import AudioEncoder, ModelDimensions
from typing import Dict, Optional
from whisperspeech.vq_stoks import RQBottleneckTransformer, Tunables
from huggingface_hub import hf_hub_download
import torch.nn.functional as F
import os
from typing import List, Optional, Union
import io
import urllib
from tqdm import tqdm
import torchaudio
_HF_MODELS = {
"medium": "https://huggingface.co/jan-hq/WhisperVQ/resolve/main/medium_encoder_only.pt",
}
def available_models() -> List[str]:
"""Returns the names of available models"""
return list(_HF_MODELS.keys())
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
os.makedirs(root, exist_ok=True)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, os.path.basename(url))
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(
f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
with open(download_target, "rb") as f:
model_bytes = f.read()
return model_bytes if in_memory else download_target
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
model_bytes = open(download_target, "rb").read()
return model_bytes if in_memory else download_target
class CustomWhisperEncoder(nn.Module):
"""
Lightweight wrapper that only loads the AudioEncoder part of Whisper
"""
def __init__(self, name: str, device: str = None, download_root: str = None, in_memory: bool = False,):
super().__init__()
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None:
default = os.path.join(os.path.expanduser("~"), ".cache")
# os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
download_root = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
if name in _HF_MODELS:
checkpoint_file = _download(
_HF_MODELS[name], download_root, in_memory)
elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name
else:
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}"
)
# Load weights
with (
io.BytesIO(checkpoint_file) if in_memory else open(
checkpoint_file, "rb")
) as fp:
checkpoint = torch.load(fp, map_location=device)
del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"])
self.encoder = AudioEncoder(
dims.n_mels,
dims.n_audio_ctx,
dims.n_audio_state,
dims.n_audio_head,
dims.n_audio_layer,
)
self.encoder.load_state_dict(checkpoint["model_state_dict"])
if device:
self.to(device)
self.eval()
def forward(self, mel: torch.Tensor):
return self.encoder(mel)
class CustomRQBottleneckTransformer(RQBottleneckTransformer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@classmethod
def load_vq_only(cls, ref="collabora/spear-tts-pytorch:whisper-vq-stoks-medium-en+pl.model",
repo_id=None, filename=None, local_filename=None):
if repo_id is None and filename is None and local_filename is None:
if ":" in ref:
repo_id, filename = ref.split(":", 1)
else:
local_filename = ref
if not local_filename:
local_filename = hf_hub_download(
repo_id=repo_id, filename=filename)
# Load the spec
spec = torch.load(local_filename)
# Create instance with minimal required components
instance = cls(**spec['config'], tunables=Tunables(**
Tunables.upgrade(spec.get('tunables', {}))))
# Load only necessary state dict entries
required_components = {
'rq', 'mlp', 'mlp_ln'
}
filtered_state_dict = {
k: v for k, v in spec['state_dict'].items()
if any(k.startswith(comp) for comp in required_components)
}
instance.load_state_dict(filtered_state_dict, strict=False)
instance.eval()
return instance
def load_encoder(self, device=None):
if self.whmodel is not None:
return
device = device or self.device
# Use our custom encoder-only model
if self.whmodel is None:
encoder = CustomWhisperEncoder(
self.whisper_model_name, device=device)
self.whmodel = encoder
multilingual = not self.whisper_model_name.endswith('.en')
self.tokenizer = whisper.tokenizer.get_tokenizer(multilingual)
def optimzed_encode_mel(self, mel):
assert len(
mel.shape) == 3, "invalid mel spectrogram shape, expect (batch,chn,time)"
self.load_encoder()
n = mel.shape[-1]
if n > whisper.audio.N_FRAMES:
padding = 0
padded = mel[:, :, :whisper.audio.N_FRAMES]
else:
padding = -n % whisper.audio.N_FRAMES
padded = F.pad(mel, (0, padding), value=-1.5)
# .to(self.whmodel[0].device))#[:,:n//2]
embs = self.whmodel.encoder(padded)
stoks = self.quantize(embs)
if self.tunables.mask_embs:
return stoks[:, :n//2//self.downsample]
else:
return stoks
# overide
def encode_audio(self, audio):
if isinstance(audio, str):
x, sr = torchaudio.load(audio)
x = torchaudio.transforms.Resample(sr, 16000)(x)[0]
audio = x.unsqueeze(0)
return self.optimzed_encode_mel(self.log_mel_spectrogram(audio).to(self.device))
if __name__ == "__main__":
# Load the model
vqmodel = CustomRQBottleneckTransformer.load_vq_only(
"whisper-vq-stoks-v3-7lang-fixed.model"
).to("cuda")
vqmodel.load_encoder('cuda')
vqmodel.eval()
|