DrishtiSharma commited on
Commit
43546ab
Β·
verified Β·
1 Parent(s): 72ea0b4

Update test.py

Browse files
Files changed (1) hide show
  1. test.py +344 -123
test.py CHANGED
@@ -25,7 +25,7 @@ from datasets import load_dataset
25
  import tempfile
26
 
27
  st.title("SQL-RAG Using CrewAI πŸš€")
28
- st.write("Analyze datasets using natural language queries powered by SQL and CrewAI.")
29
 
30
  # Initialize LLM
31
  llm = None
@@ -86,88 +86,348 @@ if st.session_state.df is not None and st.session_state.show_preview:
86
  st.subheader("πŸ“‚ Dataset Preview")
87
  st.dataframe(st.session_state.df.head())
88
 
89
- # Function to create TXT file
90
- def create_text_report_with_viz_temp(report, conclusion, visualizations):
91
- content = f"### Analysis Report\n\n{report}\n\n### Visualizations\n"
92
-
93
- for i, fig in enumerate(visualizations, start=1):
94
- fig_title = fig.layout.title.text if fig.layout.title.text else f"Visualization {i}"
95
- x_axis = fig.layout.xaxis.title.text if fig.layout.xaxis.title.text else "X-axis"
96
- y_axis = fig.layout.yaxis.title.text if fig.layout.yaxis.title.text else "Y-axis"
97
-
98
- content += f"\n{i}. {fig_title}\n"
99
- content += f" - X-axis: {x_axis}\n"
100
- content += f" - Y-axis: {y_axis}\n"
101
-
102
- if fig.data:
103
- trace_types = set(trace.type for trace in fig.data)
104
- content += f" - Chart Type(s): {', '.join(trace_types)}\n"
105
- else:
106
- content += " - No data available in this visualization.\n"
107
-
108
- content += f"\n\n\n{conclusion}"
109
-
110
- with tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode='w', encoding='utf-8') as temp_txt:
111
- temp_txt.write(content)
112
- return temp_txt.name
113
-
114
-
115
- # Function to create PDF with report text and visualizations
116
- def create_pdf_report_with_viz(report, conclusion, visualizations):
117
- pdf = FPDF()
118
- pdf.set_auto_page_break(auto=True, margin=15)
119
- pdf.add_page()
120
- pdf.set_font("Arial", size=12)
121
-
122
- # Title
123
- pdf.set_font("Arial", style="B", size=18)
124
- pdf.cell(0, 10, "πŸ“Š Analysis Report", ln=True, align="C")
125
- pdf.ln(10)
126
-
127
- # Report Content
128
- pdf.set_font("Arial", style="B", size=14)
129
- pdf.cell(0, 10, "Analysis", ln=True)
130
- pdf.set_font("Arial", size=12)
131
- pdf.multi_cell(0, 10, report)
132
-
133
- pdf.ln(10)
134
- pdf.set_font("Arial", style="B", size=14)
135
- pdf.cell(0, 10, "Conclusion", ln=True)
136
- pdf.set_font("Arial", size=12)
137
- pdf.multi_cell(0, 10, conclusion)
138
-
139
- # Add Visualizations
140
- pdf.add_page()
141
- pdf.set_font("Arial", style="B", size=16)
142
- pdf.cell(0, 10, "πŸ“ˆ Visualizations", ln=True)
143
- pdf.ln(5)
144
-
145
- with tempfile.TemporaryDirectory() as temp_dir:
146
- for i, fig in enumerate(visualizations, start=1):
147
- fig_title = fig.layout.title.text if fig.layout.title.text else f"Visualization {i}"
148
- x_axis = fig.layout.xaxis.title.text if fig.layout.xaxis.title.text else "X-axis"
149
- y_axis = fig.layout.yaxis.title.text if fig.layout.yaxis.title.text else "Y-axis"
150
-
151
- # Save each visualization as a PNG image
152
- img_path = os.path.join(temp_dir, f"viz_{i}.png")
153
- fig.write_image(img_path)
154
-
155
- # Insert Title and Description
156
- pdf.set_font("Arial", style="B", size=14)
157
- pdf.multi_cell(0, 10, f"{i}. {fig_title}")
158
- pdf.set_font("Arial", size=12)
159
- pdf.multi_cell(0, 10, f"X-axis: {x_axis} | Y-axis: {y_axis}")
160
- pdf.ln(3)
161
-
162
- # Embed Visualization
163
- pdf.image(img_path, w=170)
164
- pdf.ln(10)
165
-
166
- # Save PDF
167
- temp_pdf = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
168
- pdf.output(temp_pdf.name)
169
-
170
- return temp_pdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
  def escape_markdown(text):
