jjkim commited on
Commit
2128ba2
·
1 Parent(s): f35f0d4

refactor & fix order bug & add early stop option

Browse files
Files changed (2) hide show
  1. code_eval.py +71 -58
  2. requirements.txt +2 -1
code_eval.py CHANGED
@@ -20,11 +20,14 @@ import itertools
20
  import os
21
  from collections import Counter, defaultdict
22
  from concurrent.futures import CancelledError, ThreadPoolExecutor, as_completed
 
 
23
 
24
  import datasets
25
  import evaluate
26
  import numpy as np
27
  from tqdm import tqdm
 
28
 
29
  from .execute import check_correctness
30
 
@@ -155,9 +158,11 @@ class CodeEval(evaluate.Metric):
155
  self,
156
  predictions,
157
  references,
 
158
  k=[1, 10, 100],
159
  num_workers=4,
160
  timeout=3.0,
 
161
  ):
162
  """Returns the scores"""
163
 
@@ -169,69 +174,43 @@ class CodeEval(evaluate.Metric):
169
  "This metric is currently not supported on Windows."
170
  )
171
 
 
 
172
  with ThreadPoolExecutor(max_workers=num_workers) as executor:
173
- futures = []
174
- future_dict = defaultdict(lambda: defaultdict(list))
175
- completion_id = Counter()
176
- results = defaultdict(list)
177
-
178
- for task_id, (candidates, test_case) in enumerate(
179
- zip(predictions, references)
180
- ):
181
- for candidate in candidates:
182
- for _test_case in test_case:
183
- assert isinstance(_test_case, str)
184
- test_program = candidate + "\n" + _test_case
185
- args = (
186
- test_program,
187
- timeout,
188
- task_id,
189
- completion_id[task_id],
190
- )
191
  future = executor.submit(check_correctness, *args)
192
- futures.append(future)
193
- future_dict[task_id][completion_id[task_id]].append(future)
194
- completion_id[task_id] += 1
195
-
196
- pbar = tqdm(total=len(futures))
197
- for future in as_completed(futures):
198
- try:
199
- result = future.result()
200
- except CancelledError:
201
- pbar.update(1)
202
- continue
203
-
204
- results[result["task_id"]].append((result["completion_id"], result))
205
- pbar.update(1)
206
-
207
- if not result["passed"]:
208
- future_list = future_dict[result["task_id"]][result["completion_id"]]
209
- for future in future_list:
210
- future.cancel()
211
-
212
- new_results = {}
213
- for key, result in results.items():
214
- new_result = []
215
- result.sort(key=lambda x: x[0])
216
- for completion_id, group in itertools.groupby(result, key=lambda x: x[0]):
217
- group = list(group)
218
- new_result.append(
219
- (
220
- completion_id,
221
- dict(
222
- task_id=key,
223
- passed=all(r[1]["passed"] for r in group),
224
- result=[r[1]["result"] for r in group],
225
- completion_id=completion_id,
226
- ),
227
- )
228
- )
229
- new_results[key] = new_result
230
- results = new_results
231
 
232
  total, correct = [], []
233
  for result in results.values():
234
- result.sort(key=lambda x: x[0])
235
  passed = [r[1]["passed"] for r in result]
236
  total.append(len(passed))
237
  correct.append(sum(passed))
@@ -266,3 +245,37 @@ def estimate_pass_at_k(num_samples, num_correct, k):
266
  return np.array(
267
  [estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)]
268
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  import os
21
  from collections import Counter, defaultdict
22
  from concurrent.futures import CancelledError, ThreadPoolExecutor, as_completed
23
+ from typing import List, Optional
24
+ import time
25
 
26
  import datasets
27
  import evaluate
28
  import numpy as np
29
  from tqdm import tqdm
30
+ from pydantic import BaseModel
31
 
32
  from .execute import check_correctness
33
 
 
158
  self,
159
  predictions,
160
  references,
161
+ task_ids=None,
162
  k=[1, 10, 100],
163
  num_workers=4,
164
  timeout=3.0,
165
+ early_stop=False,
166
  ):
167
  """Returns the scores"""
168
 
 
174
  "This metric is currently not supported on Windows."
175
  )
176
 
177
+ task_ids = task_ids or list(range(len(predictions)))
178
+
179
  with ThreadPoolExecutor(max_workers=num_workers) as executor:
180
+ results = {}
181
+ for tid, pred, ref in zip(task_ids, predictions, references):
182
+ results[tid] = []
183
+ for candidate in pred:
184
+ result = Result(task_id=tid, completion_id=len(results))
185
+ for test_case in ref:
186
+ assert isinstance(test_case, str)
187
+ test_program = candidate + "\n" + test_case
188
+ args = (test_program, timeout, tid)
 
 
 
 
 
 
 
 
 
189
  future = executor.submit(check_correctness, *args)
190
+ result.add(future)
191
+ results[tid].append(result)
192
+
193
+ pbar = tqdm(total=len(results))
194
+ prev_done_count = 0
195
+ while not all(r.done() for r in results.values()):
196
+ cur_done_count = 0
197
+ for result in results.values():
198
+ for r in result:
199
+ if not r.done():
200
+ r.refresh(early_stop)
201
+ else:
202
+ cur_done_count += 1
203
+ pbar.update(cur_done_count - prev_done_count)
204
+ prev_done_count = cur_done_count
205
+ time.sleep(1)
206
+
207
+ results = {
208
+ task_id: [(r.completion_id, r.dict(exclude={"futures"})) for r in result]
209
+ for task_id, result in results.items()
210
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
  total, correct = [], []
213
  for result in results.values():
 
214
  passed = [r[1]["passed"] for r in result]
215
  total.append(len(passed))
216
  correct.append(sum(passed))
 
245
  return np.array(
246
  [estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)]
247
  )
248
+
249
+
250
+ class Result(BaseModel):
251
+ task_id: int
252
+ completion_id: int
253
+
254
+ passed: Optional[bool] = None
255
+ result: List[str] = []
256
+ futures: List[object] = []
257
+
258
+ def add(self, future):
259
+ self.futures.append(future)
260
+ self.result.append(None)
261
+
262
+ def refresh(self, early_stop=False):
263
+ for i, future in enumerate(self.futures):
264
+ if self.result[i] is None and future.done():
265
+ try:
266
+ self.result[i] = future.result()
267
+ except CancelledError:
268
+ self.result[i] = "Early Stopped"
269
+ except Exception as e:
270
+ self.result[i] = str(e)
271
+
272
+ if early_stop:
273
+ # cancel all other futures
274
+ for future in self.futures[i + 1 :]:
275
+ future.cancel()
276
+
277
+ if all(r is not None for r in self.result):
278
+ self.passed = all(r["passed"] for r in self.result)
279
+
280
+ def done(self):
281
+ return self.passed is not None
requirements.txt CHANGED
@@ -1 +1,2 @@
1
- git+https://github.com/huggingface/evaluate@af3c30561d840b83e54fc5f7150ea58046d6af69
 
 
1
+ pydantic
2
+ numpy