jjkim
commited on
Commit
·
346d7a2
1
Parent(s):
9cbcfb8
change inputs
Browse files- code_eval.py +23 -10
code_eval.py
CHANGED
@@ -22,6 +22,7 @@ from collections import Counter, defaultdict
|
|
22 |
from concurrent.futures import CancelledError, ThreadPoolExecutor, as_completed
|
23 |
from typing import List, Optional
|
24 |
import time
|
|
|
25 |
|
26 |
import datasets
|
27 |
import evaluate
|
@@ -156,9 +157,12 @@ class CodeEval(evaluate.Metric):
|
|
156 |
|
157 |
def _compute(
|
158 |
self,
|
159 |
-
|
|
|
|
|
160 |
references,
|
161 |
-
|
|
|
162 |
k=[1, 10, 100],
|
163 |
num_workers=4,
|
164 |
timeout=3.0,
|
@@ -174,18 +178,27 @@ class CodeEval(evaluate.Metric):
|
|
174 |
raise NotImplementedError(
|
175 |
"This metric is currently not supported on Windows."
|
176 |
)
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
181 |
results = {}
|
182 |
-
for
|
|
|
|
|
|
|
183 |
results[tid] = []
|
184 |
-
|
|
|
|
|
185 |
result = Result(task_id=tid, completion_id=cid)
|
186 |
-
|
187 |
-
|
188 |
-
|
|
|
|
|
|
|
|
|
189 |
args = (test_program, timeout, tid, cid)
|
190 |
future = executor.submit(check_correctness, *args)
|
191 |
result.add(future)
|
|
|
22 |
from concurrent.futures import CancelledError, ThreadPoolExecutor, as_completed
|
23 |
from typing import List, Optional
|
24 |
import time
|
25 |
+
from string import Template
|
26 |
|
27 |
import datasets
|
28 |
import evaluate
|
|
|
157 |
|
158 |
def _compute(
|
159 |
self,
|
160 |
+
candidates,
|
161 |
+
cand_key,
|
162 |
+
cand_template,
|
163 |
references,
|
164 |
+
ref_key,
|
165 |
+
ref_template,
|
166 |
k=[1, 10, 100],
|
167 |
num_workers=4,
|
168 |
timeout=3.0,
|
|
|
178 |
raise NotImplementedError(
|
179 |
"This metric is currently not supported on Windows."
|
180 |
)
|
181 |
+
|
182 |
+
candidates = sorted(candidates, key=lambda x: x["id"])
|
183 |
+
references = sorted(references, key=lambda x: x["id"])
|
184 |
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
185 |
results = {}
|
186 |
+
for cand_d, ref_d in zip(candidates, references):
|
187 |
+
assert cand_d["id"] == ref_d["id"]
|
188 |
+
tid = cand_d["id"]
|
189 |
+
|
190 |
results[tid] = []
|
191 |
+
cand = cand_d[cand_key]
|
192 |
+
ref = ref_d[ref_key]
|
193 |
+
for cid, c in enumerate(cand):
|
194 |
result = Result(task_id=tid, completion_id=cid)
|
195 |
+
body = Template(cand_template).safe_substitute(candidate=c)
|
196 |
+
for r in ref:
|
197 |
+
assert isinstance(r, str)
|
198 |
+
test = Template(ref_template).safe_substitute(ref_key=r)
|
199 |
+
test = Template(test).safe_substitute(reference=c)
|
200 |
+
|
201 |
+
test_program = body + "\n" + test
|
202 |
args = (test_program, timeout, tid, cid)
|
203 |
future = executor.submit(check_correctness, *args)
|
204 |
result.add(future)
|