chenjgtea commited on
Commit
a536c15
1 Parent(s): 0699795

拆分gpu、cpu模式运行模式

Browse files
Files changed (9) hide show
  1. Chat2TTS/core.py +78 -61
  2. test/audio_test.py +48 -0
  3. test/common_test.py +1 -1
  4. tool/__init__.py +1 -2
  5. tool/func.py +29 -2
  6. tool/np.py +19 -2
  7. tool/pcm.py +0 -21
  8. web/app_cpu.py +1 -1
  9. web/app_gpu.py +31 -20
Chat2TTS/core.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import os
3
  import logging
4
  from omegaconf import OmegaConf
@@ -11,9 +10,11 @@ from .utils.gpu_utils import select_device
11
  from .utils.io_utils import get_latest_modified_file
12
  from .infer.api import refine_text, infer_code
13
  from dataclasses import dataclass
14
- from typing import Literal, Optional, List, Tuple, Dict
 
15
  from tool.logger import get_logger
16
- from tool.normalizer import normalizer_en_nemo_text,normalizer_cn_tn
 
17
 
18
  from ChatTTS.norm import Normalizer
19
 
@@ -23,31 +24,31 @@ from huggingface_hub import snapshot_download
23
  class Chat:
24
  def __init__(self, ):
25
  self.pretrain_models = {}
26
- self.logger = get_logger(__name__,lv=logging.INFO)
27
  self.normalizer = Normalizer(
28
  os.path.join(os.path.dirname(__file__), "res", "homophones_map.json"),
29
  self.logger,
30
  )
31
-
32
- def check_model(self, level = logging.INFO, use_decoder = False):
33
  not_finish = False
34
  check_list = ['vocos', 'gpt', 'tokenizer']
35
-
36
  if use_decoder:
37
  check_list.append('decoder')
38
  else:
39
  check_list.append('dvae')
40
-
41
  for module in check_list:
42
  if module not in self.pretrain_models:
43
  self.logger.log(logging.WARNING, f'{module} not initialized.')
44
  not_finish = True
45
-
46
  if not not_finish:
47
  self.logger.log(level, f'All initialized.')
48
-
49
  return not not_finish
50
-
51
  def load_models(self, source='huggingface', force_redownload=False, local_path='<LOCAL_PATH>'):
52
  if source == 'huggingface':
53
  hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface"))
@@ -55,25 +56,27 @@ class Chat:
55
  download_path = get_latest_modified_file(os.path.join(hf_home, 'hub/models--2Noise--ChatTTS/snapshots'))
56
  except:
57
  download_path = None
58
- if download_path is None or force_redownload:
59
  self.logger.log(logging.INFO, f'Download from HF: https://huggingface.co/2Noise/ChatTTS')
60
  download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"])
61
  else:
62
  self.logger.log(logging.INFO, f'Load from cache: {download_path}')
63
- self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()})
 
64
  self._regist_normalizer()
65
  elif source == 'local':
66
  self.logger.log(logging.INFO, f'Load from local: {local_path}')
67
- self._load(**{k: os.path.join(local_path, v) for k, v in OmegaConf.load(os.path.join(local_path, 'config', 'path.yaml')).items()})
 
68
 
69
  def _regist_normalizer(self):
70
 
71
  self.logger.info("==========开始注册 normalizer===========")
72
 
73
  try:
74
- self.normalizer.register("en",normalizer_en_nemo_text())
75
  except ValueError as e:
76
- self.logger.error('normalizer_en_nemo_text register fail' , e)
77
  except:
78
  self.logger.error("Package nemo_text_processing not found!")
79
  self.logger.error(
@@ -81,40 +84,40 @@ class Chat:
81
  )
82
 
83
  try:
84
- self.normalizer.register("zh",normalizer_cn_tn())
85
  except ValueError as e:
86
- self.logger.error('normalizer_cn_tn register fail' , e)
87
  except:
88
  self.logger.error("Package WeTextProcessing not found!")
89
  self.logger.error(
90
  "Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing",
91
  )
92
 
93
-
94
  def _load(
95
- self,
96
- vocos_config_path: str = None,
97
- vocos_ckpt_path: str = None,
98
- dvae_config_path: str = None,
99
- dvae_ckpt_path: str = None,
100
- gpt_config_path: str = None,
101
- gpt_ckpt_path: str = None,
102
- decoder_config_path: str = None,
103
- decoder_ckpt_path: str = None,
104
- tokenizer_path: str = None,
105
- device: str = None
106
  ):
