Zekun Wu commited on
Commit
09c5f1e
·
1 Parent(s): cb1e0b7
Files changed (1) hide show
  1. util/injection.py +9 -4
util/injection.py CHANGED
@@ -84,6 +84,13 @@ def invoke_retry(prompt, agent, parameters,string_input=False):
84
  raise Exception("Failed to complete the API call after maximum retry attempts.")
85
 
86
 
 
 
 
 
 
 
 
87
  def process_scores_multiple(df, num_run, parameters, privilege_label, protect_label, agent, group_name, occupation):
88
 
89
  print(f"Processing {len(df)} entries with {num_run} runs each.")
@@ -118,7 +125,5 @@ def process_scores_multiple(df, num_run, parameters, privilege_label, protect_la
118
  print(f"Scores: {scores}")
119
 
120
  for category in ['Privilege_characteristics', 'Privilege_normal','Protect_characteristics', 'Protect_normal','Neutral_characteristics', 'Neutral_normal']:
121
- df[f'{category}_Scores'] = pd.Series([lst for lst in scores[category]])
122
- df[f'{category}_Avg_Score'] = df[f'{category}_Scores'].apply(
123
- lambda scores: sum(score for score in scores if score is not None) / len(scores) if scores else None
124
- )
 
84
  raise Exception("Failed to complete the API call after maximum retry attempts.")
85
 
86
 
87
+ def calculate_avg_score(score_list):
88
+ if isinstance(score_list, list) and score_list:
89
+ valid_scores = [score for score in score_list if score is not None]
90
+ if valid_scores:
91
+ avg_score = sum(valid_scores) / len(valid_scores)
92
+ return avg_score
93
+ return None
94
  def process_scores_multiple(df, num_run, parameters, privilege_label, protect_label, agent, group_name, occupation):
95
 
96
  print(f"Processing {len(df)} entries with {num_run} runs each.")
 
125
  print(f"Scores: {scores}")
126
 
127
  for category in ['Privilege_characteristics', 'Privilege_normal','Protect_characteristics', 'Protect_normal','Neutral_characteristics', 'Neutral_normal']:
128
+ df[f'{category}_Scores'] = pd.Series([lst if isinstance(lst, list) else [] for lst in scores[category]])
129
+ df[f'{category}_Avg_Score'] = df[f'{category}_Scores'].apply(calculate_avg_score)