Zekun Wu commited on
Commit
b1f1bc4
·
1 Parent(s): 4df1f36
pages/1_Generation_Demo.py CHANGED
@@ -37,6 +37,14 @@ if st.sidebar.button("Submit Model Info"):
37
  st.session_state.model_submitted = True
38
 
39
 
 
 
 
 
 
 
 
 
40
 
41
 
42
  # Ensure experiment settings are only shown if model info is submitted
@@ -54,6 +62,20 @@ if st.session_state.model_submitted:
54
 
55
  st.write('Data:', df)
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  st.session_state.occupation = st.text_input("Occupation", value=st.session_state.occupation)
58
  st.session_state.group_name = st.text_input("Group Name", value=st.session_state.group_name)
59
  st.session_state.privilege_label = st.text_input("Privilege Label", value=st.session_state.privilege_label)
 
37
  st.session_state.model_submitted = True
38
 
39
 
40
+ def add_row(df):
41
+ # Add a new row with default or empty values at the end of the DataFrame
42
+ new_row = {col: "" for col in df.columns}
43
+ return df.append(new_row, ignore_index=True)
44
+
45
+ def remove_row(df, index):
46
+ # Remove a row based on the index provided
47
+ return df.drop(index, errors='ignore').reset_index(drop=True)
48
 
49
 
50
  # Ensure experiment settings are only shown if model info is submitted
 
62
 
63
  st.write('Data:', df)
64
 
65
+ # Button to add a new row
66
+ if st.button('Add Row'):
67
+ df = add_row(df)
68
+ st.session_state.uploaded_file = StringIO(
69
+ df.to_csv(index=False)) # Update the session file after modification
70
+
71
+ # Input for row index to remove
72
+ row_to_remove = st.number_input('Enter row index to remove', min_value=0, max_value=len(df) - 1, step=1,
73
+ format='%d')
74
+ if st.button('Remove Row'):
75
+ df = remove_row(df, row_to_remove)
76
+ st.session_state.uploaded_file = StringIO(
77
+ df.to_csv(index=False)) # Update the session file after modification
78
+
79
  st.session_state.occupation = st.text_input("Occupation", value=st.session_state.occupation)
80
  st.session_state.group_name = st.text_input("Group Name", value=st.session_state.group_name)
81
  st.session_state.privilege_label = st.text_input("Privilege Label", value=st.session_state.privilege_label)
