Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -24,9 +24,6 @@ from transformers import (
|
|
24 |
from transformers.image_utils import load_image
|
25 |
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
|
26 |
|
27 |
-
# -----------------------------
|
28 |
-
# Existing global variables and model setup
|
29 |
-
# -----------------------------
|
30 |
MAX_MAX_NEW_TOKENS = 2048
|
31 |
DEFAULT_MAX_NEW_TOKENS = 1024
|
32 |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
|
@@ -54,38 +51,6 @@ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
54 |
torch_dtype=torch.float16
|
55 |
).to("cuda").eval()
|
56 |
|
57 |
-
# -----------------------------
|
58 |
-
# New reasoning feature setup
|
59 |
-
# -----------------------------
|
60 |
-
from openai import OpenAI
|
61 |
-
|
62 |
-
api_key = os.getenv("SAMBANOVA_API_KEY")
|
63 |
-
client_reasoning = OpenAI(
|
64 |
-
base_url="https://api.sambanova.ai/v1/",
|
65 |
-
api_key=api_key,
|
66 |
-
)
|
67 |
-
|
68 |
-
def reasoning_predict(message, history):
|
69 |
-
"""
|
70 |
-
This function appends the user's reasoning request to the history,
|
71 |
-
then streams the response from the Sambanova API using the model
|
72 |
-
'DeepSeek-R1-Distill-Llama-70B'.
|
73 |
-
"""
|
74 |
-
history.append({"role": "user", "content": message})
|
75 |
-
stream = client_reasoning.chat.completions.create(
|
76 |
-
messages=history,
|
77 |
-
model="DeepSeek-R1-Distill-Llama-70B",
|
78 |
-
stream=True,
|
79 |
-
)
|
80 |
-
chunks = []
|
81 |
-
for chunk in stream:
|
82 |
-
# Accumulate streamed content and yield the current full response
|
83 |
-
chunks.append(chunk.choices[0].delta.content or "")
|
84 |
-
yield "".join(chunks)
|
85 |
-
|
86 |
-
# -----------------------------
|
87 |
-
# Utility Functions and Checks
|
88 |
-
# -----------------------------
|
89 |
async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
|
90 |
communicate = edge_tts.Communicate(text, voice)
|
91 |
await communicate.save(output_file)
|
@@ -124,9 +89,6 @@ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
|
124 |
|
125 |
dtype = torch.float16 if device.type == "cuda" else torch.float32
|
126 |
|
127 |
-
# -----------------------------
|
128 |
-
# Image Generation Models Setup
|
129 |
-
# -----------------------------
|
130 |
if torch.cuda.is_available():
|
131 |
# Lightning 5 model
|
132 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
@@ -212,9 +174,6 @@ def save_image(img: Image.Image) -> str:
|
|
212 |
img.save(unique_name)
|
213 |
return unique_name
|
214 |
|
215 |
-
# -----------------------------
|
216 |
-
# Main Generation Function with Reasoning Integration
|
217 |
-
# -----------------------------
|
218 |
@spaces.GPU
|
219 |
def generate(
|
220 |
input_dict: dict,
|
@@ -229,7 +188,6 @@ def generate(
|
|
229 |
files = input_dict.get("files", [])
|
230 |
|
231 |
lower_text = text.lower().strip()
|
232 |
-
|
233 |
# Check if the prompt is an image generation command using model flags.
|
234 |
if (lower_text.startswith("@lightningv5") or
|
235 |
lower_text.startswith("@lightningv4") or
|
@@ -282,20 +240,7 @@ def generate(
|
|
282 |
yield gr.Image(image_path)
|
283 |
return
|
284 |
|
285 |
-
# -----------------------------
|
286 |
-
# NEW: Reasoning Branch
|
287 |
-
# -----------------------------
|
288 |
-
if lower_text.startswith("@reasoning"):
|
289 |
-
reasoning_text = text.replace("@reasoning", "").strip()
|
290 |
-
reasoning_history = clean_chat_history(chat_history)
|
291 |
-
yield "Reasoning..."
|
292 |
-
for response in reasoning_predict(reasoning_text, reasoning_history):
|
293 |
-
yield response
|
294 |
-
return
|
295 |
-
|
296 |
-
# -----------------------------
|
297 |
# Otherwise, handle text/chat (and TTS) generation.
|
298 |
-
# -----------------------------
|
299 |
tts_prefix = "@tts"
|
300 |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
|
301 |
voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
|
@@ -400,7 +345,6 @@ demo = gr.ChatInterface(
|
|
400 |
['@turbov3 "Abstract art, colorful and vibrant"'],
|
401 |
["Write a Python function to check if a number is prime."],
|
402 |
["@tts2 What causes rainbows to form?"],
|
403 |
-
["@reasoning How does quantum entanglement work and what are its implications?"],
|
404 |
],
|
405 |
cache_examples=False,
|
406 |
type="messages",
|
|
|
24 |
from transformers.image_utils import load_image
|
25 |
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
|
26 |
|
|
|
|
|
|
|
27 |
MAX_MAX_NEW_TOKENS = 2048
|
28 |
DEFAULT_MAX_NEW_TOKENS = 1024
|
29 |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
|
|
|
51 |
torch_dtype=torch.float16
|
52 |
).to("cuda").eval()
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
|
55 |
communicate = edge_tts.Communicate(text, voice)
|
56 |
await communicate.save(output_file)
|
|
|
89 |
|
90 |
dtype = torch.float16 if device.type == "cuda" else torch.float32
|
91 |
|
|
|
|
|
|
|
92 |
if torch.cuda.is_available():
|
93 |
# Lightning 5 model
|
94 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
|
174 |
img.save(unique_name)
|
175 |
return unique_name
|
176 |
|
|
|
|
|
|
|
177 |
@spaces.GPU
|
178 |
def generate(
|
179 |
input_dict: dict,
|
|
|
188 |
files = input_dict.get("files", [])
|
189 |
|
190 |
lower_text = text.lower().strip()
|
|
|
191 |
# Check if the prompt is an image generation command using model flags.
|
192 |
if (lower_text.startswith("@lightningv5") or
|
193 |
lower_text.startswith("@lightningv4") or
|
|
|
240 |
yield gr.Image(image_path)
|
241 |
return
|
242 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
# Otherwise, handle text/chat (and TTS) generation.
|
|
|
244 |
tts_prefix = "@tts"
|
245 |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
|
246 |
voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
|
|
|
345 |
['@turbov3 "Abstract art, colorful and vibrant"'],
|
346 |
["Write a Python function to check if a number is prime."],
|
347 |
["@tts2 What causes rainbows to form?"],
|
|
|
348 |
],
|
349 |
cache_examples=False,
|
350 |
type="messages",
|