jjkim commited on
Commit
cb0919a
·
1 Parent(s): fe7364e

add early termination

Browse files
Files changed (1) hide show
  1. code_eval.py +27 -10
code_eval.py CHANGED
@@ -19,7 +19,7 @@ described in the paper "Evaluating Large Language Models Trained on Code"
19
  import itertools
20
  import os
21
  from collections import Counter, defaultdict
22
- from concurrent.futures import ThreadPoolExecutor, as_completed
23
 
24
  import datasets
25
  import evaluate
@@ -171,6 +171,7 @@ class CodeEval(evaluate.Metric):
171
 
172
  with ThreadPoolExecutor(max_workers=num_workers) as executor:
173
  futures = []
 
174
  completion_id = Counter()
175
  results = defaultdict(list)
176
 
@@ -189,31 +190,47 @@ class CodeEval(evaluate.Metric):
189
  )
190
  future = executor.submit(check_correctness, *args)
191
  futures.append(future)
 
192
  completion_id[task_id] += 1
193
 
194
  pbar = tqdm(total=len(futures))
195
  for future in as_completed(futures):
196
- result = future.result()
 
 
 
 
 
197
  results[result["task_id"]].append((result["completion_id"], result))
198
  pbar.update(1)
199
 
200
- for result in results.values():
 
 
 
 
 
 
201
  new_result = []
202
  for completion_id, group in itertools.groupby(result, key=lambda x: x[0]):
203
  group = list(group)
204
  new_result.append(
205
- dict(
206
- task_id=group[0][0],
207
- passed=all(r[1]["passed"] for r in group),
208
- result=[r[1]["result"] for r in group],
209
- completion_id=completion_id,
 
 
 
210
  )
211
  )
212
- result = new_result
 
213
 
214
  total, correct = [], []
215
  for result in results.values():
216
- result.sort()
217
  passed = [r[1]["passed"] for r in result]
218
  total.append(len(passed))
219
  correct.append(sum(passed))
 
19
  import itertools
20
  import os
21
  from collections import Counter, defaultdict
22
+ from concurrent.futures import CancelledError, ThreadPoolExecutor, as_completed
23
 
24
  import datasets
25
  import evaluate
 
171
 
172
  with ThreadPoolExecutor(max_workers=num_workers) as executor:
173
  futures = []
174
+ future_dict = defaultdict(lambda: defaultdict(list))
175
  completion_id = Counter()
176
  results = defaultdict(list)
177
 
 
190
  )
191
  future = executor.submit(check_correctness, *args)
192
  futures.append(future)
193
+ future_dict[task_id][completion_id[task_id]].append(future)
194
  completion_id[task_id] += 1
195
 
196
  pbar = tqdm(total=len(futures))
197
  for future in as_completed(futures):
198
+ try:
199
+ result = future.result()
200
+ except CancelledError:
201
+ pbar.update(1)
202
+ continue
203
+
204
  results[result["task_id"]].append((result["completion_id"], result))
205
  pbar.update(1)
206
 
207
+ if not result["passed"]:
208
+ future_list = future_dict[result["task_id"]][result["completion_id"]]
209
+ for future in future_list:
210
+ future.cancel()
211
+
212
+ new_results = {}
213
+ for key, result in results.items():
214
  new_result = []
215
  for completion_id, group in itertools.groupby(result, key=lambda x: x[0]):
216
  group = list(group)
217
  new_result.append(
218
+ (
219
+ group[0][0],
220
+ dict(
221
+ task_id=group[0][0],
222
+ passed=all(r[1]["passed"] for r in group),
223
+ result=[r[1]["result"] for r in group],
224
+ completion_id=completion_id,
225
+ ),
226
  )
227
  )
228
+ new_results[key] = new_result
229
+ results = new_results
230
 
231
  total, correct = [], []
232
  for result in results.values():
233
+ result.sort(key=lambda x: x[0])
234
  passed = [r[1]["passed"] for r in result]
235
  total.append(len(passed))
236
  correct.append(sum(passed))