girishwangikar commited on
Commit
0e7b5e0
·
verified ·
1 Parent(s): 1b4aded

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -83
app.py CHANGED
@@ -12,7 +12,6 @@ import tempfile
12
  import base64
13
  import io
14
 
15
- # Custom Groq Model Class remains unchanged
16
  class GroqModel:
17
  def __init__(self, model_name="llama2-70b-4096"):
18
  self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
@@ -29,15 +28,10 @@ class GroqModel:
29
  )
30
  return response.choices[0].message.content
31
 
 
32
  @tool
33
  def analyze_basic_stats(data: pd.DataFrame) -> str:
34
- """Calculate basic statistics for numerical columns.
35
-
36
- Args:
37
- data: Input DataFrame
38
- Returns:
39
- String containing formatted basic statistics
40
- """
41
  stats = {}
42
  numeric_cols = data.select_dtypes(include=[np.number]).columns
43
 
@@ -54,13 +48,7 @@ def analyze_basic_stats(data: pd.DataFrame) -> str:
54
 
55
  @tool
56
  def generate_correlation_matrix(data: pd.DataFrame) -> str:
57
- """Generate correlation matrix visualization for numerical columns.
58
-
59
- Args:
60
- data: Input DataFrame
61
- Returns:
62
- Base64 encoded string of correlation matrix plot
63
- """
64
  numeric_data = data.select_dtypes(include=[np.number])
65
 
66
  plt.figure(figsize=(10, 8))
@@ -74,13 +62,7 @@ def generate_correlation_matrix(data: pd.DataFrame) -> str:
74
 
75
  @tool
76
  def analyze_categorical_columns(data: pd.DataFrame) -> str:
77
- """Analyze categorical columns in the dataset.
78
-
79
- Args:
80
- data: Input DataFrame
81
- Returns:
82
- String containing formatted categorical analysis
83
- """
84
  categorical_cols = data.select_dtypes(include=['object', 'category']).columns
85
  analysis = {}
86
 
@@ -95,13 +77,7 @@ def analyze_categorical_columns(data: pd.DataFrame) -> str:
95
 
96
  @tool
97
  def suggest_features(data: pd.DataFrame) -> str:
98
- """Suggest potential feature engineering steps.
99
-
100
- Args:
101
- data: Input DataFrame
102
- Returns:
103
- String containing feature engineering suggestions
104
- """
105
  suggestions = []
106
  numeric_cols = data.select_dtypes(include=[np.number]).columns
107
  categorical_cols = data.select_dtypes(include=['object', 'category']).columns
@@ -118,70 +94,110 @@ def suggest_features(data: pd.DataFrame) -> str:
118
 
119
  return '\n'.join(suggestions)
120
 
121
- # Streamlit App
 
 
 
 
 
 
 
 
 
 
122
  def main():
123
  st.title("Data Analysis Assistant")
124
  st.write("Upload your dataset and get automated analysis with natural language interaction.")
125
 
 
 
 
 
126
  uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
127
 