107
  if not device:
108
  device = select_device(4096)
109
  self.logger.log(logging.INFO, f'use {device}')
110
-
 
111
  if vocos_config_path:
112
  vocos = Vocos.from_hparams(vocos_config_path).to(device).eval()
113
  assert vocos_ckpt_path, 'vocos_ckpt_path should not be None'
114
  vocos.load_state_dict(torch.load(vocos_ckpt_path))
115
  self.pretrain_models['vocos'] = vocos
116
  self.logger.log(logging.INFO, 'vocos loaded.')
117
-
118
  if dvae_config_path:
119
  cfg = OmegaConf.load(dvae_config_path)
120
  dvae = DVAE(**cfg).to(device).eval()
@@ -122,7 +125,7 @@ class Chat:
122
  dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location='cpu'))
123
  self.pretrain_models['dvae'] = dvae
124
  self.logger.log(logging.INFO, 'dvae loaded.')
125
-
126
  if gpt_config_path:
127
  cfg = OmegaConf.load(gpt_config_path)
128
  gpt = GPT_warpper(**cfg).to(device).eval()
@@ -139,7 +142,6 @@ class Chat:
139
  spk_stat_path, weights_only=True, mmap=True, map_location='cpu'
140
  ).to(device)
141
 
142
-
143
  if decoder_config_path:
144
  cfg = OmegaConf.load(decoder_config_path)
145
  decoder = DVAE(**cfg).to(device).eval()
@@ -147,13 +149,13 @@ class Chat:
147
  decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location='cpu'))
148
  self.pretrain_models['decoder'] = decoder
149
  self.logger.log(logging.INFO, 'decoder loaded.')
150
-
151
  if tokenizer_path:
152
  tokenizer = torch.load(tokenizer_path, map_location='cpu')
153
  tokenizer.padding_side = 'left'
154
  self.pretrain_models['tokenizer'] = tokenizer
155
  self.logger.log(logging.INFO, 'tokenizer loaded.')
156
-
157
  self.check_model()
158
 
159
  @dataclass(repr=False, eq=False)
@@ -177,16 +179,19 @@ class Chat:
177
  max_new_token: int = 2048
178
 
179
  def infer(
180
- self,
181
- text,
182
- skip_refine_text=False,
183
- refine_text_only=False,
184
- params_refine_text={},
185
- params_infer_code={},
186
- use_decoder=False,
187
- lang=None
188
  ):
189
-
 
 
 
190
  assert self.check_model(use_decoder=use_decoder)
191
 
192
  if not isinstance(text, list):
@@ -203,36 +208,48 @@ class Chat:
203
  ]
204
 
205
  if skip_refine_text:
206
- self.logger.info(f"========对文本内容不做优化处理,仅做规则处理,lang:{lang}======")
207
  else:
208
  self.logger.info(f"========针对文本内容做模型优化处理,lang:{lang}======")
209
  text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
210
- text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
 
211
  text = self.pretrain_models['tokenizer'].batch_decode(text_tokens)
212
  if refine_text_only:
213
  return text
214
-
215
  text = [params_infer_code.get('prompt', '') + i for i in text]
216
  params_infer_code.pop('prompt', '')
217
  result = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder)
218
-
219
  if use_decoder:
220
- mel_spec = [self.pretrain_models['decoder'](i[None].permute(0,2,1)) for i in result['hiddens']]
221
  else:
222
- mel_spec = [self.pretrain_models['dvae'](i[None].permute(0,2,1)) for i in result['ids']]
223
-
224
  wav = [self.pretrain_models['vocos'].decode(i).cpu().numpy() for i in mel_spec]
225
-
226
  return wav
227
 
 
228
  def emptpy_audio(self):
229
- return self.infer(" ",
230
- skip_refine_text=True,
231
- refine_text_only=False,
232
- params_refine_text={},
233
- params_infer_code={},
234
- use_decoder=False)
235
-
 
 
 
 
 
 
 
 
 
 
236
 
237
  # def sample_random_speaker(self) -> str:
238
  # return self._encode_spk_emb(self.sample_random_speaker_tensor())
@@ -266,4 +283,4 @@ class Chat:
266
  .add_(mean)
267
  )
268
  del out, std, mean
269
- return spk
 
 
1
  import os
2
  import logging
