meg-huggingface
commited on
Commit
•
86102e5
1
Parent(s):
f20cab2
Full dataset
Browse files- .gitignore +6 -6
- main_backend_toxicity.py +2 -2
- src/backend/run_toxicity_eval.py +13 -13
.gitignore
CHANGED
@@ -7,9 +7,9 @@ __pycache__/
|
|
7 |
.vscode/
|
8 |
.idea/
|
9 |
|
10 |
-
eval-queue/
|
11 |
-
eval-results/
|
12 |
-
eval-queue-bk/
|
13 |
-
eval-results-bk/
|
14 |
-
logs/
|
15 |
-
output.log
|
|
|
7 |
.vscode/
|
8 |
.idea/
|
9 |
|
10 |
+
#eval-queue/
|
11 |
+
#eval-results/
|
12 |
+
#eval-queue-bk/
|
13 |
+
#eval-results-bk/
|
14 |
+
#logs/
|
15 |
+
#output.log
|
main_backend_toxicity.py
CHANGED
@@ -69,7 +69,7 @@ def run_auto_eval():
|
|
69 |
logger.info(f'Starting Evaluation of {eval_request.json_filepath} on Inference endpoints')
|
70 |
|
71 |
model_repository = eval_request.model
|
72 |
-
endpoint_name = re.sub("/", "-", model_repository.lower()) + "-toxicity-eval"
|
73 |
endpoint_url = create_endpoint(endpoint_name, model_repository)
|
74 |
logger.info("Created an endpoint url at %s" % endpoint_url)
|
75 |
results = main(endpoint_url, model_repository)
|
@@ -78,7 +78,7 @@ def run_auto_eval():
|
|
78 |
#local_dir = EVAL_RESULTS_PATH_BACKEND,
|
79 |
#limit=LIMIT
|
80 |
# )
|
81 |
-
|
82 |
|
83 |
|
84 |
if __name__ == "__main__":
|
|
|
69 |
logger.info(f'Starting Evaluation of {eval_request.json_filepath} on Inference endpoints')
|
70 |
|
71 |
model_repository = eval_request.model
|
72 |
+
endpoint_name = re.sub("/", "-", model_repository.lower()) + "-toxicity-eval"
|
73 |
endpoint_url = create_endpoint(endpoint_name, model_repository)
|
74 |
logger.info("Created an endpoint url at %s" % endpoint_url)
|
75 |
results = main(endpoint_url, model_repository)
|
|
|
78 |
#local_dir = EVAL_RESULTS_PATH_BACKEND,
|
79 |
#limit=LIMIT
|
80 |
# )
|
81 |
+
logger.info(f'Completed Evaluation of {eval_request.json_filepath}')
|
82 |
|
83 |
|
84 |
if __name__ == "__main__":
|
src/backend/run_toxicity_eval.py
CHANGED
@@ -50,7 +50,7 @@ def get_generation(endpoint_url, pred):
|
|
50 |
})
|
51 |
retries = 0
|
52 |
while output == {'error': '503 Service Unavailable'}:
|
53 |
-
|
54 |
if retries >= 10:
|
55 |
print("Did not work after 10 tries. Giving up.")
|
56 |
sys.exit()
|
@@ -152,17 +152,17 @@ def score_per_generation(endpoint_url, prompts, total_retries=5):
|
|
152 |
return req_att_scores
|
153 |
|
154 |
def main(endpoint_url, model_repository):
|
155 |
-
|
156 |
ds = load_dataset("allenai/real-toxicity-prompts")
|
157 |
-
prompts = [row['text'] for row in ds['train']['prompt']
|
158 |
att_scores_out = score_per_generation(endpoint_url, prompts)
|
159 |
-
|
160 |
-
|
161 |
average_att_scores = {}
|
162 |
for req_att in att_scores_out:
|
163 |
average_att_scores[req_att.lower()] = mean(att_scores_out[req_att])
|
164 |
-
|
165 |
-
|
166 |
|
167 |
results = {"results":{"realtoxicityprompts":{}}}
|
168 |
for att, score in average_att_scores.items():
|
@@ -177,13 +177,13 @@ def main(endpoint_url, model_repository):
|
|
177 |
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
178 |
with open(output_path, "w") as f:
|
179 |
f.write(dumped)
|
|
|
|
|
180 |
|
181 |
-
logger.
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
print("repo id")
|
186 |
-
print(RESULTS_REPO)
|
187 |
|
188 |
API.upload_file(
|
189 |
path_or_fileobj=output_path,
|
|
|
50 |
})
|
51 |
retries = 0
|
52 |
while output == {'error': '503 Service Unavailable'}:
|
53 |
+
logger.warning("Service unavailable.")
|
54 |
if retries >= 10:
|
55 |
print("Did not work after 10 tries. Giving up.")
|
56 |
sys.exit()
|
|
|
152 |
return req_att_scores
|
153 |
|
154 |
def main(endpoint_url, model_repository):
|
155 |
+
logger.info("Loading dataset")
|
156 |
ds = load_dataset("allenai/real-toxicity-prompts")
|
157 |
+
prompts = [row['text'] for row in ds['train']['prompt']]
|
158 |
att_scores_out = score_per_generation(endpoint_url, prompts)
|
159 |
+
logger.debug("Scores are:")
|
160 |
+
logger.debug(att_scores_out)
|
161 |
average_att_scores = {}
|
162 |
for req_att in att_scores_out:
|
163 |
average_att_scores[req_att.lower()] = mean(att_scores_out[req_att])
|
164 |
+
logger.debug("Final scores are:")
|
165 |
+
logger.debug(average_att_scores)
|
166 |
|
167 |
results = {"results":{"realtoxicityprompts":{}}}
|
168 |
for att, score in average_att_scores.items():
|
|
|
177 |
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
178 |
with open(output_path, "w") as f:
|
179 |
f.write(dumped)
|
180 |
+
logger.debug("Results:")
|
181 |
+
logger.debug(results)
|
182 |
|
183 |
+
logger.debug("Uploading to")
|
184 |
+
logger.debug(output_path)
|
185 |
+
logger.debug("repo id")
|
186 |
+
logger.debug(RESULTS_REPO)
|
|
|
|
|
187 |
|
188 |
API.upload_file(
|
189 |
path_or_fileobj=output_path,
|