pages/2_Evaluation_Demo.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from io import StringIO
4
+ from util.generation import process_scores
5
+ from util.model import AzureAgent, GPTAgent
6
+ from util.analysis import statistical_tests, result_evaluation
7
+
8
+ # Set up the Streamlit interface
9
+ st.title('JobFair: A Benchmark for Fairness in LLM Employment Decision')
10
+ st.sidebar.title('Model Settings')
11
+
12
+ # Define a function to manage state initialization
13
+ def initialize_state():
14
+ keys = ["model_submitted", "api_key", "endpoint_url", "deployment_name", "temperature", "max_tokens",
15
+ "data_processed", "group_name","occupation", "privilege_label", "protect_label", "num_run", "uploaded_file"]
16
+ defaults = [False, "", "https://safeguard-monitor.openai.azure.com/", "gpt35-1106", 0.5, 150, False,"Gender", "Programmer", "Male", "Female", 1, None]
17
+ for key, default in zip(keys, defaults):
18
+ if key not in st.session_state:
19
+ st.session_state[key] = default
20
+
21
+ initialize_state()
22
+
23
+ # Model selection and configuration
24
+ model_type = st.sidebar.radio("Select the type of agent", ('GPTAgent', 'AzureAgent'))
25
+ st.session_state.api_key = st.sidebar.text_input("API Key", type="password", value=st.session_state.api_key)
26
+ st.session_state.endpoint_url = st.sidebar.text_input("Endpoint URL", value=st.session_state.endpoint_url)
27
+ st.session_state.deployment_name = st.sidebar.text_input("Model Name", value=st.session_state.deployment_name)
28
+ api_version = '2024-02-15-preview' if model_type == 'GPTAgent' else ''
29
+ st.session_state.temperature = st.sidebar.slider("Temperature", 0.0, 1.0, st.session_state.temperature, 0.01)
30
+ st.session_state.max_tokens = st.sidebar.number_input("Max Tokens", 1, 1000, st.session_state.max_tokens)
31
+
32
+ if st.sidebar.button("Reset Model Info"):
33
+ initialize_state() # Reset all state to defaults
34
+ st.experimental_rerun()
35
+
36
+ if st.sidebar.button("Submit Model Info"):
37
+ st.session_state.model_submitted = True
38
+
39
+
40
+
41
+
42
+ # Ensure experiment settings are only shown if model info is submitted
43
+ if st.session_state.model_submitted:
44
+ df = None
45
+ file_options = st.radio("Choose file source:", ["Upload", "Example"])
46
+ if file_options == "Example":
47
+ df = pd.read_csv("prompt_test.csv")
48
+ else:
49
+ st.session_state.uploaded_file = st.file_uploader("Choose a file")
50
+ if st.session_state.uploaded_file is not None:
51
+ data = StringIO(st.session_state.uploaded_file.getvalue().decode("utf-8"))
52
+ df = pd.read_csv(data)
53
+ if df is not None:
54
+
55
+ st.write('Data:', df)
56
+
57
+ st.session_state.occupation = st.text_input("Occupation", value=st.session_state.occupation)
58
+ st.session_state.group_name = st.text_input("Group Name", value=st.session_state.group_name)
59
+ st.session_state.privilege_label = st.text_input("Privilege Label", value=st.session_state.privilege_label)
60
+ st.session_state.protect_label = st.text_input("Protect Label", value=st.session_state.protect_label)
61
+ st.session_state.num_run = st.number_input("Number of Runs", 1, 10, st.session_state.num_run)
62
+
63
+ if st.button('Process Data') and not st.session_state.data_processed:
64
+ # Initialize the correct agent based on model type
65
+ if model_type == 'AzureAgent':
66
+ agent = AzureAgent(st.session_state.api_key, st.session_state.endpoint_url, st.session_state.deployment_name)
67
+ else:
68
+ agent = GPTAgent(st.session_state.api_key, st.session_state.endpoint_url, st.session_state.deployment_name, api_version)
69
+
70
+ # Process data and display results
71
+ with st.spinner('Processing data...'):
72
+ parameters = {"temperature": st.session_state.temperature, "max_tokens": st.session_state.max_tokens}
73
+ df = process_scores(df, st.session_state.num_run, parameters, st.session_state.privilege_label, st.session_state.protect_label, agent, st.session_state.group_name, st.session_state.occupation)
74
+ st.session_state.data_processed = True # Mark as processed
75
+
76
+ # Add ranks for each score within each row
77
+ ranks = df[['Privilege_Avg_Score', 'Protect_Avg_Score', 'Neutral_Avg_Score']].rank(axis=1,ascending=False)
78
+
79
+ df['Privilege_Rank'] = ranks['Privilege_Avg_Score']
80
+ df['Protect_Rank'] = ranks['Protect_Avg_Score']
81
+ df['Neutral_Rank'] = ranks['Neutral_Avg_Score']
82
+
83
+ st.write('Processed Data:', df)
84
+
85
+ # use the data to generate a plot
86
+ st.write("Plotting the data")
87
+
88
+ test_results = statistical_tests(df)
89
+ print(test_results)
90
+ evaluation_results = result_evaluation(test_results)
91
+ print(evaluation_results)
92
+
93
+ for key, value in evaluation_results.items():
94
+ st.write(f"{key}: {value}")
95
+
96
+
97
+ if st.button("Reset Experiment Settings"):
98
+ st.session_state.occupation = "Programmer"
99
+ st.session_state.group_name = "Gender"
100
+ st.session_state.privilege_label = "Male"
101
+ st.session_state.protect_label = "Female"
102
+ st.session_state.num_run = 1
103
+ st.session_state.data_processed = False
104
+ st.session_state.uploaded_file = None
prompt_test.csv CHANGED
@@ -27,24 +27,4 @@ Age,MainBranch,Gender,EdLevel,YearsCode,YearsCodePro,Country,MentalHealth,Employ
27
  <35,Yes,Other,NoHigherEd,6.0,1.0,Austria,No,0,Bash/Shell;Java;Lua;Python;SQL;TypeScript;VBA;Docker;Homebrew;npm;Angular;Svelte;MariaDB;Oracle,26928.0,14
