sagegu commited on
Commit
135bd1e
·
1 Parent(s): 0f5fe38
Files changed (3) hide show
  1. app.py +132 -13
  2. requirements.txt +4 -3
  3. streamer.py +137 -0
app.py CHANGED
@@ -1,18 +1,137 @@
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
 
 
 
 
 
 
 
5
 
6
- def predict(input_img):
7
- predictions = pipeline(input_img)
8
- return input_img, {p["label"]: p["score"] for p in predictions}
9
 
10
- gradio_app = gr.Interface(
11
- predict,
12
- inputs=gr.Image(label="Select hot dog candidate", sources=['upload', 'webcam'], type="pil"),
13
- outputs=[gr.Image(label="Processed Image"), gr.Label(label="Result", num_top_classes=2)],
14
- title="Hot Dog? Or Not?",
15
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- if __name__ == "__main__":
18
- gradio_app.launch()
 
1
+ import io
2
+ from threading import Thread
3
+ import random
4
+ import os
5
+
6
+ import numpy as np
7
+ import spaces
8
  import gradio as gr
9
+ import torch
10
+
11
+ from parler_tts import ParlerTTSForConditionalGeneration
12
+ from pydub import AudioSegment
13
+ from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
14
+ from huggingface_hub import InferenceClient
15
+ from streamer import ParlerTTSStreamer
16
+ import time
17
+
18
+ device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
19
+ torch_dtype = torch.float16 if device != "cpu" else torch.float32
20
+
21
+ repo_id = "parler-tts/parler_tts_mini_v0.1"
22
+
23
+ jenny_repo_id = "ylacombe/parler-tts-mini-jenny-30H"
24
+
25
+ model = ParlerTTSForConditionalGeneration.from_pretrained(
26
+ jenny_repo_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
27
+ ).to(device)
28
+
29
+ client = InferenceClient(token=os.getenv("HF_TOKEN"))
30
+
31
+ tokenizer = AutoTokenizer.from_pretrained(repo_id)
32
+ feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
33
+
34
+ SAMPLE_RATE = feature_extractor.sampling_rate
35
+ SEED = 42
36
+
37
+
38
+ def numpy_to_mp3(audio_array, sampling_rate):
39
+ # Normalize audio_array if it's floating-point
40
+ if np.issubdtype(audio_array.dtype, np.floating):
41
+ max_val = np.max(np.abs(audio_array))
42
+ audio_array = (audio_array / max_val) * 32767 # Normalize to 16-bit range
43
+ audio_array = audio_array.astype(np.int16)
44
+
45
+ # Create an audio segment from the numpy array
46
+ audio_segment = AudioSegment(
47
+ audio_array.tobytes(),
48
+ frame_rate=sampling_rate,
49
+ sample_width=audio_array.dtype.itemsize,
50
+ channels=1
51
+ )
52
+
53
+ # Export the audio segment to MP3 bytes - use a high bitrate to maximise quality
54
+ mp3_io = io.BytesIO()
55
+ audio_segment.export(mp3_io, format="mp3", bitrate="320k")
56
+
57
+ # Get the MP3 bytes
58
+ mp3_bytes = mp3_io.getvalue()
59
+ mp3_io.close()
60
+
61
+ return mp3_bytes
62
+
63
+
64
+ sampling_rate = model.audio_encoder.config.sampling_rate
65
+ frame_rate = model.audio_encoder.config.frame_rate
66
+
67
+
68
+ def generate_response(audio):
69
+ gr.Info("Transcribing Audio", duration=5)
70
+ question = client.automatic_speech_recognition(audio).text
71
+ messages = [{"role": "system", "content": ("You are a magic 8 ball."
72
+ "Someone will present to you a situation or question and your job "
73
+ "is to answer with a cryptic addage or proverb such as "
74
+ "'curiosity killed the cat' or 'The early bird gets the worm'."
75
+ "Keep your answers short and do not include the phrase 'Magic 8 Ball' in your response. If the question does not make sense or is off-topic, say 'Foolish questions get foolish answers.'"
76
+ "For example, 'Magic 8 Ball, should I get a dog?', 'A dog is ready for you but are you ready for the dog?'")},
77
+ {"role": "user", "content": f"Magic 8 Ball please answer this question - {question}"}]
78
+
79
+ response = client.chat_completion(messages, max_tokens=64, seed=random.randint(1, 5000),
80
+ model="mistralai/Mistral-7B-Instruct-v0.3")
81
+ response = response.choices[0].message.content.replace("Magic 8 Ball", "")
82
+ return response, None, None
83
+
84
+
85
+ @spaces.GPU
86
+ def read_response(answer):
87
+ play_steps_in_s = 2.0
88
+ play_steps = int(frame_rate * play_steps_in_s)
89
+
90
+ description = "Jenny speaks at an average pace with a calm delivery in a very confined sounding environment with clear audio quality."
91
+ description_tokens = tokenizer(description, return_tensors="pt").to(device)
92
+
93
+ streamer = ParlerTTSStreamer(model, device=device, play_steps=play_steps)
94
+ prompt = tokenizer(answer, return_tensors="pt").to(device)
95
+
96
+ generation_kwargs = dict(
97
+ input_ids=description_tokens.input_ids,
98
+ prompt_input_ids=prompt.input_ids,
99
+ streamer=streamer,
100
+ do_sample=True,
101
+ temperature=1.0,
102
+ min_new_tokens=10,
103
+ )
104
 
105
+ set_seed(SEED)
106
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
107
+ thread.start()
108
+ start = time.time()
109
+ for new_audio in streamer:
110
+ print(
111
+ f"Sample of length: {round(new_audio.shape[0] / sampling_rate, 2)} seconds after {time.time() - start} seconds")
112
+ yield answer, numpy_to_mp3(new_audio, sampling_rate=sampling_rate)
113
 
 
 
 
114
 
115
+ with gr.Blocks() as block:
116
+ gr.HTML(
117
+ f"""
118
+ <h1 style='text-align: center;'> Magic 8 Ball 🎱 </h1>
119
+ <h3 style='text-align: center;'> Ask a question and receive wisdom </h3>
120
+ <p style='text-align: center;'> Powered by <a href="https://github.com/huggingface/parler-tts"> Parler-TTS</a>
121
+ """
122
+ )
123
+ with gr.Group():
124
+ with gr.Row():
125
+ audio_out = gr.Audio(label="Spoken Answer", streaming=True, autoplay=True, loop=False)
126
+ answer = gr.Textbox(label="Answer")
127
+ state = gr.State()
128
+ with gr.Row():
129
+ audio_in = gr.Audio(label="Speak you question", sources="microphone", type="filepath")
130
+ with gr.Row():
131
+ gr.HTML(
132
+ """<h3 style='text-align: center;'> Examples: 'What is the meaning of life?', 'Should I get a dog?' </h3>""")
133
+ audio_in.stop_recording(generate_response, audio_in, [state, answer, audio_out]).then(fn=read_response,
134
+ inputs=state,
135
+ outputs=[answer, audio_out])
136
 
137
+ block.launch()
 
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
- huggingface_hub==0.25.2
2
- transformers
3
- torch
 
 
1
+ https://gradio-builds.s3.amazonaws.com/bed454c3d22cfacedc047eb3b0ba987b485ac3fd/gradio-4.40.0-py3-none-any.whl
2
+ git+https://github.com/huggingface/parler-tts.git
3
+ accelerate
4
+ nltk
streamer.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from queue import Queue
2
+ from transformers.generation.streamers import BaseStreamer
3
+ from typing import Optional
4
+ from parler_tts import ParlerTTSForConditionalGeneration
5
+ import numpy as np
6
+ import math
7
+ import torch
8
+
9
+
10
+ class ParlerTTSStreamer(BaseStreamer):
11
+ def __init__(
12
+ self,
13
+ model: ParlerTTSForConditionalGeneration,
14
+ device: Optional[str] = None,
15
+ play_steps: Optional[int] = 10,
16
+ stride: Optional[int] = None,
17
+ timeout: Optional[float] = None,
18
+ ):
19
+ """
20
+ Streamer that stores playback-ready audio in a queue, to be used by a downstream application as an iterator. This is
21
+ useful for applications that benefit from accessing the generated audio in a non-blocking way (e.g. in an interactive
22
+ Gradio demo).
23
+ Parameters:
24
+ model (`ParlerTTSForConditionalGeneration`):
25
+ The Parler-TTS model used to generate the audio waveform.
26
+ device (`str`, *optional*):
27
+ The torch device on which to run the computation. If `None`, will default to the device of the model.
28
+ play_steps (`int`, *optional*, defaults to 10):
29
+ The number of generation steps with which to return the generated audio array. Using fewer steps will
30
+ mean the first chunk is ready faster, but will require more codec decoding steps overall. This value
31
+ should be tuned to your device and latency requirements.
32
+ stride (`int`, *optional*):
33
+ The window (stride) between adjacent audio samples. Using a stride between adjacent audio samples reduces
34
+ the hard boundary between them, giving smoother playback. If `None`, will default to a value equivalent to
35
+ play_steps // 6 in the audio space.
36
+ timeout (`int`, *optional*):
37
+ The timeout for the audio queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
38
+ in `.generate()`, when it is called in a separate thread.
39
+ """
40
+ self.decoder = model.decoder
41
+ self.audio_encoder = model.audio_encoder
42
+ self.generation_config = model.generation_config
43
+ self.device = device if device is not None else model.device
44
+
45
+ # variables used in the streaming process
46
+ self.play_steps = play_steps
47
+ if stride is not None:
48
+ self.stride = stride
49
+ else:
50
+ hop_length = math.floor(self.audio_encoder.config.sampling_rate / self.audio_encoder.config.frame_rate)
51
+ self.stride = hop_length * (play_steps - self.decoder.num_codebooks) // 6
52
+ self.token_cache = None
53
+ self.to_yield = 0
54
+
55
+ # varibles used in the thread process
56
+ self.audio_queue = Queue()
57
+ self.stop_signal = None
58
+ self.timeout = timeout
59
+
60
+ def apply_delay_pattern_mask(self, input_ids):
61
+ # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Parler)
62
+ _, delay_pattern_mask = self.decoder.build_delay_pattern_mask(
63
+ input_ids[:, :1],
64
+ bos_token_id=self.generation_config.bos_token_id,
65
+ pad_token_id=self.generation_config.decoder_start_token_id,
66
+ max_length=input_ids.shape[-1],
67
+ )
68
+ # apply the pattern mask to the input ids
69
+ input_ids = self.decoder.apply_delay_pattern_mask(input_ids, delay_pattern_mask)
70
+
71
+ # revert the pattern delay mask by filtering the pad token id
72
+ mask = (delay_pattern_mask != self.generation_config.bos_token_id) & (delay_pattern_mask != self.generation_config.pad_token_id)
73
+ input_ids = input_ids[mask].reshape(1, self.decoder.num_codebooks, -1)
74
+ # append the frame dimension back to the audio codes
75
+ input_ids = input_ids[None, ...]
76
+
77
+ # send the input_ids to the correct device
78
+ input_ids = input_ids.to(self.audio_encoder.device)
79
+
80
+ decode_sequentially = (
81
+ self.generation_config.bos_token_id in input_ids
82
+ or self.generation_config.pad_token_id in input_ids
83
+ or self.generation_config.eos_token_id in input_ids
84
+ )
85
+ if not decode_sequentially:
86
+ output_values = self.audio_encoder.decode(
87
+ input_ids,
88
+ audio_scales=[None],
89
+ )
90
+ else:
91
+ sample = input_ids[:, 0]
92
+ sample_mask = (sample >= self.audio_encoder.config.codebook_size).sum(dim=(0, 1)) == 0
93
+ sample = sample[:, :, sample_mask]
94
+ output_values = self.audio_encoder.decode(sample[None, ...], [None])
95
+
96
+ audio_values = output_values.audio_values[0, 0]
97
+ return audio_values.cpu().float().numpy()
98
+
99
+ def put(self, value):
100
+ batch_size = value.shape[0] // self.decoder.num_codebooks
101
+ if batch_size > 1:
102
+ raise ValueError("ParlerTTSStreamer only supports batch size 1")
103
+
104
+ if self.token_cache is None:
105
+ self.token_cache = value
106
+ else:
107
+ self.token_cache = torch.concatenate([self.token_cache, value[:, None]], dim=-1)
108
+
109
+ if self.token_cache.shape[-1] % self.play_steps == 0:
110
+ audio_values = self.apply_delay_pattern_mask(self.token_cache)
111
+ self.on_finalized_audio(audio_values[self.to_yield : -self.stride])
112
+ self.to_yield += len(audio_values) - self.to_yield - self.stride
113
+
114
+ def end(self):
115
+ """Flushes any remaining cache and appends the stop symbol."""
116
+ if self.token_cache is not None:
117
+ audio_values = self.apply_delay_pattern_mask(self.token_cache)
118
+ else:
119
+ audio_values = np.zeros(self.to_yield)
120
+
121
+ self.on_finalized_audio(audio_values[self.to_yield :], stream_end=True)
122
+
123
+ def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False):
124
+ """Put the new audio in the queue. If the stream is ending, also put a stop signal in the queue."""
125
+ self.audio_queue.put(audio, timeout=self.timeout)
126
+ if stream_end:
127
+ self.audio_queue.put(self.stop_signal, timeout=self.timeout)
128
+
129
+ def __iter__(self):
130
+ return self
131
+
132
+ def __next__(self):
133
+ value = self.audio_queue.get(timeout=self.timeout)
134
+ if not isinstance(value, np.ndarray) and value == self.stop_signal:
135
+ raise StopIteration()
136
+ else:
137
+ return value