sagar007 commited on
Commit
b8c63a2
·
verified ·
1 Parent(s): a6e4f9f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +375 -153
app.py CHANGED
@@ -9,43 +9,42 @@ import os
9
  import subprocess
10
  import numpy as np
11
  from typing import List, Dict, Tuple, Any
12
-
13
- # Install required dependencies for Kokoro with better error handling
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  try:
15
- subprocess.run(['git', 'lfs', 'install'], check=True)
16
- if not os.path.exists('Kokoro-82M'):
17
- subprocess.run(['git', 'clone', 'https://huggingface.co/hexgrad/Kokoro-82M'], check=True)
18
-
19
- # Try installing espeak with proper package manager commands
20
- try:
21
- subprocess.run(['apt-get', 'update'], check=True)
22
- subprocess.run(['apt-get', 'install', '-y', 'espeak'], check=True)
23
- except subprocess.CalledProcessError:
24
- print("Warning: Could not install espeak. Attempting espeak-ng...")
25
- try:
26
- subprocess.run(['apt-get', 'install', '-y', 'espeak-ng'], check=True)
27
- except subprocess.CalledProcessError:
28
- print("Warning: Could not install espeak or espeak-ng. TTS functionality may be limited.")
29
-
30
  except Exception as e:
31
- print(f"Warning: Initial setup error: {str(e)}")
32
- print("Continuing with limited functionality...")
33
-
34
- # --- Initialization (Do this ONCE) ---
35
-
36
- model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
37
- tokenizer = AutoTokenizer.from_pretrained(model_name)
38
- tokenizer.pad_token = tokenizer.eos_token
39
- # Initialize DeepSeek model
40
- model = AutoModelForCausalLM.from_pretrained(
41
- model_name,
42
- device_map="auto",
43
- offload_folder="offload",
44
- low_cpu_mem_usage=True,
45
- torch_dtype=torch.float16
46
- )
47
-
48
- # Initialize Kokoro TTS (with error handling)
49
  VOICE_CHOICES = {
50
  '🇺🇸 Female (Default)': 'af',
51
  '🇺🇸 Bella': 'af_bella',
@@ -54,41 +53,81 @@ VOICE_CHOICES = {
54
  }
55
  TTS_ENABLED = False
56
  TTS_MODEL = None
57
- VOICEPACK = None
58
-
59
- try:
60
- if os.path.exists('Kokoro-82M'):
61
- import sys
62
- sys.path.append('Kokoro-82M')
63
- from models import build_model # type: ignore
64
- from kokoro import generate # type: ignore
65
-
66
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
67
- TTS_MODEL = build_model('Kokoro-82M/kokoro-v0_19.pth', device)
68
 
69
- # Load default voice
 
 
 
 
 
 
 
 
 
 
70
  try:
71
- VOICEPACK = torch.load('Kokoro-82M/voices/af.pt', map_location=device, weights_only=True)
72
- except Exception as e:
73
- print(f"Warning: Could not load default voice: {e}")
74
- raise
75
-
76
- TTS_ENABLED = True
77
- else:
78
- print("Warning: Kokoro-82M directory not found. TTS disabled.")
79
- except Exception as e:
80
- print(f"Warning: Could not initialize Kokoro TTS: {str(e)}")
81
- TTS_ENABLED = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- def get_web_results(query: str, max_results: int = 5) -> List[Dict[str, str]]:
84
- """Get web search results using DuckDuckGo"""
 
 
 
 
 
85
  try:
86
  with DDGS() as ddgs:
87
  results = list(ddgs.text(query, max_results=max_results))
88
  return [{
89
  "title": result.get("title", ""),
90
- "snippet": result["body"],
91
- "url": result["href"],
92
  "date": result.get("published", "")
93
  } for result in results]
94
  except Exception as e:
@@ -116,23 +155,24 @@ def format_sources(web_results: List[Dict[str, str]]) -> str:
116
  sources_html = "<div class='sources-container'>"
117
  for i, res in enumerate(web_results, 1):
118
  title = res["title"] or "Source"
119
- date = f"<span class='source-date'>{res['date']}</span>" if res['date'] else ""
 
120
  sources_html += f"""
121
  <div class='source-item'>
122
  <div class='source-number'>[{i}]</div>
123
  <div class='source-content'>
124
  <a href="{res['url']}" target="_blank" class='source-title'>{title}</a>
125
  {date}
126
- <div class='source-snippet'>{res['snippet'][:150]}...</div>
127
  </div>
128
  </div>
129
  """
130
  sources_html += "</div>"
131
  return sources_html
132
 
133
- @spaces.GPU(duration=30)
134
  def generate_answer(prompt: str) -> str:
135
- """Generate answer using the DeepSeek model"""
136
  inputs = tokenizer(
137
  prompt,
138
  return_tensors="pt",
@@ -142,52 +182,56 @@ def generate_answer(prompt: str) -> str:
142
  return_attention_mask=True
143
  ).to(model.device)
144
 
145
- outputs = model.generate(
146
- inputs.input_ids,
147
- attention_mask=inputs.attention_mask,
148
- max_new_tokens=256,
149
- temperature=0.7,
150
- top_p=0.95,
151
- pad_token_id=tokenizer.eos_token_id,
152
- do_sample=True,
153
- early_stopping=True
154
- )
 
155
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
156
 
157
- @spaces.GPU(duration=30)
158
- def generate_speech_with_gpu(text: str, voice_name: str = 'af', tts_model=TTS_MODEL, voicepack=VOICEPACK) -> Tuple[int, np.ndarray] | None:
159
- """Generate speech from text using Kokoro TTS model."""
160
- if not TTS_ENABLED or tts_model is None:
161
- print("TTS is not enabled or model is not loaded.")
 
162
  return None
163
 
164
  try:
 
165
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
166
 
167
- # Handle voicepack loading
168
- voice_file = f'Kokoro-82M/voices/{voice_name}.pt'
169
- if voice_name == 'af' and voicepack is not None:
170
- # Use the pre-loaded default voicepack
171
- pass
172
- elif os.path.exists(voice_file):
173
- # Load the selected voicepack if it exists
174
- voicepack = torch.load(voice_file, map_location=device, weights_only=True)
175
- else:
176
- # Fall back to default 'af' if selected voicepack is missing
177
- print(f"Voicepack {voice_name}.pt not found. Falling back to default 'af'.")
178
- voice_file = 'Kokoro-82M/voices/af.pt'
179
- if os.path.exists(voice_file):
180
- voicepack = torch.load(voice_file, map_location=device, weights_only=True)
 
 
181
  else:
182
- print("Default voicepack 'af.pt' not found. Cannot generate audio.")
183
- return None
184
-
185
  # Clean the text
186
  clean_text = ' '.join([line for line in text.split('\n') if not line.startswith('#')])
187
  clean_text = clean_text.replace('[', '').replace(']', '').replace('*', '')
188
 
189
  # Split long text into chunks
190
- max_chars = 1000
191
  chunks = []
192
  if len(clean_text) > max_chars:
193
  sentences = clean_text.split('.')
@@ -207,7 +251,7 @@ def generate_speech_with_gpu(text: str, voice_name: str = 'af', tts_model=TTS_MO
207
  audio_chunks = []
208
  for chunk in chunks:
209
  if chunk.strip():
210
- chunk_audio, _ = generate(tts_model, chunk, voicepack, lang='a')
211
  if isinstance(chunk_audio, torch.Tensor):
212
  chunk_audio = chunk_audio.cpu().numpy()
213
  audio_chunks.append(chunk_audio)
@@ -215,35 +259,61 @@ def generate_speech_with_gpu(text: str, voice_name: str = 'af', tts_model=TTS_MO
215
  # Concatenate chunks
216
  if audio_chunks:
217
  final_audio = np.concatenate(audio_chunks) if len(audio_chunks) > 1 else audio_chunks[0]
218
- return (24000, final_audio)
219
- else:
220
- return None
221
 
222
  except Exception as e:
223
  print(f"Error generating speech: {str(e)}")
224
  return None
225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  def process_query(query: str, history: List[List[str]], selected_voice: str = 'af'):
227
- """Process user query with streaming effect"""
228
  try:
229
  if history is None:
230
  history = []
231
 
232
- # Get web results first
233
- web_results = get_web_results(query)
234
- sources_html = format_sources(web_results)
235
-
236
  current_history = history + [[query, "*Searching...*"]]
237
-
238
  # Yield initial searching state
239
  yield (
240
  "*Searching & Thinking...*", # answer_output (Markdown)
241
- sources_html, # sources_output (HTML)
242
  "Searching...", # search_btn (Button)
243
  current_history, # chat_history_display (Chatbot)
244
  None # audio_output (Audio)
245
  )
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  # Generate answer
248
  prompt = format_prompt(query, web_results)
249
  answer = generate_answer(prompt)
@@ -251,26 +321,27 @@ def process_query(query: str, history: List[List[str]], selected_voice: str = 'a
251
 
252
  # Update history before TTS
253
  updated_history = history + [[query, final_answer]]
 
 
 
 
 
 
 
 
 
254
 
255
- # Generate speech from the answer (only if enabled)
256
- if TTS_ENABLED:
257
- yield (
258
- final_answer, # answer_output
259
- sources_html, # sources_output
260
- "Generating audio...", # search_btn
261
- updated_history, # chat_history_display
262
- None # audio_output
263
- )
264
  try:
265
- audio = generate_speech_with_gpu(final_answer, selected_voice)
266
  if audio is None:
267
  final_answer += "\n\n*Audio generation failed. The voicepack may be missing or incompatible.*"
268
  except Exception as e:
269
  final_answer += f"\n\n*Error generating audio: {str(e)}*"
270
- audio = None
271
  else:
272
- final_answer += "\n\n*TTS is disabled. Audio not available.*"
273
- audio = None
274
 
275
  # Yield final result
276
  yield (
@@ -278,7 +349,7 @@ def process_query(query: str, history: List[List[str]], selected_voice: str = 'a
278
  sources_html, # sources_output
279
  "Search", # search_btn
280
  updated_history, # chat_history_display
281
- audio if audio is not None else None # audio_output
282
  )
283
 
284
  except Exception as e:
@@ -287,13 +358,13 @@ def process_query(query: str, history: List[List[str]], selected_voice: str = 'a
287
  error_message = "⚠️ GPU quota exceeded. Please try again later when the daily quota resets."
288
  yield (
289
  f"Error: {error_message}", # answer_output
290
- sources_html, # sources_output
291
  "Search", # search_btn
292
  history + [[query, f"*Error: {error_message}*"]], # chat_history_display
293
  None # audio_output
294
  )
295
 
296
- # Update the CSS for better contrast and readability
297
  css = """
298
  .gradio-container {
299
  max-width: 1200px !important;
@@ -303,36 +374,44 @@ css = """
303
  text-align: center;
304
  margin-bottom: 2rem;
305
  padding: 2rem 0;
306
- background: #1a1b1e;
307
  border-radius: 12px;
308
  color: white;
 
309
  }
310
  #header h1 {
311
  color: white;
312
  font-size: 2.5rem;
313
  margin-bottom: 0.5rem;
 
314
  }
315
  #header h3 {
316
  color: #a8a9ab;
317
  }
318
  .search-container {
319
- background: #1a1b1e;
320
  border-radius: 12px;
321
- box-shadow: 0 4px 12px rgba(0,0,0,0.1);
322
- padding: 1rem;
323
- margin-bottom: 1rem;
324
  }
325
  .search-box {
326
  padding: 1rem;
327
  background: #2c2d30;
328
- border-radius: 8px;
329
  margin-bottom: 1rem;
 
330
  }
331
  .search-box input[type="text"] {
332
  background: #3a3b3e !important;
333
  border: 1px solid #4a4b4e !important;
334
  color: white !important;
335
  border-radius: 8px !important;
 
 
 
 
 
336
  }
337
  .search-box input[type="text"]::placeholder {
338
  color: #a8a9ab !important;
@@ -340,23 +419,43 @@ css = """
340
  .search-box button {
341
  background: #2563eb !important;
342
  border: none !important;
 
 
 
 
 
 
 
 
 
343
  }
344
  .results-container {
345
  background: #2c2d30;
346
- border-radius: 8px;
347
- padding: 1rem;
348
- margin-top: 1rem;
 
349
  }
350
  .answer-box {
351
  background: #3a3b3e;
352
- border-radius: 8px;
353
  padding: 1.5rem;
354
  color: white;
355
- margin-bottom: 1rem;
 
 
 
 
 
356
  }
357
  .answer-box p {
358
  color: #e5e7eb;
359
- line-height: 1.6;
 
 
 
 
 
360
  }
361
  .sources-container {
362
  margin-top: 1rem;
@@ -367,13 +466,16 @@ css = """
367
  .source-item {
368
  display: flex;
369
  padding: 12px;
370
- margin: 8px 0;
371
  background: #3a3b3e;
372
  border-radius: 8px;
373
  transition: all 0.2s;
 
374
  }
375
  .source-item:hover {
376
  background: #4a4b4e;
 
 
377
  }
378
  .source-number {
379
  font-weight: bold;
@@ -388,7 +490,12 @@ css = """
388
  font-weight: 500;
389
  text-decoration: none;
390
  display: block;
391
- margin-bottom: 4px;
 
 
 
 
 
392
  }
393
  .source-date {
394
  color: #a8a9ab;
@@ -398,7 +505,7 @@ css = """
398
  .source-snippet {
399
  color: #e5e7eb;
400
  font-size: 0.9em;
401
- line-height: 1.4;
402
  }
403
  .chat-history {
404
  max-height: 400px;
@@ -407,6 +514,18 @@ css = """
407
  background: #2c2d30;
408
  border-radius: 8px;
409
  margin-top: 1rem;
 
 
 
 
 
 
 
 
 
 
 
 
410
  }
411
  .examples-container {
412
  background: #2c2d30;
@@ -418,20 +537,73 @@ css = """
418
  background: #3a3b3e !important;
419
  border: 1px solid #4a4b4e !important;
420
  color: #e5e7eb !important;
 
 
 
 
 
 
421
  }
422
  .markdown-content {
423
  color: #e5e7eb !important;
424
  }
425
  .markdown-content h1, .markdown-content h2, .markdown-content h3 {
426
  color: white !important;
 
 
 
 
 
 
 
 
 
 
 
427
  }
428
  .markdown-content a {
429
  color: #60a5fa !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  }
431
  .accordion {
432
  background: #2c2d30 !important;
433
  border-radius: 8px !important;
434
  margin-top: 1rem !important;
 
435
  }
436
  .voice-selector {
437
  margin-top: 1rem;
@@ -443,10 +615,54 @@ css = """
443
  background: #3a3b3e !important;
444
  color: white !important;
445
  border: 1px solid #4a4b4e !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446
  }
447
  """
448
 
449
- # Update the Gradio interface layout
450
  with gr.Blocks(title="AI Search Assistant", css=css, theme="dark") as demo:
451
  chat_history = gr.State([])
452
 
@@ -462,13 +678,14 @@ with gr.Blocks(title="AI Search Assistant", css=css, theme="dark") as demo:
462
  scale=5,
463
  container=False
464
  )
465
- search_btn = gr.Button("Search", variant="primary", scale=1)
466
  voice_select = gr.Dropdown(
467
- choices=list(VOICE_CHOICES.items()),
468
- value='af',
469
- label="Select Voice",
470
- elem_classes="voice-selector"
 
471
  )
 
472
 
473
  with gr.Row(elem_classes="results-container"):
474
  with gr.Column(scale=2):
@@ -486,28 +703,33 @@ with gr.Blocks(title="AI Search Assistant", css=css, theme="dark") as demo:
486
  with gr.Row(elem_classes="examples-container"):
487
  gr.Examples(
488
  examples=[
489
- "musk explores blockchain for doge",
490
- "nvidia to launch new gaming card",
491
  "What are the best practices for sustainable living?",
492
- "tesla mistaken for asteroid"
493
  ],
494
  inputs=search_input,
495
  label="Try these examples"
496
  )
497
 
 
 
 
 
498
  # Handle interactions
499
  search_btn.click(
500
  fn=process_query,
501
- inputs=[search_input, chat_history, voice_select],
502
  outputs=[answer_output, sources_output, search_btn, chat_history_display, audio_output]
503
  )
504
 
505
  # Also trigger search on Enter key
506
  search_input.submit(
507
  fn=process_query,
508
- inputs=[search_input, chat_history, voice_select],
509
  outputs=[answer_output, sources_output, search_btn, chat_history_display, audio_output]
510
  )
511
 
512
  if __name__ == "__main__":
513
- demo.launch(share=True)
 
 
9
  import subprocess
10
  import numpy as np
11
  from typing import List, Dict, Tuple, Any
12
+ from functools import lru_cache
13
+ import asyncio
14
+ import threading
15
+ from concurrent.futures import ThreadPoolExecutor
16
+
17
+ # --- Configuration ---
18
+ MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
19
+ MAX_SEARCH_RESULTS = 5
20
+ TTS_SAMPLE_RATE = 24000
21
+ MAX_TTS_CHARS = 1000
22
+ GPU_DURATION = 30 # for spaces.GPU decorator
23
+ MAX_NEW_TOKENS = 256
24
+ TEMPERATURE = 0.7
25
+ TOP_P = 0.95
26
+
27
+ # --- Initialization ---
28
+ # Initialize model and tokenizer with better error handling
29
  try:
30
+ print("Loading tokenizer...")
31
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
32
+ tokenizer.pad_token = tokenizer.eos_token
33
+
34
+ print("Loading model...")
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ MODEL_NAME,
37
+ device_map="auto",
38
+ offload_folder="offload",
39
+ low_cpu_mem_usage=True,
40
+ torch_dtype=torch.float16
41
+ )
42
+ print("Model and tokenizer loaded successfully")
 
 
43
  except Exception as e:
44
+ print(f"Error initializing model: {str(e)}")
45
+ raise
46
+
47
+ # --- TTS Setup ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  VOICE_CHOICES = {
49
  '🇺🇸 Female (Default)': 'af',
50
  '🇺🇸 Bella': 'af_bella',
 
53
  }
54
  TTS_ENABLED = False
55
  TTS_MODEL = None
56
+ VOICEPACKS = {} # Cache voice packs
 
 
 
 
 
 
 
 
 
 
57
 
58
+ # Initialize Kokoro TTS in a separate thread to avoid blocking startup
59
+ def setup_tts():
60
+ global TTS_ENABLED, TTS_MODEL, VOICEPACKS
61
+
62
+ try:
63
+ # Install dependencies first
64
+ subprocess.run(['git', 'lfs', 'install'], check=True)
65
+ if not os.path.exists('Kokoro-82M'):
66
+ subprocess.run(['git', 'clone', 'https://huggingface.co/hexgrad/Kokoro-82M'], check=True)
67
+
68
+ # Install espeak
69
  try:
70
+ subprocess.run(['apt-get', 'update'], check=True)
71
+ subprocess.run(['apt-get', 'install', '-y', 'espeak'], check=True)
72
+ except subprocess.CalledProcessError:
73
+ try:
74
+ subprocess.run(['apt-get', 'install', '-y', 'espeak-ng'], check=True)
75
+ except subprocess.CalledProcessError:
76
+ print("Warning: Could not install espeak or espeak-ng. TTS functionality may be limited.")
77
+
78
+ # Set up Kokoro TTS
79
+ if os.path.exists('Kokoro-82M'):
80
+ import sys
81
+ sys.path.append('Kokoro-82M')
82
+ from models import build_model
83
+ from kokoro import generate
84
+
85
+ # Make these functions accessible globally
86
+ globals()['build_model'] = build_model
87
+ globals()['generate_tts'] = generate
88
+
89
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
90
+ TTS_MODEL = build_model('Kokoro-82M/kokoro-v0_19.pth', device)
91
+
92
+ # Preload default voice
93
+ default_voice = 'af'
94
+ VOICEPACKS[default_voice] = torch.load(f'Kokoro-82M/voices/{default_voice}.pt',
95
+ map_location=device,
96
+ weights_only=True)
97
+
98
+ # Preload other common voices to reduce latency
99
+ for voice_name in ['af_bella', 'af_sarah', 'af_nicole']:
100
+ try:
101
+ voice_path = f'Kokoro-82M/voices/{voice_name}.pt'
102
+ if os.path.exists(voice_path):
103
+ VOICEPACKS[voice_name] = torch.load(voice_path,
104
+ map_location=device,
105
+ weights_only=True)
106
+ except Exception as e:
107
+ print(f"Warning: Could not preload voice {voice_name}: {str(e)}")
108
+
109
+ TTS_ENABLED = True
110
+ print("TTS setup completed successfully")
111
+ else:
112
+ print("Warning: Kokoro-82M directory not found. TTS disabled.")
113
+ except Exception as e:
114
+ print(f"Warning: Could not initialize Kokoro TTS: {str(e)}")
115
+ TTS_ENABLED = False
116
 
117
+ # Start TTS setup in a separate thread
118
+ threading.Thread(target=setup_tts, daemon=True).start()
119
+
120
+ # --- Search and Generation Functions ---
121
+ @lru_cache(maxsize=128)
122
+ def get_web_results(query: str, max_results: int = MAX_SEARCH_RESULTS) -> List[Dict[str, str]]:
123
+ """Get web search results using DuckDuckGo with caching for improved performance"""
124
  try:
125
  with DDGS() as ddgs:
126
  results = list(ddgs.text(query, max_results=max_results))
127
  return [{
128
  "title": result.get("title", ""),
129
+ "snippet": result.get("body", ""),
130
+ "url": result.get("href", ""),
131
  "date": result.get("published", "")
132
  } for result in results]
133
  except Exception as e:
 
155
  sources_html = "<div class='sources-container'>"
156
  for i, res in enumerate(web_results, 1):
157
  title = res["title"] or "Source"
158
+ date = f"<span class='source-date'>{res['date']}</span>" if res.get('date') else ""
159
+ snippet = res.get("snippet", "")[:150] + "..." if res.get("snippet") else ""
160
  sources_html += f"""
161
  <div class='source-item'>
162
  <div class='source-number'>[{i}]</div>
163
  <div class='source-content'>
164
  <a href="{res['url']}" target="_blank" class='source-title'>{title}</a>
165
  {date}
166
+ <div class='source-snippet'>{snippet}</div>
167
  </div>
168
  </div>
169
  """
170
  sources_html += "</div>"
171
  return sources_html
172
 
173
+ @spaces.GPU(duration=GPU_DURATION)
174
  def generate_answer(prompt: str) -> str:
175
+ """Generate answer using the DeepSeek model with optimized settings"""
176
  inputs = tokenizer(
177
  prompt,
178
  return_tensors="pt",
 
182
  return_attention_mask=True
183
  ).to(model.device)
184
 
185
+ with torch.no_grad(): # Disable gradient calculation for inference
186
+ outputs = model.generate(
187
+ inputs.input_ids,
188
+ attention_mask=inputs.attention_mask,
189
+ max_new_tokens=MAX_NEW_TOKENS,
190
+ temperature=TEMPERATURE,
191
+ top_p=TOP_P,
192
+ pad_token_id=tokenizer.eos_token_id,
193
+ do_sample=True,
194
+ early_stopping=True
195
+ )
196
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
197
 
198
+ @spaces.GPU(duration=GPU_DURATION)
199
+ def generate_speech(text: str, voice_name: str = 'af') -> Tuple[int, np.ndarray] | None:
200
+ """Generate speech from text using Kokoro TTS model with improved error handling and caching."""
201
+ global VOICEPACKS, TTS_MODEL, TTS_ENABLED
202
+
203
+ if not TTS_ENABLED or TTS_MODEL is None:
204
  return None
205
 
206
  try:
207
+ from kokoro import generate as generate_tts
208
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
209
 
210
+ # Load voicepack if needed
211
+ if voice_name not in VOICEPACKS:
212
+ voice_file = f'Kokoro-82M/voices/{voice_name}.pt'
213
+
214
+ if not os.path.exists(voice_file):
215
+ print(f"Voicepack {voice_name}.pt not found. Falling back to default 'af'.")
216
+ voice_name = 'af'
217
+
218
+ # Check if default is already loaded
219
+ if voice_name not in VOICEPACKS:
220
+ voice_file = f'Kokoro-82M/voices/{voice_name}.pt'
221
+ if os.path.exists(voice_file):
222
+ VOICEPACKS[voice_name] = torch.load(voice_file, map_location=device, weights_only=True)
223
+ else:
224
+ print("Default voicepack 'af.pt' not found. Cannot generate audio.")
225
+ return None
226
  else:
227
+ VOICEPACKS[voice_name] = torch.load(voice_file, map_location=device, weights_only=True)
228
+
 
229
  # Clean the text
230
  clean_text = ' '.join([line for line in text.split('\n') if not line.startswith('#')])
231
  clean_text = clean_text.replace('[', '').replace(']', '').replace('*', '')
232
 
233
  # Split long text into chunks
234
+ max_chars = MAX_TTS_CHARS
235
  chunks = []
236
  if len(clean_text) > max_chars:
237
  sentences = clean_text.split('.')
 
251
  audio_chunks = []
252
  for chunk in chunks:
253
  if chunk.strip():
254
+ chunk_audio, _ = generate_tts(TTS_MODEL, chunk, VOICEPACKS[voice_name], lang='a')
255
  if isinstance(chunk_audio, torch.Tensor):
256
  chunk_audio = chunk_audio.cpu().numpy()
257
  audio_chunks.append(chunk_audio)
 
259
  # Concatenate chunks
260
  if audio_chunks:
261
  final_audio = np.concatenate(audio_chunks) if len(audio_chunks) > 1 else audio_chunks[0]
262
+ return (TTS_SAMPLE_RATE, final_audio)
263
+
264
+ return None
265
 
266
  except Exception as e:
267
  print(f"Error generating speech: {str(e)}")
268
  return None
269
 
270
+ # --- Asynchronous Processing ---
271
+ async def async_web_search(query: str) -> List[Dict[str, str]]:
272
+ """Run web search in a non-blocking way"""
273
+ loop = asyncio.get_event_loop()
274
+ return await loop.run_in_executor(None, get_web_results, query)
275
+
276
+ async def async_answer_generation(prompt: str) -> str:
277
+ """Run answer generation in a non-blocking way"""
278
+ loop = asyncio.get_event_loop()
279
+ return await loop.run_in_executor(None, generate_answer, prompt)
280
+
281
+ async def async_speech_generation(text: str, voice_name: str) -> Tuple[int, np.ndarray] | None:
282
+ """Run speech generation in a non-blocking way"""
283
+ loop = asyncio.get_event_loop()
284
+ return await loop.run_in_executor(None, generate_speech, text, voice_name)
285
+
286
  def process_query(query: str, history: List[List[str]], selected_voice: str = 'af'):
287
+ """Process user query with streaming effect and non-blocking operations"""
288
  try:
289
  if history is None:
290
  history = []
291
 
292
+ # Start the search task
 
 
 
293
  current_history = history + [[query, "*Searching...*"]]
294
+
295
  # Yield initial searching state
296
  yield (
297
  "*Searching & Thinking...*", # answer_output (Markdown)
298
+ "<div class='searching'>Searching for results...</div>", # sources_output (HTML)
299
  "Searching...", # search_btn (Button)
300
  current_history, # chat_history_display (Chatbot)
301
  None # audio_output (Audio)
302
  )
303
 
304
+ # Get web results
305
+ web_results = get_web_results(query)
306
+ sources_html = format_sources(web_results)
307
+
308
+ # Update with the search results obtained
309
+ yield (
310
+ "*Analyzing search results...*", # answer_output
311
+ sources_html, # sources_output
312
+ "Generating answer...", # search_btn
313
+ current_history, # chat_history_display
314
+ None # audio_output
315
+ )
316
+
317
  # Generate answer
318
  prompt = format_prompt(query, web_results)
319
  answer = generate_answer(prompt)
 
321
 
322
  # Update history before TTS
323
  updated_history = history + [[query, final_answer]]
324
+
325
+ # Update with the answer before generating speech
326
+ yield (
327
+ final_answer, # answer_output
328
+ sources_html, # sources_output
329
+ "Generating audio...", # search_btn
330
+ updated_history, # chat_history_display
331
+ None # audio_output
332
+ )
333
 
334
+ # Generate speech (but don't block if TTS is still initializing)
335
+ audio = None
336
+ if TTS_ENABLED and TTS_MODEL is not None:
 
 
 
 
 
 
337
  try:
338
+ audio = generate_speech(final_answer, selected_voice)
339
  if audio is None:
340
  final_answer += "\n\n*Audio generation failed. The voicepack may be missing or incompatible.*"
341
  except Exception as e:
342
  final_answer += f"\n\n*Error generating audio: {str(e)}*"
 
343
  else:
344
+ final_answer += "\n\n*TTS is still initializing or is disabled. Try again in a moment.*"
 
345
 
346
  # Yield final result
347
  yield (
 
349
  sources_html, # sources_output
350
  "Search", # search_btn
351
  updated_history, # chat_history_display
352
+ audio # audio_output
353
  )
354
 
355
  except Exception as e:
 
358
  error_message = "⚠️ GPU quota exceeded. Please try again later when the daily quota resets."
359
  yield (
360
  f"Error: {error_message}", # answer_output
361
+ "<div class='error'>An error occurred during search</div>", # sources_output
362
  "Search", # search_btn
363
  history + [[query, f"*Error: {error_message}*"]], # chat_history_display
364
  None # audio_output
365
  )
366
 
367
+ # --- Improved UI ---
368
  css = """
369
  .gradio-container {
370
  max-width: 1200px !important;
 
374
  text-align: center;
375
  margin-bottom: 2rem;
376
  padding: 2rem 0;
377
+ background: linear-gradient(135deg, #1a1b1e, #2d2e32);
378
  border-radius: 12px;
379
  color: white;
380
+ box-shadow: 0 8px 32px rgba(0,0,0,0.2);
381
  }
382
  #header h1 {
383
  color: white;
384
  font-size: 2.5rem;
385
  margin-bottom: 0.5rem;
386
+ text-shadow: 0 2px 4px rgba(0,0,0,0.3);
387
  }
388
  #header h3 {
389
  color: #a8a9ab;
390
  }
391
  .search-container {
392
+ background: linear-gradient(135deg, #1a1b1e, #2d2e32);
393
  border-radius: 12px;
394
+ box-shadow: 0 4px 16px rgba(0,0,0,0.15);
395
+ padding: 1.5rem;
396
+ margin-bottom: 1.5rem;
397
  }
398
  .search-box {
399
  padding: 1rem;
400
  background: #2c2d30;
401
+ border-radius: 10px;
402
  margin-bottom: 1rem;
403
+ box-shadow: inset 0 2px 4px rgba(0,0,0,0.1);
404
  }
405
  .search-box input[type="text"] {
406
  background: #3a3b3e !important;
407
  border: 1px solid #4a4b4e !important;
408
  color: white !important;
409
  border-radius: 8px !important;
410
+ transition: all 0.3s ease;
411
+ }
412
+ .search-box input[type="text"]:focus {
413
+ border-color: #60a5fa !important;
414
+ box-shadow: 0 0 0 2px rgba(96, 165, 250, 0.3) !important;
415
  }
416
  .search-box input[type="text"]::placeholder {
417
  color: #a8a9ab !important;
 
419
  .search-box button {
420
  background: #2563eb !important;
421
  border: none !important;
422
+ box-shadow: 0 2px 4px rgba(0,0,0,0.2) !important;
423
+ transition: all 0.3s ease !important;
424
+ }
425
+ .search-box button:hover {
426
+ background: #1d4ed8 !important;
427
+ transform: translateY(-1px) !important;
428
+ }
429
+ .search-box button:active {
430
+ transform: translateY(1px) !important;
431
  }
432
  .results-container {
433
  background: #2c2d30;
434
+ border-radius: 10px;
435
+ padding: 1.5rem;
436
+ margin-top: 1.5rem;
437
+ box-shadow: 0 4px 12px rgba(0,0,0,0.1);
438
  }
439
  .answer-box {
440
  background: #3a3b3e;
441
+ border-radius: 10px;
442
  padding: 1.5rem;
443
  color: white;
444
+ margin-bottom: 1.5rem;
445
+ box-shadow: 0 2px 8px rgba(0,0,0,0.15);
446
+ transition: all 0.3s ease;
447
+ }
448
+ .answer-box:hover {
449
+ box-shadow: 0 4px 16px rgba(0,0,0,0.2);
450
  }
451
  .answer-box p {
452
  color: #e5e7eb;
453
+ line-height: 1.7;
454
+ }
455
+ .answer-box code {
456
+ background: #2c2d30;
457
+ border-radius: 4px;
458
+ padding: 2px 4px;
459
  }
460
  .sources-container {
461
  margin-top: 1rem;
 
466
  .source-item {
467
  display: flex;
468
  padding: 12px;
469
+ margin: 12px 0;
470
  background: #3a3b3e;
471
  border-radius: 8px;
472
  transition: all 0.2s;
473
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
474
  }
475
  .source-item:hover {
476
  background: #4a4b4e;
477
+ transform: translateY(-2px);
478
+ box-shadow: 0 4px 8px rgba(0,0,0,0.15);
479
  }
480
  .source-number {
481
  font-weight: bold;
 
490
  font-weight: 500;
491
  text-decoration: none;
492
  display: block;
493
+ margin-bottom: 6px;
494
+ transition: all 0.2s;
495
+ }
496
+ .source-title:hover {
497
+ color: #93c5fd;
498
+ text-decoration: underline;
499
  }
500
  .source-date {
501
  color: #a8a9ab;
 
505
  .source-snippet {
506
  color: #e5e7eb;
507
  font-size: 0.9em;
508
+ line-height: 1.5;
509
  }
510
  .chat-history {
511
  max-height: 400px;
 
514
  background: #2c2d30;
515
  border-radius: 8px;
516
  margin-top: 1rem;
517
+ scrollbar-width: thin;
518
+ scrollbar-color: #4a4b4e #2c2d30;
519
+ }
520
+ .chat-history::-webkit-scrollbar {
521
+ width: 8px;
522
+ }
523
+ .chat-history::-webkit-scrollbar-track {
524
+ background: #2c2d30;
525
+ }
526
+ .chat-history::-webkit-scrollbar-thumb {
527
+ background-color: #4a4b4e;
528
+ border-radius: 20px;
529
  }
530
  .examples-container {
531
  background: #2c2d30;
 
537
  background: #3a3b3e !important;
538
  border: 1px solid #4a4b4e !important;
539
  color: #e5e7eb !important;
540
+ transition: all 0.2s;
541
+ margin: 4px !important;
542
+ }
543
+ .examples-container button:hover {
544
+ background: #4a4b4e !important;
545
+ transform: translateY(-1px);
546
  }
547
  .markdown-content {
548
  color: #e5e7eb !important;
549
  }
550
  .markdown-content h1, .markdown-content h2, .markdown-content h3 {
551
  color: white !important;
552
+ margin-top: 1.2em !important;
553
+ margin-bottom: 0.8em !important;
554
+ }
555
+ .markdown-content h1 {
556
+ font-size: 1.7em !important;
557
+ }
558
+ .markdown-content h2 {
559
+ font-size: 1.5em !important;
560
+ }
561
+ .markdown-content h3 {
562
+ font-size: 1.3em !important;
563
  }
564
  .markdown-content a {
565
  color: #60a5fa !important;
566
+ text-decoration: none !important;
567
+ transition: all 0.2s;
568
+ }
569
+ .markdown-content a:hover {
570
+ color: #93c5fd !important;
571
+ text-decoration: underline !important;
572
+ }
573
+ .markdown-content code {
574
+ background: #2c2d30 !important;
575
+ padding: 2px 6px !important;
576
+ border-radius: 4px !important;
577
+ font-family: monospace !important;
578
+ }
579
+ .markdown-content pre {
580
+ background: #2c2d30 !important;
581
+ padding: 12px !important;
582
+ border-radius: 8px !important;
583
+ overflow-x: auto !important;
584
+ }
585
+ .markdown-content blockquote {
586
+ border-left: 4px solid #60a5fa !important;
587
+ padding-left: 1em !important;
588
+ margin-left: 0 !important;
589
+ color: #a8a9ab !important;
590
+ }
591
+ .markdown-content table {
592
+ border-collapse: collapse !important;
593
+ width: 100% !important;
594
+ }
595
+ .markdown-content th, .markdown-content td {
596
+ padding: 8px 12px !important;
597
+ border: 1px solid #4a4b4e !important;
598
+ }
599
+ .markdown-content th {
600
+ background: #2c2d30 !important;
601
  }
602
  .accordion {
603
  background: #2c2d30 !important;
604
  border-radius: 8px !important;
605
  margin-top: 1rem !important;
606
+ box-shadow: 0 2px 8px rgba(0,0,0,0.1) !important;
607
  }
608
  .voice-selector {
609
  margin-top: 1rem;
 
615
  background: #3a3b3e !important;
616
  color: white !important;
617
  border: 1px solid #4a4b4e !important;
618
+ border-radius: 4px !important;
619
+ padding: 8px !important;
620
+ transition: all 0.2s;
621
+ }
622
+ .voice-selector select:focus {
623
+ border-color: #60a5fa !important;
624
+ }
625
+ .audio-player {
626
+ margin-top: 1rem;
627
+ background: #2c2d30 !important;
628
+ border-radius: 8px !important;
629
+ padding: 0.5rem !important;
630
+ }
631
+ .audio-player audio {
632
+ width: 100% !important;
633
+ }
634
+ .searching, .error {
635
+ padding: 1rem;
636
+ border-radius: 8px;
637
+ text-align: center;
638
+ margin: 1rem 0;
639
+ }
640
+ .searching {
641
+ background: rgba(96, 165, 250, 0.1);
642
+ color: #60a5fa;
643
+ }
644
+ .error {
645
+ background: rgba(239, 68, 68, 0.1);
646
+ color: #ef4444;
647
+ }
648
+ .no-sources {
649
+ padding: 1rem;
650
+ text-align: center;
651
+ color: #a8a9ab;
652
+ background: #2c2d30;
653
+ border-radius: 8px;
654
+ }
655
+ @keyframes pulse {
656
+ 0% { opacity: 0.6; }
657
+ 50% { opacity: 1; }
658
+ 100% { opacity: 0.6; }
659
+ }
660
+ .searching {
661
+ animation: pulse 1.5s infinite;
662
  }
663
  """
664
 
665
+ # --- Gradio Interface ---
666
  with gr.Blocks(title="AI Search Assistant", css=css, theme="dark") as demo:
667
  chat_history = gr.State([])
668
 
 
678
  scale=5,
679
  container=False
680
  )
 
681
  voice_select = gr.Dropdown(
682
+ choices=list(VOICE_CHOICES.keys()),
683
+ value=list(VOICE_CHOICES.keys())[0],
684
+ label="Voice",
685
+ elem_classes="voice-selector",
686
+ scale=1
687
  )
688
+ search_btn = gr.Button("Search", variant="primary", scale=1)
689
 
690
  with gr.Row(elem_classes="results-container"):
691
  with gr.Column(scale=2):
 
703
  with gr.Row(elem_classes="examples-container"):
704
  gr.Examples(
705
  examples=[
706
+ "Latest news about artificial intelligence advances",
707
+ "How does blockchain technology work?",
708
  "What are the best practices for sustainable living?",
709
+ "Compare electric vehicles and traditional cars"
710
  ],
711
  inputs=search_input,
712
  label="Try these examples"
713
  )
714
 
715
+ # Handle voice selection mapping
716
+ def get_voice_id(voice_name):
717
+ return VOICE_CHOICES.get(voice_name, 'af')
718
+
719
  # Handle interactions
720
  search_btn.click(
721
  fn=process_query,
722
+ inputs=[search_input, chat_history, lambda x: get_voice_id(x), voice_select],
723
  outputs=[answer_output, sources_output, search_btn, chat_history_display, audio_output]
724
  )
725
 
726
  # Also trigger search on Enter key
727
  search_input.submit(
728
  fn=process_query,
729
+ inputs=[search_input, chat_history, lambda x: get_voice_id(x), voice_select],
730
  outputs=[answer_output, sources_output, search_btn, chat_history_display, audio_output]
731
  )
732
 
733
  if __name__ == "__main__":
734
+ # Start the app with optimized settings
735
+ demo.queue(concurrency_count=5, max_size=20).launch(share=True)