srinivasbilla commited on
Commit
dbff21d
Β·
verified Β·
1 Parent(s): 91ae083

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -0
app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
2
+ import torch
3
+ import soundfile as sf
4
+ from xcodec2.modeling_xcodec2 import XCodec2Model
5
+ from IPython import display
6
+ import torchaudio
7
+ import spaces
8
+ import gradio as gr
9
+ import tempfile
10
+
11
+ llasa_3b ='srinivasbilla/llasa-3b'
12
+
13
+ tokenizer = AutoTokenizer.from_pretrained(llasa_3b)
14
+
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ llasa_3b,
17
+ trust_remote_code=True,
18
+ use_cache=False,
19
+ torch_dtype=torch.bfloat16,
20
+ device_map='cuda',
21
+ return_dict=True
22
+ )
23
+
24
+ model_path = "srinivasbilla/xcodec2"
25
+
26
+ Codec_model = XCodec2Model.from_pretrained(model_path)
27
+ Codec_model.eval().cuda()
28
+
29
+ whisper_turbo_pipe = pipeline(
30
+ "automatic-speech-recognition",
31
+ model="openai/whisper-large-v3-turbo",
32
+ torch_dtype=torch.float16,
33
+ device='cuda',
34
+ )
35
+
36
+
37
+ def ids_to_speech_tokens(speech_ids):
38
+
39
+ speech_tokens_str = []
40
+ for speech_id in speech_ids:
41
+ speech_tokens_str.append(f"<|s_{speech_id}|>")
42
+ return speech_tokens_str
43
+
44
+ def extract_speech_ids(speech_tokens_str):
45
+
46
+ speech_ids = []
47
+ for token_str in speech_tokens_str:
48
+ if token_str.startswith('<|s_') and token_str.endswith('|>'):
49
+ num_str = token_str[4:-2]
50
+
51
+ num = int(num_str)
52
+ speech_ids.append(num)
53
+ else:
54
+ print(f"Unexpected token: {token_str}")
55
+ return speech_ids
56
+
57
+ @spaces.GPU(duration=120)
58
+ def infer(sample_audio_path, target_text):
59
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
60
+
61
+ waveform, sample_rate = torchaudio.load(sample_audio_path)
62
+
63
+ # Check if the audio is stereo (i.e., has more than one channel)
64
+ if waveform.size(0) > 1:
65
+ # Convert stereo to mono by averaging the channels
66
+ waveform_mono = torch.mean(waveform, dim=0, keepdim=True)
67
+ else:
68
+ # If already mono, just use the original waveform
69
+ waveform_mono = waveform
70
+
71
+ prompt_wav = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_mono)
72
+ prompt_text = whisper_turbo_pipe(prompt_wav[0].numpy())['text'].strip()
73
+
74
+ input_text = prompt_text + ' ' + target_text
75
+
76
+ #TTS start!
77
+ with torch.no_grad():
78
+ # Encode the prompt wav
79
+ vq_code_prompt = Codec_model.encode_code(input_waveform=prompt_wav)
80
+
81
+ vq_code_prompt = vq_code_prompt[0,0,:]
82
+ # Convert int 12345 to token <|s_12345|>
83
+ speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt)
84
+
85
+ formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
86
+
87
+ # Tokenize the text and the speech prefix
88
+ chat = [
89
+ {"role": "user", "content": "Convert the text to speech:" + formatted_text},
90
+ {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>" + ''.join(speech_ids_prefix)}
91
+ ]
92
+
93
+ input_ids = tokenizer.apply_chat_template(
94
+ chat,
95
+ tokenize=True,
96
+ return_tensors='pt',
97
+ continue_final_message=True
98
+ )
99
+ input_ids = input_ids.to('cuda')
100
+ speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')
101
+
102
+ # Generate the speech autoregressively
103
+ outputs = model.generate(
104
+ input_ids,
105
+ max_length=2048, # We trained our model with a max length of 2048
106
+ eos_token_id= speech_end_id ,
107
+ do_sample=True,
108
+ top_p=1,
109
+ temperature=0.8
110
+ )
111
+ # Extract the speech tokens
112
+ generated_ids = outputs[0][input_ids.shape[1]-len(speech_ids_prefix):-1]
113
+
114
+ speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
115
+
116
+ # Convert token <|s_23456|> to int 23456
117
+ speech_tokens = extract_speech_ids(speech_tokens)
118
+
119
+ speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
120
+
121
+ # Decode the speech tokens to speech waveform
122
+ gen_wav = Codec_model.decode_code(speech_tokens)
123
+
124
+ # if only need the generated part
125
+ gen_wav = gen_wav[:,:,prompt_wav.shape[1]:]
126
+
127
+ return gen_wav[0, 0, :].cpu().numpy()
128
+
129
+ with gr.Blocks() as app_tts:
130
+ gr.Markdown("# Zero Shot Voice Clone TTS")
131
+ ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
132
+ gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
133
+
134
+ generate_btn = gr.Button("Synthesize", variant="primary")
135
+
136
+ audio_output = gr.Audio(label="Synthesized Audio")
137
+
138
+ generate_btn.click(
139
+ infer,
140
+ inputs=[
141
+ ref_audio_input,
142
+ gen_text_input,
143
+ ],
144
+ outputs=[audio_output],
145
+ )
146
+
147
+ with gr.Blocks() as app:
148
+ gr.Markdown(
149
+ """
150
+ # llasa 3b TTS
151
+
152
+ This is a local web UI for llasa 3b zero shot voice cloning and tts
153
+
154
+ The checkpoints support English and Chinese.
155
+
156
+ If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt.
157
+ """
158
+ )
159
+ gr.TabbedInterface([app_tts], ["TTS"])
160
+
161
+
162
+ app.launch()