3
  from omegaconf import OmegaConf
 
10
  from .utils.io_utils import get_latest_modified_file
11
  from .infer.api import refine_text, infer_code
12
  from dataclasses import dataclass
13
+ from typing import Literal, Optional, List, Tuple, Dict, Union
14
+ import numpy as np
15
  from tool.logger import get_logger
16
+ from tool.normalizer import normalizer_en_nemo_text, normalizer_cn_tn
17
+ from tool.func import encode_prompt
18
 
19
  from ChatTTS.norm import Normalizer
20
 
 
24
  class Chat:
25
  def __init__(self, ):
26
  self.pretrain_models = {}
27
+ self.logger = get_logger(__name__, lv=logging.INFO)
28
  self.normalizer = Normalizer(
29
  os.path.join(os.path.dirname(__file__), "res", "homophones_map.json"),
30
  self.logger,
31
  )
32
+
33
+ def check_model(self, level=logging.INFO, use_decoder=False):
34
  not_finish = False
35
  check_list = ['vocos', 'gpt', 'tokenizer']
36
+
37
  if use_decoder:
38
  check_list.append('decoder')
39
  else:
40
  check_list.append('dvae')
41
+
42
  for module in check_list:
43
  if module not in self.pretrain_models:
44
  self.logger.log(logging.WARNING, f'{module} not initialized.')
45
  not_finish = True
46
+
47
  if not not_finish:
48
  self.logger.log(level, f'All initialized.')
49
+
50
  return not not_finish
51
+
52
  def load_models(self, source='huggingface', force_redownload=False, local_path='<LOCAL_PATH>'):
53
  if source == 'huggingface':
54
  hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface"))
 
56
  download_path = get_latest_modified_file(os.path.join(hf_home, 'hub/models--2Noise--ChatTTS/snapshots'))
57
  except:
58
  download_path = None
59
+ if download_path is None or force_redownload:
60
  self.logger.log(logging.INFO, f'Download from HF: https://huggingface.co/2Noise/ChatTTS')
61
  download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"])
62
  else:
63
  self.logger.log(logging.INFO, f'Load from cache: {download_path}')
64
+ self._load(**{k: os.path.join(download_path, v) for k, v in
65
+ OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()})
66
  self._regist_normalizer()
67
  elif source == 'local':
68
  self.logger.log(logging.INFO, f'Load from local: {local_path}')
69
+ self._load(**{k: os.path.join(local_path, v) for k, v in
70
+ OmegaConf.load(os.path.join(local_path, 'config', 'path.yaml')).items()})
71
 
72
  def _regist_normalizer(self):
73
 
74
  self.logger.info("==========开始注册 normalizer===========")
75
 
76
  try:
77
+ self.normalizer.register("en", normalizer_en_nemo_text())
78
  except ValueError as e:
79
+ self.logger.error('normalizer_en_nemo_text register fail', e)
80
  except:
81
  self.logger.error("Package nemo_text_processing not found!")
82
  self.logger.error(
 
84
  )
85
 
86
  try:
87
+ self.normalizer.register("zh", normalizer_cn_tn())
88
  except ValueError as e:
89
+ self.logger.error('normalizer_cn_tn register fail', e)
90
  except:
91
  self.logger.error("Package WeTextProcessing not found!")
92
  self.logger.error(
93
  "Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing",
94
  )
95
 
 
96
  def _load(
97
+ self,
98
+ vocos_config_path: str = None,
99
+ vocos_ckpt_path: str = None,
100
+ dvae_config_path: str = None,
101
+ dvae_ckpt_path: str = None,
102
+ gpt_config_path: str = None,
103
+ gpt_ckpt_path: str = None,
104
+ decoder_config_path: str = None,
105
+ decoder_ckpt_path: str = None,
106
+ tokenizer_path: str = None,
107
+ device: str = None
108
  ):
109
  if not device:
110
  device = select_device(4096)
111
  self.logger.log(logging.INFO, f'use {device}')
112
+
113
+ self.device = device
114
  if vocos_config_path:
115
  vocos = Vocos.from_hparams(vocos_config_path).to(device).eval()
116
  assert vocos_ckpt_path, 'vocos_ckpt_path should not be None'
117
  vocos.load_state_dict(torch.load(vocos_ckpt_path))
118
  self.pretrain_models['vocos'] = vocos
119
  self.logger.log(logging.INFO, 'vocos loaded.')
120
+
121
  if dvae_config_path:
