chat-tts / Chat2TTS /core.py
chenjgtea
尝试剔除cpu标识
c333340
import os
import logging
from omegaconf import OmegaConf
import torch
from vocos import Vocos
from .model.dvae import DVAE
from .model.gpt import GPT_warpper
from .utils.gpu_utils import select_device
from .utils.io_utils import get_latest_modified_file
from .infer.api import refine_text, infer_code
from dataclasses import dataclass
from typing import Literal, Optional, List, Tuple, Dict, Union
import numpy as np
from tool.logger import get_logger
from tool.normalizer import normalizer_en_nemo_text, normalizer_cn_tn
from tool.func import encode_prompt
from ChatTTS.norm import Normalizer
from huggingface_hub import snapshot_download
class Chat:
def __init__(self, ):
self.pretrain_models = {}
self.logger = get_logger(__name__, lv=logging.INFO)
self.normalizer = Normalizer(
os.path.join(os.path.dirname(__file__), "res", "homophones_map.json"),
self.logger,
)
def check_model(self, level=logging.INFO, use_decoder=False):
not_finish = False
check_list = ['vocos', 'gpt', 'tokenizer']
if use_decoder:
check_list.append('decoder')
else:
check_list.append('dvae')
for module in check_list:
if module not in self.pretrain_models:
self.logger.log(logging.WARNING, f'{module} not initialized.')
not_finish = True
if not not_finish:
self.logger.log(level, f'All initialized.')
return not not_finish
def load_models(self, source='huggingface', force_redownload=False, local_path='<LOCAL_PATH>'):
if source == 'huggingface':
hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface"))
try:
download_path = get_latest_modified_file(os.path.join(hf_home, 'hub/models--2Noise--ChatTTS/snapshots'))
except:
download_path = None
if download_path is None or force_redownload:
self.logger.log(logging.INFO, f'Download from HF: https://huggingface.co/2Noise/ChatTTS')
download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"])
else:
self.logger.log(logging.INFO, f'Load from cache: {download_path}')
self._load(**{k: os.path.join(download_path, v) for k, v in
OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()})
self._regist_normalizer()
elif source == 'local':
self.logger.log(logging.INFO, f'Load from local: {local_path}')
self._load(**{k: os.path.join(local_path, v) for k, v in
OmegaConf.load(os.path.join(local_path, 'config', 'path.yaml')).items()})
def _regist_normalizer(self):
self.logger.info("==========开始注册 normalizer===========")
try:
self.normalizer.register("en", normalizer_en_nemo_text())
except ValueError as e:
self.logger.error('normalizer_en_nemo_text register fail', e)
except:
self.logger.error("Package nemo_text_processing not found!")
self.logger.error(
"Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing",
)
try:
self.normalizer.register("zh", normalizer_cn_tn())
except ValueError as e:
self.logger.error('normalizer_cn_tn register fail', e)
except:
self.logger.error("Package WeTextProcessing not found!")
self.logger.error(
"Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing",
)
def _load(
self,
vocos_config_path: str = None,
vocos_ckpt_path: str = None,
dvae_config_path: str = None,
dvae_ckpt_path: str = None,
gpt_config_path: str = None,
gpt_ckpt_path: str = None,
decoder_config_path: str = None,
decoder_ckpt_path: str = None,
tokenizer_path: str = None,
device: str = None
):
if not device:
device = select_device(4096)
self.logger.log(logging.INFO, f'use {device}')
self.device = device
if vocos_config_path:
vocos = Vocos.from_hparams(vocos_config_path).to(device).eval()
assert vocos_ckpt_path, 'vocos_ckpt_path should not be None'
vocos.load_state_dict(torch.load(vocos_ckpt_path))
self.pretrain_models['vocos'] = vocos
self.logger.log(logging.INFO, 'vocos loaded.')
if dvae_config_path:
cfg = OmegaConf.load(dvae_config_path)
dvae = DVAE(**cfg).to(device).eval()
assert dvae_ckpt_path, 'dvae_ckpt_path should not be None'
dvae.load_state_dict(torch.load(dvae_ckpt_path))
self.pretrain_models['dvae'] = dvae
self.logger.log(logging.INFO, 'dvae loaded.')
if gpt_config_path:
cfg = OmegaConf.load(gpt_config_path)
gpt = GPT_warpper(**cfg).to(device).eval()
assert gpt_ckpt_path, 'gpt_ckpt_path should not be None'
gpt.load_state_dict(torch.load(gpt_ckpt_path))
self.pretrain_models['gpt'] = gpt
self.gpt = gpt
self.logger.log(logging.INFO, 'gpt loaded.')
spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt")
assert os.path.exists(
spk_stat_path
), f"Missing spk_stat.pt: {spk_stat_path}"
self.pretrain_models["spk_stat"] = torch.load(
spk_stat_path, weights_only=True, mmap=True
).to(device)
if decoder_config_path:
cfg = OmegaConf.load(decoder_config_path)
decoder = DVAE(**cfg).to(device).eval()
assert decoder_ckpt_path, 'decoder_ckpt_path should not be None'
decoder.load_state_dict(torch.load(decoder_ckpt_path))
self.pretrain_models['decoder'] = decoder
self.logger.log(logging.INFO, 'decoder loaded.')
if tokenizer_path:
tokenizer = torch.load(tokenizer_path)
tokenizer.padding_side = 'left'
self.pretrain_models['tokenizer'] = tokenizer
self.logger.log(logging.INFO, 'tokenizer loaded.')
self.check_model()
@dataclass(repr=False, eq=False)
class RefineTextParams:
prompt: str = ""
top_P: float = 0.7
top_K: int = 20
temperature: float = 0.7
repetition_penalty: float = 1.0
max_new_token: int = 384
min_new_token: int = 0
show_tqdm: bool = True
ensure_non_empty: bool = True
@dataclass(repr=False, eq=False)
class InferCodeParams(RefineTextParams):
prompt: str = "[speed_5]"
spk_emb: Optional[str] = None
temperature: float = 0.3
repetition_penalty: float = 1.05
max_new_token: int = 2048
def infer(
self,
text,
skip_refine_text=False,
refine_text_only=False,
params_refine_text={},
params_infer_code={},
use_decoder=False,
lang=None
):
self.logger.info(
f"========开始infer模型,use_decoder:{use_decoder},lang:{lang},"
f"mskip_refine_text:{skip_refine_text},refine_text_only:{refine_text_only}======")
assert self.check_model(use_decoder=use_decoder)
if not isinstance(text, list):
text = [text]
text = [
self.normalizer(
text=t,
do_text_normalization=True,
do_homophone_replacement=True,
lang=lang,
)
for t in text
]
if skip_refine_text:
self.logger.info(f"========对文本内容不做优化处理,仅做规则处理======")
else:
self.logger.info(f"========针对文本内容做模型优化处理,lang:{lang}======")
text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in
text_tokens]
text = self.pretrain_models['tokenizer'].batch_decode(text_tokens)
if refine_text_only:
return text
text = [params_infer_code.get('prompt', '') + i for i in text]
params_infer_code.pop('prompt', '')
result = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder)
if use_decoder:
mel_spec = [self.pretrain_models['decoder'](i[None].permute(0, 2, 1)) for i in result['hiddens']]
else:
mel_spec = [self.pretrain_models['dvae'](i[None].permute(0, 2, 1)) for i in result['ids']]
wav = [self.pretrain_models['vocos'].decode(i).cpu().numpy() for i in mel_spec]
return wav
# 返回一个空的wav 音频文件
def emptpy_audio(self):
return self.infer(" ",
skip_refine_text=True,
refine_text_only=False,
params_refine_text={},
params_infer_code={},
use_decoder=False)
'''
将音频张量 做转码处理
'''
@torch.inference_mode()
def sample_audio_speaker(self, wav: Union[np.ndarray, torch.Tensor]) -> str:
if isinstance(wav, np.ndarray):
wav = torch.from_numpy(wav).to(self.device)
squeeze = self.pretrain_models['dvae'](wav, "encode").squeeze_(0)
return encode_prompt(squeeze)
# def sample_random_speaker(self) -> str:
# return self._encode_spk_emb(self.sample_random_speaker_tensor())
#
#
# @staticmethod
# def _encode_spk_emb(spk_emb: torch.Tensor) -> str:
# with torch.no_grad():
# arr: np.ndarray = spk_emb.to(dtype=torch.float16, device="cpu").numpy()
# s = b14.encode_to_string(
# lzma.compress(
# arr.tobytes(),
# format=lzma.FORMAT_RAW,
# filters=[
# {"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}
# ],
# ),
# )
# del arr
# return s
def sample_random_speaker_tensor(self) -> torch.Tensor:
with torch.no_grad():
dim: int = self.gpt.gpt.layers[0].mlp.gate_proj.in_features
out: torch.Tensor = self.pretrain_models["spk_stat"]
std, mean = out.chunk(2)
spk = (
torch.randn(dim, device=std.device, dtype=torch.float16)
.mul_(std)
.add_(mean)
)
del out, std, mean
return spk