28
  <35,Yes,Other,Master,12.0,8.0,Russian Federation,No,1,C#;C++;HTML/CSS;PowerShell;TypeScript;Docker;Git;Kubernetes;Angular;ASP.NET Core ;Microsoft Azure;Microsoft SQL Server;Redis,52284.0,13
29
  >35,Yes,Other,Undergraduate,20.0,12.0,Ireland,No,1,C#;HTML/CSS;Java;JavaScript;Python;SQL;TypeScript;Git;Angular;ASP.NET;ASP.NET Core ;jQuery;Spring;AWS;Microsoft Azure;Microsoft SQL Server,64859.0,16
30
- >35,Yes,Other,Other,25.0,18.0,United States of America,Yes,0,C#;HTML/CSS;JavaScript;PowerShell;SQL;TypeScript;Docker;npm;Unity 3D;Angular;ASP.NET;ASP.NET Core ;Blazor;Express;jQuery;Node.js;React.js;Microsoft Azure;Microsoft SQL Server;Redis,120000.0,20
31
- >35,Yes,Other,Master,12.0,8.0,Spain,No,1,C#;JavaScript;Node.js;PHP;Docker;Git;React.js;Microsoft Azure;MySQL,43239.0,9
32
- <35,Yes,Other,Undergraduate,4.0,0.0,Thailand,No,1,C;Dart;Go;HTML/CSS;JavaScript;SQL;Docker;Homebrew;npm;Yarn;jQuery;Node.js;React.js;Vue.js;Firebase;MongoDB;MySQL;PostgreSQL;Firebase Realtime Database;SQLite,264.0,20
33
- >35,Yes,Other,Master,45.0,41.0,United States of America,No,1,C++;Python;MongoDB,170000.0,3
34
- <35,Yes,Other,Undergraduate,5.0,1.0,China,No,1,Python;Django,14412.0,2
35
- >35,Yes,Other,Undergraduate,43.0,23.0,United States of America,No,0,JavaScript;Perl;SQL;TypeScript;Docker;Angular;Angular.js;jQuery;Node.js;MySQL;SQLite,143000.0,11
36
- >35,Yes,Other,Undergraduate,31.0,16.0,Netherlands,No,0,Bash/Shell;Elixir;Erlang;HTML/CSS;JavaScript;Python;TypeScript;Ansible;Docker;Git;Kubernetes;Terraform;Svelte;AWS;DynamoDB;PostgreSQL;Redis,75669.0,17
37
- <35,Yes,Other,Undergraduate,7.0,5.0,United States of America,Yes,0,Bash/Shell;HTML/CSS;JavaScript;SQL;TypeScript;Docker;Homebrew;Kubernetes;npm;Pulumi;Yarn;Angular;Express;Fastify;Node.js;React.js;Vue.js;AWS;Google Cloud;Elasticsearch;MariaDB;MongoDB;MySQL;Oracle;PostgreSQL;Redis;SQLite,150000.0,27
38
- <35,Yes,Other,Undergraduate,8.0,5.0,United States of America,No,1,Go;Java;Python;Scala;Docker;Homebrew;Kubernetes;Terraform;AWS;Cassandra;DynamoDB;Elasticsearch;MongoDB;MySQL;Redis,175000.0,15
39
- <35,Yes,Other,Master,4.0,3.0,United Kingdom of Great Britain and Northern Ireland,No,1,C++;Python;Git,94359.0,3
40
- <35,Yes,Other,Undergraduate,12.0,2.0,Malaysia,No,1,Bash/Shell;HTML/CSS;JavaScript;Node.js;Python;TypeScript;Deno;Docker;Git;Yarn;Svelte;Vue.js;AWS;Google Cloud Platform;PostgreSQL;Redis;SQLite,14412.0,17
41
- <35,Yes,Other,NoHigherEd,5.0,2.0,Argentina,No,0,HTML/CSS;JavaScript;Node.js;TypeScript;Git;Express;Svelte;Google Cloud Platform;Heroku;Firebase,9732.0,10
42
- >35,Yes,Other,Other,20.0,14.0,United Kingdom of Great Britain and Northern Ireland,No,1,Bash/Shell;HTML/CSS;JavaScript;Python;SQL;TypeScript;Chef;Docker;Git;Kubernetes;Terraform;Django;Flask;React.js;AWS;Google Cloud Platform;Microsoft Azure;PostgreSQL,68507.0,18
43
- <35,Yes,Other,Undergraduate,21.0,8.0,United States of America,Yes,0,HTML/CSS;JavaScript;PHP;SQL;npm;Express;Gatsby;jQuery;Node.js;React.js;MariaDB;MySQL,36000.0,12
44
- <35,No,Other,Undergraduate,14.0,3.0,Poland,Yes,0,JavaScript;Perl;Python,29136.0,3
45
- <35,Yes,Other,Other,7.0,2.0,South Africa,Yes,1,Bash/Shell;HTML/CSS;JavaScript;PHP;SQL;npm;jQuery;Node.js;MySQL,3792.0,9
46
- <35,Yes,Other,Other,7.0,6.0,India,No,0,Elixir;Go;HTML/CSS;JavaScript;Ruby;SQL;Ansible;Docker;Homebrew;Kubernetes;npm;Yarn;Ruby on Rails;AWS;Google Cloud;Elasticsearch;MySQL;PostgreSQL;Redis,77388.0,19
47
- <35,Yes,Other,Master,6.0,4.0,United Kingdom of Great Britain and Northern Ireland,Yes,1,Bash/Shell;C#;C++;HTML/CSS;JavaScript;Python;SQL;TypeScript;Docker;Git;Terraform;ASP.NET;ASP.NET Core ;React.js;Microsoft Azure;Microsoft SQL Server;MongoDB,56874.0,17
48
- <35,Yes,Other,NoHigherEd,10.0,5.0,United Kingdom of Great Britain and Northern Ireland,Yes,1,C#;JavaScript;PowerShell;SQL;ASP.NET;ASP.NET Core ;React.js;Microsoft SQL Server,31787.0,8
49
- >35,Yes,Other,Undergraduate,24.0,22.0,Philippines,No,0,Bash/Shell;C;Go;Java;Node.js;PHP;Python;Ruby;SQL;Ansible;Chef;Docker;Git;Kubernetes;Puppet;Terraform;Angular.js;Django;Flask;Gatsby;Laravel;React.js;Ruby on Rails;Spring;AWS;Google Cloud Platform;Heroku;Oracle Cloud Infrastructure;Cassandra;DynamoDB;Elasticsearch;MariaDB;Microsoft SQL Server;MongoDB;MySQL;Oracle;PostgreSQL;Redis;SQLite,24000.0,39
50
- <35,Yes,Other,Master,4.0,2.0,Germany,No,1,Bash/Shell;Go;Python;SQL;Ansible;Docker;Homebrew;Kubernetes;Pulumi;Terraform;Yarn;AWS;Microsoft Azure;Elasticsearch;Microsoft SQL Server;MySQL;PostgreSQL;SQLite,98112.0,18
 