122
  cfg = OmegaConf.load(dvae_config_path)
123
  dvae = DVAE(**cfg).to(device).eval()
 
125
  dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location='cpu'))
126
  self.pretrain_models['dvae'] = dvae
127
  self.logger.log(logging.INFO, 'dvae loaded.')
128
+
129
  if gpt_config_path:
130
  cfg = OmegaConf.load(gpt_config_path)
131
  gpt = GPT_warpper(**cfg).to(device).eval()
 
142
  spk_stat_path, weights_only=True, mmap=True, map_location='cpu'
143
  ).to(device)
144
 
 
145
  if decoder_config_path:
146
  cfg = OmegaConf.load(decoder_config_path)
147
  decoder = DVAE(**cfg).to(device).eval()
 
149
  decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location='cpu'))
150
  self.pretrain_models['decoder'] = decoder
151
  self.logger.log(logging.INFO, 'decoder loaded.')
152
+
153
  if tokenizer_path:
154
  tokenizer = torch.load(tokenizer_path, map_location='cpu')
155
  tokenizer.padding_side = 'left'
156
  self.pretrain_models['tokenizer'] = tokenizer
157
  self.logger.log(logging.INFO, 'tokenizer loaded.')
158
+
159
  self.check_model()
160
 
161
  @dataclass(repr=False, eq=False)
 
179
  max_new_token: int = 2048
180
 
181
  def infer(
182
+ self,
183
+ text,
184
+ skip_refine_text=False,
185
+ refine_text_only=False,
186
+ params_refine_text={},
187
+ params_infer_code={},
188
+ use_decoder=False,
189
+ lang=None
190
  ):
191
+
192
+ self.logger.info(
193
+ f"========开始infer模型,use_decoder:{use_decoder},lang:{lang},"
194
+ f"mskip_refine_text:{skip_refine_text},refine_text_only:{refine_text_only}======")
195
  assert self.check_model(use_decoder=use_decoder)
196
 
197
  if not isinstance(text, list):
 
208
  ]
209
 
210
  if skip_refine_text:
211
+ self.logger.info(f"========对文本内容不做优化处理,仅做规则处理======")
212
  else:
213
  self.logger.info(f"========针对文本内容做模型优化处理,lang:{lang}======")
214
  text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
215
+ text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in
216
+ text_tokens]
217
  text = self.pretrain_models['tokenizer'].batch_decode(text_tokens)
218
  if refine_text_only:
219
  return text
220
+
221
  text = [params_infer_code.get('prompt', '') + i for i in text]
222
  params_infer_code.pop('prompt', '')
223
  result = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder)
224
+
225
  if use_decoder:
226
+ mel_spec = [self.pretrain_models['decoder'](i[None].permute(0, 2, 1)) for i in result['hiddens']]
227
  else:
228
+ mel_spec = [self.pretrain_models['dvae'](i[None].permute(0, 2, 1)) for i in result['ids']]
229
+
230
  wav = [self.pretrain_models['vocos'].decode(i).cpu().numpy() for i in mel_spec]
231
+
232
  return wav
233
 
234
+ # 返回一个空的wav 音频文件
235
  def emptpy_audio(self):
236
+ return self.infer(" ",
237
+ skip_refine_text=True,
238
+ refine_text_only=False,
239
+ params_refine_text={},
240
+ params_infer_code={},
241
+ use_decoder=False)
242
+
243
+ '''
244
+ 将音频张量 做转码处理
245
+ '''
246
+
247
+ @torch.inference_mode()
248
+ def sample_audio_speaker(self, wav: Union[np.ndarray, torch.Tensor]) -> str:
249
+ if isinstance(wav, np.ndarray):
250
+ wav = torch.from_numpy(wav).to(self.device)
251
+ squeeze = self.pretrain_models['dvae'](wav, "encode").squeeze_(0)
252
+ return encode_prompt(squeeze)
253
 
254
  # def sample_random_speaker(self) -> str:
255
  # return self._encode_spk_emb(self.sample_random_speaker_tensor())
 
283
  .add_(mean)
284
  )
285
  del out, std, mean