173
  # Ensure text is a string
@@ -301,27 +561,11 @@ if st.session_state.df is not None:
301
  st.markdown(report_result if report_result else "⚠️ No Report Generated.")
302
 
303
  # Step 4: Generate Visualizations
304
- visualizations = []
305
-
306
- fig_salary = px.box(st.session_state.df, x="job_title", y="salary_in_usd",
307
- title="Salary Distribution by Job Title")
308
- visualizations.append(fig_salary)
309
 
310
- fig_experience = px.bar(
311
- st.session_state.df.groupby("experience_level")["salary_in_usd"].mean().reset_index(),
312
- x="experience_level", y="salary_in_usd",
313
- title="Average Salary by Experience Level"
314
- )
315
- visualizations.append(fig_experience)
316
-
317
- fig_employment = px.box(st.session_state.df, x="employment_type", y="salary_in_usd",
318
- title="Salary Distribution by Employment Type")
319
- visualizations.append(fig_employment)
320
 
321
  # Step 5: Insert Visual Insights
322
  st.markdown("### Visual Insights")
323
- for fig in visualizations:
324
- st.plotly_chart(fig, use_container_width=True)
325
 
326
  # Step 6: Display Concise Conclusion
327
  #st.markdown("#### Conclusion")
@@ -329,31 +573,8 @@ if st.session_state.df is not None:
329
  safe_conclusion = escape_markdown(conclusion_result if conclusion_result else "⚠️ No Conclusion Generated.")
330
  st.markdown(safe_conclusion)
331
 
332
- # Full Data Visualization Tab
333
- with tab2:
334
- st.subheader("πŸ“Š Comprehensive Data Visualizations")
335
-
336
- fig1 = px.histogram(st.session_state.df, x="job_title", title="Job Title Frequency")
337
- st.plotly_chart(fig1)
338
-
339
- fig2 = px.bar(
340
- st.session_state.df.groupby("experience_level")["salary_in_usd"].mean().reset_index(),
341
- x="experience_level", y="salary_in_usd",
342
- title="Average Salary by Experience Level"
343
- )
344
- st.plotly_chart(fig2)
345
-
346
- fig3 = px.box(st.session_state.df, x="employment_type", y="salary_in_usd",
347
- title="Salary Distribution by Employment Type")
348
- st.plotly_chart(fig3)
349
-
350
- temp_dir.cleanup()
351
- else:
352
- st.info("Please load a dataset to proceed.")
353
-
354
 
355
  # Sidebar Reference
356
  with st.sidebar:
357
  st.header("πŸ“š Reference:")
358
  st.markdown("[SQL Agents w CrewAI & Llama 3 - Plaban Nayak](https://github.com/plaban1981/Agents/blob/main/SQL_Agents_with_CrewAI_and_Llama_3.ipynb)")
359
-
 
25
  import tempfile
26
 
27
  st.title("SQL-RAG Using CrewAI πŸš€")
28
+ st.write("Analyze datasets using natural language queries.")
29
 
30
  # Initialize LLM
31
  llm = None
 
86
  st.subheader("πŸ“‚ Dataset Preview")
87
  st.dataframe(st.session_state.df.head())
88
 