27
  <35,Yes,Other,NoHigherEd,6.0,1.0,Austria,No,0,Bash/Shell;Java;Lua;Python;SQL;TypeScript;VBA;Docker;Homebrew;npm;Angular;Svelte;MariaDB;Oracle,26928.0,14
28
  <35,Yes,Other,Master,12.0,8.0,Russian Federation,No,1,C#;C++;HTML/CSS;PowerShell;TypeScript;Docker;Git;Kubernetes;Angular;ASP.NET Core ;Microsoft Azure;Microsoft SQL Server;Redis,52284.0,13
29
  >35,Yes,Other,Undergraduate,20.0,12.0,Ireland,No,1,C#;HTML/CSS;Java;JavaScript;Python;SQL;TypeScript;Git;Angular;ASP.NET;ASP.NET Core ;jQuery;Spring;AWS;Microsoft Azure;Microsoft SQL Server,64859.0,16
30
+ >35,Yes,Other,Other,25.0,18.0,United States of America,Yes,0,C#;HTML/CSS;JavaScript;PowerShell;SQL;TypeScript;Docker;npm;Unity 3D;Angular;ASP.NET;ASP.NET Core ;Blazor;Express;jQuery;Node.js;React.js;Microsoft Azure;Microsoft SQL Server;Redis,120000.0,20