286
+ return spk
test/audio_test.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+
3
+ if sys.platform == "darwin":
4
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
5
+ now_dir = os.getcwd()
6
+ sys.path.append(now_dir)
7
+
8
+ import Chat2TTS
9
+ from tool.av import load_audio
10
+ from tool.logger import get_logger
11
+
12
+
13
+
14
+ logger = get_logger("audio_test")
15
+ # Initialize and load the model:
16
+ chat = Chat2TTS.Chat()
17
+
18
+ def init_chat():
19
+ global chat
20
+ source = "local"
21
+ # 获取启动模式
22
+ MODEL = os.getenv('MODEL')
23
+ # huggingface 部署模式下,模型则直接使用hf的模型数据
24
+ if MODEL == "HF":
25
+ source = "huggingface"
26
+
27
+ logger.info("loading Chat2TTS model..., start source:" + source)
28
+
29
+
30
+ if chat.load_models(source=source, local_path="D:\\chenjgspace\\ai-model\\chattts"):
31
+ print("Models loaded successfully.")
32
+ logger.info("Models loaded end.")
33
+ # else:
34
+ # logger.error("=========Models load failed.")
35
+ # sys.exit(1)
36
+
37
+ def audo_encode():
38
+ sample_audio = load_audio("D:\\Download\\audio_test.wav",24000)
39
+ logger.info("================sample_audio:"+str(sample_audio))
40
+ spk_smp=chat.sample_audio_speaker(sample_audio)
41
+ logger.info("================spk_smp:"+str(spk_smp))
42
+
43
+
44
+ if __name__ == "__main__":
45
+
46
+ init_chat()
47
+ # 还需要继续调试
48
+ audo_encode()
test/common_test.py CHANGED
@@ -8,7 +8,7 @@ from tool.logger import get_logger
8
 
9
  logger=get_logger("common-test")
10
  def save_mp3_file(wav, index, prefix_name):
11
- from tool.pcm import pcm_arr_to_mp3_view
12
  data = pcm_arr_to_mp3_view(wav)
13
  mp3_filename = prefix_name + "_" + str(index) + ".mp3"
14
  with open(mp3_filename, "wb") as f:
 
8
 
9
  logger=get_logger("common-test")
10
  def save_mp3_file(wav, index, prefix_name):
11
+ from tool.np import pcm_arr_to_mp3_view
12
  data = pcm_arr_to_mp3_view(wav)
13
  mp3_filename = prefix_name + "_" + str(index) + ".mp3"
14
  with open(mp3_filename, "wb") as f:
tool/__init__.py CHANGED
@@ -1,5 +1,4 @@
1
  from .av import load_audio
2
- from .pcm import pcm_arr_to_mp3_view
3
- from .np import float_to_int16
4
  from .ctx import TorchSeedContext
5
  from .gpu import select_device
 
1
  from .av import load_audio
2
+ from .np import float_to_int16,pcm_arr_to_mp3_view
 
3
  from .ctx import TorchSeedContext
4
  from .gpu import select_device
tool/func.py CHANGED
@@ -1,6 +1,11 @@
1
 
2
  import gradio as gr
3
  import random
 
 
 
 
 
4
 
5
  seed_min = 1
6
  seed_max = 4294967295