89
+
90
+ # Helper Function for Validation
91
+ def is_valid_suggestion(suggestion):
92
+ chart_type = suggestion.get("chart_type", "").lower()
93
+
94
+ if chart_type in ["bar", "line", "box", "scatter"]:
95
+ return all(k in suggestion for k in ["chart_type", "x_axis", "y_axis"])
96
+
97
+ elif chart_type == "pie":
98
+ return all(k in suggestion for k in ["chart_type", "x_axis"])
99
+
100
+ elif chart_type == "heatmap":
101
+ return all(k in suggestion for k in ["chart_type", "x_axis", "y_axis"])
102
+
103
+ else:
104
+ return False
105
+
106
+ def ask_gpt4o_for_visualization(query, df, llm, retries=2):
107
+ import json
108
+
109
+ # Identify numeric and categorical columns
110
+ numeric_columns = df.select_dtypes(include='number').columns.tolist()
111
+ categorical_columns = df.select_dtypes(exclude='number').columns.tolist()
112
+
113
+ # Prompt with Dataset-Specific, Query-Based Examples
114
+ prompt = f"""
115
+ Analyze the following query and suggest the most suitable visualization(s) using the dataset.
116
+ **Query:** "{query}"
117
+ **Dataset Overview:**
118
+ - **Numeric Columns (for Y-axis):** {', '.join(numeric_columns) if numeric_columns else 'None'}
119
+ - **Categorical Columns (for X-axis or grouping):** {', '.join(categorical_columns) if categorical_columns else 'None'}
120
+ Suggest visualizations in this exact JSON format:
121
+ [
122
+ {{
123
+ "chdart_type": "bar/box/line/scatter/pie/heatmap",
124
+ "x_axis": "categorical_or_time_column",
125
+ "y_axis": "numeric_column",
126
+ "group_by": "optional_column_for_grouping",
127
+ "title": "Title of the chart",
128
+ "description": "Why this chart is suitable"
129
+ }}
130
+ ]
131
+ **Query-Based Examples:**
132
+ - **Query:** "What is the salary distribution across different job titles?"
133
+ **Suggested Visualization:**
134
+ {{
135
+ "chart_type": "box",
136
+ "x_axis": "job_title",
137
+ "y_axis": "salary_in_usd",
138
+ "group_by": "experience_level",
139
+ "title": "Salary Distribution by Job Title and Experience",
140
+ "description": "A box plot to show how salaries vary across different job titles and experience levels."
141
+ }}
142
+ - **Query:** "Show the average salary by company size and employment type."
143
+ **Suggested Visualizations:**
144
+ [
145
+ {{
146
+ "chart_type": "bar",
147
+ "x_axis": "company_size",
148
+ "y_axis": "salary_in_usd",
149
+ "group_by": "employment_type",
150
+ "title": "Average Salary by Company Size and Employment Type",
151
+ "description": "A grouped bar chart comparing average salaries across company sizes and employment types."
152
+ }},
153
+ {{
154
+ "chart_type": "heatmap",
155
+ "x_axis": "company_size",
156
+ "y_axis": "salary_in_usd",
157
+ "group_by": "employment_type",
158
+ "title": "Salary Heatmap by Company Size and Employment Type",
159
+ "description": "A heatmap showing salary concentration across company sizes and employment types."
160
+ }}
161
+ ]
162
+ - **Query:** "How has the average salary changed over the years?"
163
+ **Suggested Visualization:**
164
+ {{
165
+ "chart_type": "line",
166
+ "x_axis": "work_year",
167
+ "y_axis": "salary_in_usd",
168
+ "group_by": "experience_level",
169
+ "title": "Average Salary Trend Over Years",
170
+ "description": "A line chart showing how the average salary has changed across different experience levels over the years."
171
+ }}
172
+ - **Query:** "What is the employee distribution by company location?"
173
+ **Suggested Visualization:**
174
+ {{
175
+ "chart_type": "pie",
176
+ "x_axis": "company_location",
177
+ "y_axis": null,
178
+ "group_by": null,
179
+ "title": "Employee Distribution by Company Location",
180
+ "description": "A pie chart showing the distribution of employees across company locations."
181
+ }}
182
+ - **Query:** "Is there a relationship between remote work ratio and salary?"
183
+ **Suggested Visualization:**
184
+ {{
185
+ "chart_type": "scatter",
186
+ "x_axis": "remote_ratio",
187
+ "y_axis": "salary_in_usd",
188
+ "group_by": "experience_level",
189
+ "title": "Remote Work Ratio vs Salary",
190
+ "description": "A scatter plot to analyze the relationship between remote work ratio and salary."
191
+ }}
192
+ - **Query:** "Which job titles have the highest salaries across regions?"
193
+ **Suggested Visualization:**
194
+ {{
195
+ "chart_type": "heatmap",
196
+ "x_axis": "job_title",
197
+ "y_axis": "employee_residence",
198
+ "group_by": null,
199
+ "title": "Salary Heatmap by Job Title and Region",
200
+ "description": "A heatmap showing the concentration of high-paying job titles across regions."
201
+ }}
202
+ Only suggest visualizations that logically match the query and dataset.
203
+ """
204
+
205
+ for attempt in range(retries + 1):
206
+ try:
207
+ response = llm.generate(prompt)
208
+ suggestions = json.loads(response)
209
+
210
+ if isinstance(suggestions, list):
211
+ valid_suggestions = [s for s in suggestions if is_valid_suggestion(s)]
212
+ if valid_suggestions:
213
+ return valid_suggestions
214
+ else:
215
+ st.warning("⚠️ GPT-4o did not suggest valid visualizations.")
216
+ return None
217
+
218
+ elif isinstance(suggestions, dict):
219
+ if is_valid_suggestion(suggestions):
220
+ return [suggestions]
221
+ else:
222
+ st.warning("⚠️ GPT-4o's suggestion is incomplete or invalid.")
223
+ return None
224
+
225
+ except json.JSONDecodeError:
226
+ st.warning(f"⚠️ Attempt {attempt + 1}: GPT-4o returned invalid JSON.")
227
+ except Exception as e:
228
+ st.error(f"⚠️ Error during GPT-4o call: {e}")
229
+
230
+ if attempt < retries:
231
+ st.info("πŸ”„ Retrying visualization suggestion...")
232
+
233
+ st.error("❌ Failed to generate a valid visualization after multiple attempts.")
234
+ return None
235
+
236
+
237
+ def add_stats_to_figure(fig, df, y_axis, chart_type):
238
+ """
239
+ Add relevant statistical annotations to the visualization
240
+ based on the chart type.
241
+ """
242
+ # Check if the y-axis column is numeric
243
+ if not pd.api.types.is_numeric_dtype(df[y_axis]):
244
+ st.warning(f"⚠️ Cannot compute statistics for non-numeric column: {y_axis}")
245
+ return fig
246
+
247
+ # Compute statistics for numeric data
248
+ min_val = df[y_axis].min()
249
+ max_val = df[y_axis].max()
250
+ avg_val = df[y_axis].mean()
251
+ median_val = df[y_axis].median()
252
+ std_dev_val = df[y_axis].std()
253
+
254
+ # Format the stats for display
255
+ stats_text = (
256
+ f"πŸ“Š **Statistics**\n\n"
257
+ f"- **Min:** ${min_val:,.2f}\n"
258
+ f"- **Max:** ${max_val:,.2f}\n"
259
+ f"- **Average:** ${avg_val:,.2f}\n"
260
+ f"- **Median:** ${median_val:,.2f}\n"
261
+ f"- **Std Dev:** ${std_dev_val:,.2f}"
262
+ )
263
+
264
+ # Apply stats only to relevant chart types
265
+ if chart_type in ["bar", "line"]:
266
+ # Add annotation box for bar and line charts
267
+ fig.add_annotation(
268
+ text=stats_text,
269
+ xref="paper", yref="paper",
270
+ x=1.02, y=1,
271
+ showarrow=False,
272
+ align="left",
273
+ font=dict(size=12, color="black"),
274
+ bordercolor="gray",
275
+ borderwidth=1,
276
+ bgcolor="rgba(255, 255, 255, 0.85)"
277
+ )
278
+
279
+ # Add horizontal reference lines
280
+ fig.add_hline(y=min_val, line_dash="dot", line_color="red", annotation_text="Min", annotation_position="bottom right")
281
+ fig.add_hline(y=median_val, line_dash="dash", line_color="orange", annotation_text="Median", annotation_position="top right")
282
+ fig.add_hline(y=avg_val, line_dash="dashdot", line_color="green", annotation_text="Avg", annotation_position="top right")
283
+ fig.add_hline(y=max_val, line_dash="dot", line_color="blue", annotation_text="Max", annotation_position="top right")
284
+
285
+ elif chart_type == "scatter":
286
+ # Add stats annotation only, no lines for scatter plots
287
+ fig.add_annotation(
288
+ text=stats_text,
289
+ xref="paper", yref="paper",
290
+ x=1.02, y=1,
291
+ showarrow=False,
292
+ align="left",
293
+ font=dict(size=12, color="black"),
294
+ bordercolor="gray",
295
+ borderwidth=1,
296
+ bgcolor="rgba(255, 255, 255, 0.85)"
297
+ )
298
+
299
+ elif chart_type == "box":
300
+ # Box plots inherently show distribution; no extra stats needed
301
+ pass
302
+
303
+ elif chart_type == "pie":
304
+ # Pie charts represent proportions, not suitable for stats
305
+ st.info("πŸ“Š Pie charts represent proportions. Additional stats are not applicable.")
306
+
307
+ elif chart_type == "heatmap":
308
+ # Heatmaps already reflect data intensity
309
+ st.info("πŸ“Š Heatmaps inherently reflect distribution. No additional stats added.")
310
+
311
+ else:
312
+ st.warning(f"⚠️ No statistical overlays applied for unsupported chart type: '{chart_type}'.")
313
+
314
+ return fig
315
+
316
+
317
+ # Dynamically generate Plotly visualizations based on GPT-4o suggestions
318
+ def generate_visualization(suggestion, df):
319
+ """
320
+ Generate a Plotly visualization based on GPT-4o's suggestion.
321
+ If the Y-axis is missing, infer it intelligently.
322
+ """
323
+ chart_type = suggestion.get("chart_type", "bar").lower()
324
+ x_axis = suggestion.get("x_axis")
325
+ y_axis = suggestion.get("y_axis")
326
+ group_by = suggestion.get("group_by")
327
+
328
+ # Step 1: Infer Y-axis if not provided
329
+ if not y_axis:
330
+ numeric_columns = df.select_dtypes(include='number').columns.tolist()
331
+
332
+ # Avoid using the same column for both axes
333
+ if x_axis in numeric_columns:
334
+ numeric_columns.remove(x_axis)
335
+
336
+ # Smart guess: prioritize salary or relevant metrics if available
337
+ priority_columns = ["salary_in_usd", "income", "earnings", "revenue"]
338
+ for col in priority_columns:
339
+ if col in numeric_columns:
340
+ y_axis = col
341
+ break
342
+
343
+ # Fallback to the first numeric column if no priority columns exist
344
+ if not y_axis and numeric_columns:
345
+ y_axis = numeric_columns[0]
346
+
347
+ # Step 2: Validate axes
348
+ if not x_axis or not y_axis:
349
+ st.warning("⚠️ Unable to determine appropriate columns for visualization.")
350
+ return None
351
+
352
+ # Step 3: Dynamically select the Plotly function
353
+ plotly_function = getattr(px, chart_type, None)
354
+ if not plotly_function:
355
+ st.warning(f"⚠️ Unsupported chart type '{chart_type}' suggested by GPT-4o.")
356
+ return None
357
+
358
+ # Step 4: Prepare dynamic plot arguments
359
+ plot_args = {"data_frame": df, "x": x_axis, "y": y_axis}
360
+ if group_by and group_by in df.columns:
361
+ plot_args["color"] = group_by
362
+
363
+ try:
364
+ # Step 5: Generate the visualization
365
+ fig = plotly_function(**plot_args)
366
+ fig.update_layout(
367
+ title=f"{chart_type.title()} Plot of {y_axis.replace('_', ' ').title()} by {x_axis.replace('_', ' ').title()}",
368
+ xaxis_title=x_axis.replace('_', ' ').title(),
369
+ yaxis_title=y_axis.replace('_', ' ').title(),
370
+ )
371
+
372
+ # Step 6: Apply statistics intelligently
373
+ fig = add_statistics_to_visualization(fig, df, y_axis, chart_type)
374
+
375
+ return fig
376
+
377
+ except Exception as e:
378
+ st.error(f"⚠️ Failed to generate visualization: {e}")
379
+ return None
380
+
381
+
382
+ def generate_multiple_visualizations(suggestions, df):
383
+ """
384
+ Generates one or more visualizations based on GPT-4o's suggestions.
385
+ Handles both single and multiple suggestions.
386
+ """
387
+ visualizations = []
388
+
389
+ for suggestion in suggestions:
390
+ fig = generate_visualization(suggestion, df)
391
+ if fig:
392
+ # Apply chart-specific statistics
393
+ fig = add_stats_to_figure(fig, df, suggestion["y_axis"], suggestion["chart_type"])
394
+ visualizations.append(fig)
395
+
396
+ if not visualizations and suggestions:
397
+ st.warning("⚠️ No valid visualization found. Displaying the most relevant one.")
398
+ best_suggestion = suggestions[0]
399
+ fig = generate_visualization(best_suggestion, df)
400
+ fig = add_stats_to_figure(fig, df, best_suggestion["y_axis"], best_suggestion["chart_type"])
401
+ visualizations.append(fig)
402
+
403
+ return visualizations
404
+
405
+
406
+ def handle_visualization_suggestions(suggestions, df):
407
+ """
408
+ Determines whether to generate a single or multiple visualizations.
409
+ """
410
+ visualizations = []
411
+
412
+ # If multiple suggestions, generate multiple plots
413
+ if isinstance(suggestions, list) and len(suggestions) > 1:
414
+ visualizations = generate_multiple_visualizations(suggestions, df)
415
+
416
+ # If only one suggestion, generate a single plot
417
+ elif isinstance(suggestions, dict) or (isinstance(suggestions, list) and len(suggestions) == 1):
418
+ suggestion = suggestions[0] if isinstance(suggestions, list) else suggestions
419
+ fig = generate_visualization(suggestion, df)
420
+ if fig:
421
+ visualizations.append(fig)
422
+
423
+ # Handle cases when no visualization could be generated
424
+ if not visualizations:
425
+ st.warning("⚠️ Unable to generate any visualization based on the suggestion.")
426
+
427
+ # Display all generated visualizations
428
+ for fig in visualizations:
429
+ st.plotly_chart(fig, use_container_width=True)
430
+
431
 
432
  def escape_markdown(text):
433
  # Ensure text is a string
 
561
  st.markdown(report_result if report_result else "⚠️ No Report Generated.")
562
 
563
  # Step 4: Generate Visualizations
 
 
 
 
 
564
 
 
 
 
 
 
 
 
 
 
 
565
 
566
  # Step 5: Insert Visual Insights
567
  st.markdown("### Visual Insights")
568
+
 
569
 
570
  # Step 6: Display Concise Conclusion
571
  #st.markdown("#### Conclusion")
 
573
  safe_conclusion = escape_markdown(conclusion_result if conclusion_result else "⚠️ No Conclusion Generated.")
574
  st.markdown(safe_conclusion)
575
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
576
 
577
  # Sidebar Reference
578
  with st.sidebar:
579
  st.header("πŸ“š Reference:")
580
  st.markdown("[SQL Agents w CrewAI & Llama 3 - Plaban Nayak](https://github.com/plaban1981/Agents/blob/main/SQL_Agents_with_CrewAI_and_Llama_3.ipynb)")