Spaces:
Running
Running
import streamlit as st | |
import numpy as np | |
import pandas as pd | |
from smolagents import CodeAgent, tool | |
from typing import Union, List, Dict, Optional | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import os | |
from groq import Groq | |
from dataclasses import dataclass | |
import tempfile | |
import base64 | |
import io | |
class GroqLLM: | |
"""Compatible LLM interface for smolagents CodeAgent""" | |
def __init__(self, model_name="llama-3.1-8B-Instant"): | |
self.client = Groq(api_key=os.environ.get("GROQ_API_KEY")) | |
self.model_name = model_name | |
def __call__(self, prompt: Union[str, dict, List[Dict]]) -> str: | |
"""Make the class callable as required by smolagents""" | |
try: | |
# Handle different prompt formats | |
if isinstance(prompt, (dict, list)): | |
prompt_str = str(prompt) | |
else: | |
prompt_str = str(prompt) | |
# Create a properly formatted message | |
completion = self.client.chat.completions.create( | |
model=self.model_name, | |
messages=[{ | |
"role": "user", | |
"content": prompt_str | |
}], | |
temperature=0.7, | |
max_tokens=1024, | |
stream=False | |
) | |
return completion.choices[0].message.content if completion.choices else "Error: No response generated" | |
except Exception as e: | |
error_msg = f"Error generating response: {str(e)}" | |
print(error_msg) | |
return error_msg | |
class DataAnalysisAgent(CodeAgent): | |
"""Extended CodeAgent with dataset awareness""" | |
def __init__(self, dataset: pd.DataFrame, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self._dataset = dataset | |
def dataset(self) -> pd.DataFrame: | |
"""Access the stored dataset""" | |
return self._dataset | |
def run(self, prompt: str) -> str: | |
"""Override run method to include dataset context""" | |
dataset_info = f""" | |
Dataset Shape: {self.dataset.shape} | |
Columns: {', '.join(self.dataset.columns)} | |
Data Types: {self.dataset.dtypes.to_dict()} | |
""" | |
enhanced_prompt = f""" | |
Analyze the following dataset: | |
{dataset_info} | |
Task: {prompt} | |
Use the provided tools to analyze this specific dataset and return detailed results. | |
""" | |
return super().run(enhanced_prompt) | |
def analyze_basic_stats(data: pd.DataFrame) -> str: | |
"""Calculate basic statistical measures for numerical columns in the dataset. | |
This function computes fundamental statistical metrics including mean, median, | |
standard deviation, skewness, and counts of missing values for all numerical | |
columns in the provided DataFrame. | |
Args: | |
data: A pandas DataFrame containing the dataset to analyze. The DataFrame | |
should contain at least one numerical column for meaningful analysis. | |
Returns: | |
str: A string containing formatted basic statistics for each numerical column, | |
including mean, median, standard deviation, skewness, and missing value counts. | |
""" | |
# Access dataset from agent if no data provided | |
if data is None: | |
data = tool.agent.dataset | |
stats = {} | |
numeric_cols = data.select_dtypes(include=[np.number]).columns | |
for col in numeric_cols: | |
stats[col] = { | |
'mean': float(data[col].mean()), | |
'median': float(data[col].median()), | |
'std': float(data[col].std()), | |
'skew': float(data[col].skew()), | |
'missing': int(data[col].isnull().sum()) | |
} | |
return str(stats) | |
def generate_correlation_matrix(data: pd.DataFrame) -> str: | |
"""Generate a visual correlation matrix for numerical columns in the dataset. | |
This function creates a heatmap visualization showing the correlations between | |
all numerical columns in the dataset. The correlation values are displayed | |
using a color-coded matrix for easy interpretation. | |
Args: | |
data: A pandas DataFrame containing the dataset to analyze. The DataFrame | |
should contain at least two numerical columns for correlation analysis. | |
Returns: | |
str: A base64 encoded string representing the correlation matrix plot image, | |
which can be displayed in a web interface or saved as an image file. | |
""" | |
# Access dataset from agent if no data provided | |
if data is None: | |
data = tool.agent.dataset | |
numeric_data = data.select_dtypes(include=[np.number]) | |
plt.figure(figsize=(10, 8)) | |
sns.heatmap(numeric_data.corr(), annot=True, cmap='coolwarm') | |
plt.title('Correlation Matrix') | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png') | |
plt.close() | |
return base64.b64encode(buf.getvalue()).decode() | |
def analyze_categorical_columns(data: pd.DataFrame) -> str: | |
"""Analyze categorical columns in the dataset for distribution and frequencies. | |
This function examines categorical columns to identify unique values, top categories, | |
and missing value counts, providing insights into the categorical data distribution. | |
Args: | |
data: A pandas DataFrame containing the dataset to analyze. The DataFrame | |
should contain at least one categorical column for meaningful analysis. | |
Returns: | |
str: A string containing formatted analysis results for each categorical column, | |
including unique value counts, top categories, and missing value counts. | |
""" | |
# Access dataset from agent if no data provided | |
if data is None: | |
data = tool.agent.dataset | |
categorical_cols = data.select_dtypes(include=['object', 'category']).columns | |
analysis = {} | |
for col in categorical_cols: | |
analysis[col] = { | |
'unique_values': int(data[col].nunique()), | |
'top_categories': data[col].value_counts().head(5).to_dict(), | |
'missing': int(data[col].isnull().sum()) | |
} | |
return str(analysis) | |
def suggest_features(data: pd.DataFrame) -> str: | |
"""Suggest potential feature engineering steps based on data characteristics. | |
This function analyzes the dataset's structure and statistical properties to | |
recommend possible feature engineering steps that could improve model performance. | |
Args: | |
data: A pandas DataFrame containing the dataset to analyze. The DataFrame | |
can contain both numerical and categorical columns. | |
Returns: | |
str: A string containing suggestions for feature engineering based on | |
the characteristics of the input data. | |
""" | |
# Access dataset from agent if no data provided | |
if data is None: | |
data = tool.agent.dataset | |
suggestions = [] | |
numeric_cols = data.select_dtypes(include=[np.number]).columns | |
categorical_cols = data.select_dtypes(include=['object', 'category']).columns | |
if len(numeric_cols) >= 2: | |
suggestions.append("Consider creating interaction terms between numerical features") | |
if len(categorical_cols) > 0: | |
suggestions.append("Consider one-hot encoding for categorical variables") | |
for col in numeric_cols: | |
if data[col].skew() > 1 or data[col].skew() < -1: | |
suggestions.append(f"Consider log transformation for {col} due to skewness") | |
return '\n'.join(suggestions) | |
def main(): | |
st.title("Data Analysis Assistant") | |
st.write("Upload your dataset and get automated analysis with natural language interaction.") | |
# Initialize session state | |
if 'data' not in st.session_state: | |
st.session_state['data'] = None | |
if 'agent' not in st.session_state: | |
st.session_state['agent'] = None | |
uploaded_file = st.file_uploader("Choose a CSV file", type="csv") | |
try: | |
if uploaded_file is not None: | |
with st.spinner('Loading and processing your data...'): | |
# Load the dataset | |
data = pd.read_csv(uploaded_file) | |
st.session_state['data'] = data | |
# Initialize the agent with the dataset | |
st.session_state['agent'] = DataAnalysisAgent( | |
dataset=data, | |
tools=[analyze_basic_stats, generate_correlation_matrix, | |
analyze_categorical_columns, suggest_features], | |
model=GroqLLM(), | |
additional_authorized_imports=["pandas", "numpy", "matplotlib", "seaborn"] | |
) | |
st.success(f'Successfully loaded dataset with {data.shape[0]} rows and {data.shape[1]} columns') | |
st.subheader("Data Preview") | |
st.dataframe(data.head()) | |
if st.session_state['data'] is not None: | |
analysis_type = st.selectbox( | |
"Choose analysis type", | |
["Basic Statistics", "Correlation Analysis", "Categorical Analysis", | |
"Feature Engineering", "Custom Question"] | |
) | |
if analysis_type == "Basic Statistics": | |
with st.spinner('Analyzing basic statistics...'): | |
result = st.session_state['agent'].run( | |
"Use the analyze_basic_stats tool to analyze this dataset and " | |
"provide insights about the numerical distributions." | |
) | |
st.write(result) | |
elif analysis_type == "Correlation Analysis": | |
with st.spinner('Generating correlation matrix...'): | |
result = st.session_state['agent'].run( | |
"Use the generate_correlation_matrix tool to analyze correlations " | |
"and explain any strong relationships found." | |
) | |
if isinstance(result, str) and result.startswith('data:image') or ',' in result: | |
st.image(f"data:image/png;base64,{result.split(',')[-1]}") | |
else: | |
st.write(result) | |
elif analysis_type == "Categorical Analysis": | |
with st.spinner('Analyzing categorical columns...'): | |
result = st.session_state['agent'].run( | |
"Use the analyze_categorical_columns tool to examine the " | |
"categorical variables and explain the distributions." | |
) | |
st.write(result) | |
elif analysis_type == "Feature Engineering": | |
with st.spinner('Generating feature suggestions...'): | |
result = st.session_state['agent'].run( | |
"Use the suggest_features tool to recommend potential " | |
"feature engineering steps for this dataset." | |
) | |
st.write(result) | |
elif analysis_type == "Custom Question": | |
question = st.text_input("What would you like to know about your data?") | |
if question: | |
with st.spinner('Analyzing...'): | |
result = st.session_state['agent'].run(question) | |
st.write(result) | |
except Exception as e: | |
st.error(f"An error occurred: {str(e)}") | |
if __name__ == "__main__": | |
main() |