jjkim
commited on
Commit
·
2128ba2
1
Parent(s):
f35f0d4
refactor & fix order bug & add early stop option
Browse files- code_eval.py +71 -58
- requirements.txt +2 -1
code_eval.py
CHANGED
@@ -20,11 +20,14 @@ import itertools
|
|
20 |
import os
|
21 |
from collections import Counter, defaultdict
|
22 |
from concurrent.futures import CancelledError, ThreadPoolExecutor, as_completed
|
|
|
|
|
23 |
|
24 |
import datasets
|
25 |
import evaluate
|
26 |
import numpy as np
|
27 |
from tqdm import tqdm
|
|
|
28 |
|
29 |
from .execute import check_correctness
|
30 |
|
@@ -155,9 +158,11 @@ class CodeEval(evaluate.Metric):
|
|
155 |
self,
|
156 |
predictions,
|
157 |
references,
|
|
|
158 |
k=[1, 10, 100],
|
159 |
num_workers=4,
|
160 |
timeout=3.0,
|
|
|
161 |
):
|
162 |
"""Returns the scores"""
|
163 |
|
@@ -169,69 +174,43 @@ class CodeEval(evaluate.Metric):
|
|
169 |
"This metric is currently not supported on Windows."
|
170 |
)
|
171 |
|
|
|
|
|
172 |
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
for _test_case in test_case:
|
183 |
-
assert isinstance(_test_case, str)
|
184 |
-
test_program = candidate + "\n" + _test_case
|
185 |
-
args = (
|
186 |
-
test_program,
|
187 |
-
timeout,
|
188 |
-
task_id,
|
189 |
-
completion_id[task_id],
|
190 |
-
)
|
191 |
future = executor.submit(check_correctness, *args)
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
for
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
pbar.update(
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
for key, result in results.items():
|
214 |
-
new_result = []
|
215 |
-
result.sort(key=lambda x: x[0])
|
216 |
-
for completion_id, group in itertools.groupby(result, key=lambda x: x[0]):
|
217 |
-
group = list(group)
|
218 |
-
new_result.append(
|
219 |
-
(
|
220 |
-
completion_id,
|
221 |
-
dict(
|
222 |
-
task_id=key,
|
223 |
-
passed=all(r[1]["passed"] for r in group),
|
224 |
-
result=[r[1]["result"] for r in group],
|
225 |
-
completion_id=completion_id,
|
226 |
-
),
|
227 |
-
)
|
228 |
-
)
|
229 |
-
new_results[key] = new_result
|
230 |
-
results = new_results
|
231 |
|
232 |
total, correct = [], []
|
233 |
for result in results.values():
|
234 |
-
result.sort(key=lambda x: x[0])
|
235 |
passed = [r[1]["passed"] for r in result]
|
236 |
total.append(len(passed))
|
237 |
correct.append(sum(passed))
|
@@ -266,3 +245,37 @@ def estimate_pass_at_k(num_samples, num_correct, k):
|
|
266 |
return np.array(
|
267 |
[estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)]
|
268 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
import os
|
21 |
from collections import Counter, defaultdict
|
22 |
from concurrent.futures import CancelledError, ThreadPoolExecutor, as_completed
|
23 |
+
from typing import List, Optional
|
24 |
+
import time
|
25 |
|
26 |
import datasets
|
27 |
import evaluate
|
28 |
import numpy as np
|
29 |
from tqdm import tqdm
|
30 |
+
from pydantic import BaseModel
|
31 |
|
32 |
from .execute import check_correctness
|
33 |
|
|
|
158 |
self,
|
159 |
predictions,
|
160 |
references,
|
161 |
+
task_ids=None,
|
162 |
k=[1, 10, 100],
|
163 |
num_workers=4,
|
164 |
timeout=3.0,
|
165 |
+
early_stop=False,
|
166 |
):
|
167 |
"""Returns the scores"""
|
168 |
|
|
|
174 |
"This metric is currently not supported on Windows."
|
175 |
)
|
176 |
|
177 |
+
task_ids = task_ids or list(range(len(predictions)))
|
178 |
+
|
179 |
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
180 |
+
results = {}
|
181 |
+
for tid, pred, ref in zip(task_ids, predictions, references):
|
182 |
+
results[tid] = []
|
183 |
+
for candidate in pred:
|
184 |
+
result = Result(task_id=tid, completion_id=len(results))
|
185 |
+
for test_case in ref:
|
186 |
+
assert isinstance(test_case, str)
|
187 |
+
test_program = candidate + "\n" + test_case
|
188 |
+
args = (test_program, timeout, tid)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
future = executor.submit(check_correctness, *args)
|
190 |
+
result.add(future)
|
191 |
+
results[tid].append(result)
|
192 |
+
|
193 |
+
pbar = tqdm(total=len(results))
|
194 |
+
prev_done_count = 0
|
195 |
+
while not all(r.done() for r in results.values()):
|
196 |
+
cur_done_count = 0
|
197 |
+
for result in results.values():
|
198 |
+
for r in result:
|
199 |
+
if not r.done():
|
200 |
+
r.refresh(early_stop)
|
201 |
+
else:
|
202 |
+
cur_done_count += 1
|
203 |
+
pbar.update(cur_done_count - prev_done_count)
|
204 |
+
prev_done_count = cur_done_count
|
205 |
+
time.sleep(1)
|
206 |
+
|
207 |
+
results = {
|
208 |
+
task_id: [(r.completion_id, r.dict(exclude={"futures"})) for r in result]
|
209 |
+
for task_id, result in results.items()
|
210 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
|
212 |
total, correct = [], []
|
213 |
for result in results.values():
|
|
|
214 |
passed = [r[1]["passed"] for r in result]
|
215 |
total.append(len(passed))
|
216 |
correct.append(sum(passed))
|
|
|
245 |
return np.array(
|
246 |
[estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)]
|
247 |
)
|
248 |
+
|
249 |
+
|
250 |
+
class Result(BaseModel):
|
251 |
+
task_id: int
|
252 |
+
completion_id: int
|
253 |
+
|
254 |
+
passed: Optional[bool] = None
|
255 |
+
result: List[str] = []
|
256 |
+
futures: List[object] = []
|
257 |
+
|
258 |
+
def add(self, future):
|
259 |
+
self.futures.append(future)
|
260 |
+
self.result.append(None)
|
261 |
+
|
262 |
+
def refresh(self, early_stop=False):
|
263 |
+
for i, future in enumerate(self.futures):
|
264 |
+
if self.result[i] is None and future.done():
|
265 |
+
try:
|
266 |
+
self.result[i] = future.result()
|
267 |
+
except CancelledError:
|
268 |
+
self.result[i] = "Early Stopped"
|
269 |
+
except Exception as e:
|
270 |
+
self.result[i] = str(e)
|
271 |
+
|
272 |
+
if early_stop:
|
273 |
+
# cancel all other futures
|
274 |
+
for future in self.futures[i + 1 :]:
|
275 |
+
future.cancel()
|
276 |
+
|
277 |
+
if all(r is not None for r in self.result):
|
278 |
+
self.passed = all(r["passed"] for r in self.result)
|
279 |
+
|
280 |
+
def done(self):
|
281 |
+
return self.passed is not None
|
requirements.txt
CHANGED
@@ -1 +1,2 @@
|
|
1 |
-
|
|
|
|
1 |
+
pydantic
|
2 |
+
numpy
|