Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -14,24 +14,39 @@ import io
|
|
14 |
|
15 |
class GroqLLM:
|
16 |
"""Compatible LLM interface for smolagents CodeAgent"""
|
17 |
-
def __init__(self, model_name="
|
18 |
self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
19 |
self.model_name = model_name
|
20 |
|
21 |
def __call__(self, prompt: str) -> str:
|
22 |
"""Make the class callable as required by smolagents"""
|
23 |
-
|
|
|
|
|
|
|
24 |
try:
|
25 |
-
|
|
|
26 |
model=self.model_name,
|
27 |
-
messages=
|
|
|
|
|
|
|
28 |
temperature=0.7,
|
29 |
max_tokens=1024,
|
30 |
stream=False
|
31 |
)
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
33 |
except Exception as e:
|
34 |
-
|
|
|
|
|
|
|
35 |
|
36 |
@tool
|
37 |
def analyze_basic_stats(data: pd.DataFrame) -> str:
|
|
|
14 |
|
15 |
class GroqLLM:
|
16 |
"""Compatible LLM interface for smolagents CodeAgent"""
|
17 |
+
def __init__(self, model_name="llama-3.1-8B-Instant"):
|
18 |
self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
19 |
self.model_name = model_name
|
20 |
|
21 |
def __call__(self, prompt: str) -> str:
|
22 |
"""Make the class callable as required by smolagents"""
|
23 |
+
# Ensure the prompt is a string
|
24 |
+
if not isinstance(prompt, str):
|
25 |
+
return "Error: Prompt must be a string"
|
26 |
+
|
27 |
try:
|
28 |
+
# Create a properly formatted message
|
29 |
+
completion = self.client.chat.completions.create(
|
30 |
model=self.model_name,
|
31 |
+
messages=[{
|
32 |
+
"role": "user",
|
33 |
+
"content": str(prompt) # Ensure content is string
|
34 |
+
}],
|
35 |
temperature=0.7,
|
36 |
max_tokens=1024,
|
37 |
stream=False
|
38 |
)
|
39 |
+
|
40 |
+
# Extract and return the response content
|
41 |
+
if completion.choices and len(completion.choices) > 0:
|
42 |
+
return completion.choices[0].message.content
|
43 |
+
return "Error: No response generated"
|
44 |
+
|
45 |
except Exception as e:
|
46 |
+
# Provide more detailed error handling
|
47 |
+
error_msg = f"Error generating response: {str(e)}"
|
48 |
+
print(error_msg) # Log the error
|
49 |
+
return error_msg
|
50 |
|
51 |
@tool
|
52 |
def analyze_basic_stats(data: pd.DataFrame) -> str:
|