prithivMLmods commited on
Commit
4982b30
·
verified ·
1 Parent(s): 9a88eba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -0
app.py CHANGED
@@ -13,6 +13,7 @@ import torch
13
  import numpy as np
14
  from PIL import Image
15
  import edge_tts
 
16
 
17
  from transformers import (
18
  AutoModelForCausalLM,
@@ -240,6 +241,31 @@ def generate(
240
  yield gr.Image(image_path)
241
  return
242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  # Otherwise, handle text/chat (and TTS) generation.
244
  tts_prefix = "@tts"
245
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
@@ -345,6 +371,7 @@ demo = gr.ChatInterface(
345
  ['@turbov3 "Abstract art, colorful and vibrant"'],
346
  ["Write a Python function to check if a number is prime."],
347
  ["@tts2 What causes rainbows to form?"],
 
348
  ],
349
  cache_examples=False,
350
  type="messages",
 
13
  import numpy as np
14
  from PIL import Image
15
  import edge_tts
16
+ import openai
17
 
18
  from transformers import (
19
  AutoModelForCausalLM,
 
241
  yield gr.Image(image_path)
242
  return
243
 
244
+ # New reasoning feature implementation.
245
+ if lower_text.startswith("@reasoning"):
246
+ reasoning_prompt = text.replace("@reasoning", "").strip()
247
+ messages = [
248
+ {"role": "system", "content": "You are a helpful assistant"},
249
+ {"role": "user", "content": reasoning_prompt}
250
+ ]
251
+ try:
252
+ client = openai.OpenAI(
253
+ api_key=os.environ.get("SAMBANOVA_API_KEY"),
254
+ base_url="https://api.sambanova.ai/v1",
255
+ )
256
+ response = client.chat.completions.create(
257
+ model="DeepSeek-R1-Distill-Llama-70B",
258
+ messages=messages,
259
+ temperature=0.1,
260
+ top_p=0.1
261
+ )
262
+ reasoning_response = response.choices[0].message.content
263
+ except Exception as e:
264
+ reasoning_response = "Error in reasoning request: " + str(e)
265
+ yield "Thinking..."
266
+ yield reasoning_response
267
+ return
268
+
269
  # Otherwise, handle text/chat (and TTS) generation.
270
  tts_prefix = "@tts"
271
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
 
371
  ['@turbov3 "Abstract art, colorful and vibrant"'],
372
  ["Write a Python function to check if a number is prime."],
373
  ["@tts2 What causes rainbows to form?"],
374
+ ["@reasoning Explain the theory of relativity."],
375
  ],
376
  cache_examples=False,
377
  type="messages",