128
- if uploaded_file is not None:
129
- data = pd.read_csv(uploaded_file)
130
- st.session_state['data'] = data
131
-
132
- # Initialize agent
133
- agent = CodeAgent(
134
- tools=[analyze_basic_stats, generate_correlation_matrix,
135
- analyze_categorical_columns, suggest_features],
136
- model=GroqModel(),
137
- additional_authorized_imports=["pandas", "numpy", "matplotlib", "seaborn"]
138
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
- # Analysis options
141
- analysis_type = st.selectbox(
142
- "Choose analysis type",
143
- ["Basic Statistics", "Correlation Analysis", "Categorical Analysis",
144
- "Feature Engineering", "Custom Question"]
145
- )
146
-
147
- if analysis_type == "Basic Statistics":
148
- result = agent.run(
149
- f"Analyze and explain the basic statistics of this dataset. "
150
- f"Dataset info: {data.info()}\n"
151
- f"Use the analyze_basic_stats tool and provide natural language explanations."
152
  )
153
- st.write(result)
154
 
155
- elif analysis_type == "Correlation Analysis":
156
- correlation_plot = agent.run(
157
- "Generate and explain correlations between numerical variables. "
158
- "Use the generate_correlation_matrix tool."
159
- )
160
- if correlation_plot:
161
- st.image(f"data:image/png;base64,{correlation_plot}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
- elif analysis_type == "Categorical Analysis":
164
- result = agent.run(
165
- "Analyze categorical variables in the dataset. "
166
- "Use the analyze_categorical_columns tool and explain the findings."
167
- )
168
- st.write(result)
169
-
170
- elif analysis_type == "Feature Engineering":
171
- result = agent.run(
172
- "Suggest potential feature engineering steps for this dataset. "
173
- "Use the suggest_features tool and explain your suggestions."
174
- )
175
- st.write(result)
176
-
177
- elif analysis_type == "Custom Question":
178
- question = st.text_input("What would you like to know about your data?")
179
- if question:
180
- result = agent.run(
181
- f"Answer this question about the dataset: {question}\n"
182
- f"Use appropriate tools to analyze and explain."
183
- )
184
- st.write(result)
185
 
186
  if __name__ == "__main__":
187
  main()
 
12
  import base64
13
  import io
14
 
 
15
  class GroqModel:
16
  def __init__(self, model_name="llama2-70b-4096"):
17
  self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
 
28
  )
29
  return response.choices[0].message.content
30
 
31
+ # Tool functions remain unchanged
32
  @tool
33
  def analyze_basic_stats(data: pd.DataFrame) -> str:
34
+ """Calculate basic statistics for numerical columns."""
 
 
 
 
 
 
35
  stats = {}
36
  numeric_cols = data.select_dtypes(include=[np.number]).columns
37
 
 
48
 
49
  @tool
50
  def generate_correlation_matrix(data: pd.DataFrame) -> str:
51
+ """Generate correlation matrix visualization for numerical columns."""
 
 
 
 
 
 
52
  numeric_data = data.select_dtypes(include=[np.number])
53
 
54
  plt.figure(figsize=(10, 8))
 
62
 
63
  @tool
64
  def analyze_categorical_columns(data: pd.DataFrame) -> str:
65
+ """Analyze categorical columns in the dataset."""
 
 
 
 
 
 
66
  categorical_cols = data.select_dtypes(include=['object', 'category']).columns
67
  analysis = {}
68
 
 
77
 
78
  @tool
79
  def suggest_features(data: pd.DataFrame) -> str:
80
+ """Suggest potential feature engineering steps."""
 
 
 
 
 
 
81
  suggestions = []
82
  numeric_cols = data.select_dtypes(include=[np.number]).columns
83
  categorical_cols = data.select_dtypes(include=['object', 'category']).columns
 
94
 
95
  return '\n'.join(suggestions)
96
 
97
+ def initialize_session_state():
98
+ """Initialize session state variables"""
99
+ if 'data' not in st.session_state:
100
+ st.session_state['data'] = None
101
+ if 'agent' not in st.session_state:
102
+ st.session_state['agent'] = None
103
+ if 'file_uploaded' not in st.session_state:
104
+ st.session_state['file_uploaded'] = False
105
+ if 'processing' not in st.session_state:
106
+ st.session_state['processing'] = False
107
+
108
  def main():
109
  st.title("Data Analysis Assistant")
110
  st.write("Upload your dataset and get automated analysis with natural language interaction.")
111
 
112
+ # Initialize session state
113
+ initialize_session_state()
114
+
115
+ # File uploader with error handling
116
  uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
117
 
118
+ try:
119
+ if uploaded_file is not None and not st.session_state['file_uploaded']:
120
+ # Show loading spinner while processing the file
121
+ with st.spinner('Loading and processing your data...'):
122
+ try:
123
+ data = pd.read_csv(uploaded_file)
124
+ st.session_state['data'] = data
125
+ st.session_state['file_uploaded'] = True
126
+
127
+ # Initialize agent
128
+ st.session_state['agent'] = CodeAgent(
129
+ tools=[analyze_basic_stats, generate_correlation_matrix,
130
+ analyze_categorical_columns, suggest_features],
131
+ model=GroqModel(),
132
+ additional_authorized_imports=["pandas", "numpy", "matplotlib", "seaborn"]
133
+ )
134
+
135
+ # Show success message
136
+ st.success(f'Successfully loaded dataset with {data.shape[0]} rows and {data.shape[1]} columns')
137
+
138
+ # Display data preview
139
+ st.subheader("Data Preview")
140
+ st.dataframe(data.head())
141
+
142
+ except Exception as e:
143
+ st.error(f"Error loading file: {str(e)}")
144
+ st.session_state['file_uploaded'] = False
145
+ return
146
 
147
+ # Only show analysis options if data is loaded
148
+ if st.session_state['file_uploaded'] and st.session_state['data'] is not None:
149
+ # Analysis options
150
+ analysis_type = st.selectbox(
151
+ "Choose analysis type",
152
+ ["Basic Statistics", "Correlation Analysis", "Categorical Analysis",
153
+ "Feature Engineering", "Custom Question"]
 
 
 
 
 
154
  )
 
155
 
156
+ # Process analysis with loading indicators
157
+ if analysis_type:
158
+ with st.spinner(f'Performing {analysis_type.lower()}...'):
159
+ if analysis_type == "Basic Statistics":
160
+ result = st.session_state['agent'].run(
161
+ f"Analyze and explain the basic statistics of this dataset. "
162
+ f"Dataset info: {st.session_state['data'].info()}\n"
163
+ f"Use the analyze_basic_stats tool and provide natural language explanations."
164
+ )
165
+ st.write(result)
166
+
167
+ elif analysis_type == "Correlation Analysis":
168
+ correlation_plot = st.session_state['agent'].run(
169
+ "Generate and explain correlations between numerical variables. "
170
+ "Use the generate_correlation_matrix tool."
171
+ )
172
+ if correlation_plot:
173
+ st.image(f"data:image/png;base64,{correlation_plot}")
174
+
175
+ elif analysis_type == "Categorical Analysis":
176
+ result = st.session_state['agent'].run(
177
+ "Analyze categorical variables in the dataset. "
178
+ "Use the analyze_categorical_columns tool and explain the findings."
179
+ )
180
+ st.write(result)
181
+
182
+ elif analysis_type == "Feature Engineering":
183
+ result = st.session_state['agent'].run(
184
+ "Suggest potential feature engineering steps for this dataset. "
185
+ "Use the suggest_features tool and explain your suggestions."
186
+ )
187
+ st.write(result)
188
+
189
+ elif analysis_type == "Custom Question":
190
+ question = st.text_input("What would you like to know about your data?")
191
+ if question:
192
+ result = st.session_state['agent'].run(
193
+ f"Answer this question about the dataset: {question}\n"
194
+ f"Use appropriate tools to analyze and explain."
195
+ )
196
+ st.write(result)
197
 
198
+ except Exception as e:
199
+ st.error(f"An error occurred: {str(e)}")
200
+ st.session_state['file_uploaded'] = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
  if __name__ == "__main__":
203
  main()