@@ -30,6 +35,28 @@ voices = {
30
  def on_voice_change(vocie_selection):
31
  return voices.get(vocie_selection)["seed"]
32
 
33
-
 
 
34
  def generate_seed():
35
- return gr.update(value=random.randint(seed_min, seed_max))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
  import gradio as gr
3
  import random
4
+ import torch
5
+ import lzma
6
+
7
+ import numpy as np
8
+ import pybase16384 as b14
9
 
10
  seed_min = 1
11
  seed_max = 4294967295
 
35
  def on_voice_change(vocie_selection):
36
  return voices.get(vocie_selection)["seed"]
37
 
38
+ '''
39
+ 随机生成种子
40
+ '''
41
  def generate_seed():
42
+ return gr.update(value=random.randint(seed_min, seed_max))
43
+
44
+ '''
45
+ 音频文件张量 编码
46
+ '''
47
+
48
+ @torch.no_grad()
49
+ def encode_prompt(prompt: torch.Tensor) -> str:
50
+ arr: np.ndarray = prompt.to(dtype=torch.uint16, device="cpu").numpy()
51
+ shp = arr.shape
52
+ assert len(shp) == 2, "prompt must be a 2D tensor"
53
+ s = b14.encode_to_string(
54
+ np.array(shp, dtype="<u2").tobytes()
55
+ + lzma.compress(
56
+ arr.astype("<u2").tobytes(),
57
+ format=lzma.FORMAT_RAW,
58
+ filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
59
+ ),
60
+ )
61
+ del arr
62
+ return s
tool/np.py CHANGED
@@ -1,11 +1,28 @@
1
  import math
2
 
3
- import numpy as np
4
  from numba import jit
 
 
 
 
 
 
5
 
6
 
7
- @jit
8
  def float_to_int16(audio: np.ndarray) -> np.ndarray:
9
  am = int(math.ceil(float(np.abs(audio).max())) * 32768)
10
  am = 32767 * 32768 // am
11
  return np.multiply(audio, am).astype(np.int16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import math
2
 
 
3
  from numba import jit
4
+ import wave
5
+ from io import BytesIO
6
+
7
+ import numpy as np
8
+ from .av import wav2
9
+
10
 
11
 
 
12
  def float_to_int16(audio: np.ndarray) -> np.ndarray:
13
  am = int(math.ceil(float(np.abs(audio).max())) * 32768)
14
  am = 32767 * 32768 // am
15
  return np.multiply(audio, am).astype(np.int16)
16
+
17
+ def pcm_arr_to_mp3_view(wav: np.ndarray):
18
+ buf = BytesIO()
19
+ with wave.open(buf, "wb") as wf:
20
+ wf.setnchannels(1) # Mono channel
21
+ wf.setsampwidth(2) # Sample width in bytes
22
+ wf.setframerate(24000) # Sample rate in Hz
23
+ wf.writeframes(float_to_int16(wav))
24
+ buf.seek(0, 0)
25
+ buf2 = BytesIO()
26
+ wav2(buf, buf2, "mp3")
27
+ buf.seek(0, 0)
28
+ return buf2.getbuffer()
tool/pcm.py DELETED
@@ -1,21 +0,0 @@
1
- import wave
2
- from io import BytesIO
3
-
4
- import numpy as np
5
-
6
- from .np import float_to_int16
7
- from .av import wav2
8
-
9
-
10
- def pcm_arr_to_mp3_view(wav: np.ndarray):
11
- buf = BytesIO()
12
- with wave.open(buf, "wb") as wf:
13
- wf.setnchannels(1) # Mono channel
14
- wf.setsampwidth(2) # Sample width in bytes
15
- wf.setframerate(24000) # Sample rate in Hz
16
- wf.writeframes(float_to_int16(wav))
17
- buf.seek(0, 0)
18
- buf2 = BytesIO()
19
- wav2(buf, buf2, "mp3")
20
- buf.seek(0, 0)
21
- return buf2.getbuffer()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
web/app_cpu.py CHANGED
@@ -45,7 +45,7 @@ def init_chat(args):
45
 
46
  def main(args):
47
  with gr.Blocks() as demo:
48
- gr.Markdown("# ChatTTS demo")
49
  with gr.Row():
50
  with gr.Column(scale=1):
51
  text_input = gr.Textbox(
 
45
 
46
  def main(args):
47
  with gr.Blocks() as demo:
48
+ gr.Markdown("# ChatTTS demo CPU模式下运行")
49
  with gr.Row():
50
  with gr.Column(scale=1):
51
  text_input = gr.Textbox(
web/app_gpu.py CHANGED
@@ -48,7 +48,7 @@ def init_chat(args):
48
 
49
  def main(args):
50
  with gr.Blocks() as demo:
51
- gr.Markdown("# ChatTTS demo")
52
  with gr.Row():
53
  with gr.Column(scale=1):
54
  text_input = gr.Textbox(
@@ -71,6 +71,12 @@ def main(args):
71
  interactive=True,
72
  value=True
73
  )
 
 
 
 
 
 
74
  temperature_slider = gr.Slider(
75
  minimum=0.00001,
76
  maximum=1.0,
@@ -79,22 +85,23 @@ def main(args):
79
  interactive=True,
80
  label="模型 Temperature 参数设置"
81
  )
82
- top_p_slider = gr.Slider(
83
- minimum=0.1,
84
- maximum=0.9,
85
- step=0.05,
86
- value=0.7,
87
- label="模型 top_P 参数设置",
88
- interactive=True,
89
- )
90
- top_k_slider = gr.Slider(
91
- minimum=1,
92
- maximum=20,
93
- step=1,
94
- value=20,
95
- label="模型 top_K 参数设置",
96
- interactive=True,
97
- )
 
98
  with gr.Row():
99
  lang_selection = gr.Dropdown(
100
  label="语种",
@@ -139,7 +146,7 @@ def main(args):
139
  # )
140
 
141
  with gr.Row():
142
- reload_chat_button = gr.Button("Reload", scale=1, interactive=True)
143
  generate_button = gr.Button("生成音频文件", scale=1, interactive=True)
144
 
145
  with gr.Row():
@@ -177,11 +184,12 @@ def main(args):
177
  text_seed_input,
178
  refine_text_checkBox,
179
  refine_audio_checkBox,
 
180
  temperature_slider,
181
  top_p_slider,
182
  top_k_slider,
183
  audio_seed_input,
184
- lang_selection
185
  ],
186
  outputs=[text_output,audio_output])
187
  # 初始化 spk_emb_text 数值
@@ -212,6 +220,7 @@ def general_chat_infer_audio(text,
212
  text_seed_input,
213
  refine_text_checkBox,
214
  refine_audio_checkBox,
 
215
  temperature_slider,
216
  top_p_slider,
217
  top_k_slider,
@@ -239,7 +248,8 @@ def general_chat_infer_audio(text,
239
  skip_refine_text=False,
240
  refine_text_only=True, #仅返回优化后文本内容
241
  params_refine_text=params_refine_text,
242
- lang=lang
 
243
  )
244
 
245
 
@@ -265,6 +275,7 @@ def general_chat_infer_audio(text,
265
  skip_refine_text=True, #跳过文本优化
266
  params_refine_text=params_refine_text,
267
  params_infer_code=params_infer_code,
 
268
  )
269
 
270
  #yield 24000, float_to_int16(wav[0]).T
 
48
 
49
  def main(args):
50
  with gr.Blocks() as demo:
51
+ gr.Markdown("# ChatTTS demo GPU模式下运行")
52
  with gr.Row():
53
  with gr.Column(scale=1):
54
  text_input = gr.Textbox(
 
71
  interactive=True,
72
  value=True
73
  )
74
+
75
+ use_decoder_checkBox = gr.Checkbox(
76
+ label="是否使用decoder模型,如否则使用dvae模型",
77
+ interactive=True,
78
+ value=True
79
+ )
80
  temperature_slider = gr.Slider(
81
  minimum=0.00001,
82
  maximum=1.0,
 
85
  interactive=True,
86
  label="模型 Temperature 参数设置"
87
  )
88
+ with gr.Column():
89
+ top_p_slider = gr.Slider(
90
+ minimum=0.1,
91
+ maximum=0.9,
92
+ step=0.05,
93
+ value=0.7,
94
+ label="模型 top_P 参数设置",
95
+ interactive=True,
96
+ )
97
+ top_k_slider = gr.Slider(
98
+ minimum=1,
99
+ maximum=20,
100
+ step=1,
101
+ value=20,
102
+ label="模型 top_K 参数设置",
103
+ interactive=True,
104
+ )
105
  with gr.Row():
106
  lang_selection = gr.Dropdown(
107
  label="语种",
 
146
  # )
147
 
148
  with gr.Row():
149
+ # reload_chat_button = gr.Button("Reload", scale=1, interactive=True)
150
  generate_button = gr.Button("生成音频文件", scale=1, interactive=True)
151
 
152
  with gr.Row():
 
184
  text_seed_input,
185
  refine_text_checkBox,
186
  refine_audio_checkBox,
187
+ use_decoder_checkBox,
188
  temperature_slider,
189
  top_p_slider,
190
  top_k_slider,
191
  audio_seed_input,
192
+ lang_selection
193
  ],
194
  outputs=[text_output,audio_output])
195
  # 初始化 spk_emb_text 数值
 
220
  text_seed_input,
221
  refine_text_checkBox,
222
  refine_audio_checkBox,
223
+ use_decoder_checkBox,
224
  temperature_slider,
225
  top_p_slider,
226
  top_k_slider,
 
248
  skip_refine_text=False,
249
  refine_text_only=True, #仅返回优化后文本内容
250
  params_refine_text=params_refine_text,
251
+ lang=lang,
252
+ use_decoder=use_decoder_checkBox
253
  )
254
 
255
 
 
275
  skip_refine_text=True, #跳过文本优化
276
  params_refine_text=params_refine_text,
277
  params_infer_code=params_infer_code,
278
+ use_decoder=use_decoder_checkBox
279
  )
280
 
281
  #yield 24000, float_to_int16(wav[0]).T