jjkim commited on
Commit
346d7a2
·
1 Parent(s): 9cbcfb8

change inputs

Browse files
Files changed (1) hide show
  1. code_eval.py +23 -10
code_eval.py CHANGED
@@ -22,6 +22,7 @@ from collections import Counter, defaultdict
22
  from concurrent.futures import CancelledError, ThreadPoolExecutor, as_completed
23
  from typing import List, Optional
24
  import time
 
25
 
26
  import datasets
27
  import evaluate
@@ -156,9 +157,12 @@ class CodeEval(evaluate.Metric):
156
 
157
  def _compute(
158
  self,
159
- predictions,
 
 
160
  references,
161
- task_ids=None,
 
162
  k=[1, 10, 100],
163
  num_workers=4,
164
  timeout=3.0,
@@ -174,18 +178,27 @@ class CodeEval(evaluate.Metric):
174
  raise NotImplementedError(
175
  "This metric is currently not supported on Windows."
176
  )
177
-
178
- task_ids = task_ids or list(range(len(predictions)))
179
-
180
  with ThreadPoolExecutor(max_workers=num_workers) as executor:
181
  results = {}
182
- for tid, pred, ref in zip(task_ids, predictions, references):
 
 
 
183
  results[tid] = []
184
- for cid, candidate in enumerate(pred):
 
 
185
  result = Result(task_id=tid, completion_id=cid)
186
- for test_case in ref:
187
- assert isinstance(test_case, str)
188
- test_program = candidate + "\n" + test_case
 
 
 
 
189
  args = (test_program, timeout, tid, cid)
190
  future = executor.submit(check_correctness, *args)
191
  result.add(future)
 
22
  from concurrent.futures import CancelledError, ThreadPoolExecutor, as_completed
23
  from typing import List, Optional
24
  import time
25
+ from string import Template
26
 
27
  import datasets
28
  import evaluate
 
157
 
158
  def _compute(
159
  self,
160
+ candidates,
161
+ cand_key,
162
+ cand_template,
163
  references,
164
+ ref_key,
165
+ ref_template,
166
  k=[1, 10, 100],
167
  num_workers=4,
168
  timeout=3.0,
 
178
  raise NotImplementedError(
179
  "This metric is currently not supported on Windows."
180
  )
181
+
182
+ candidates = sorted(candidates, key=lambda x: x["id"])
183
+ references = sorted(references, key=lambda x: x["id"])
184
  with ThreadPoolExecutor(max_workers=num_workers) as executor:
185
  results = {}
186
+ for cand_d, ref_d in zip(candidates, references):
187
+ assert cand_d["id"] == ref_d["id"]
188
+ tid = cand_d["id"]
189
+
190
  results[tid] = []
191
+ cand = cand_d[cand_key]
192
+ ref = ref_d[ref_key]
193
+ for cid, c in enumerate(cand):
194
  result = Result(task_id=tid, completion_id=cid)
195
+ body = Template(cand_template).safe_substitute(candidate=c)
196
+ for r in ref:
197
+ assert isinstance(r, str)
198
+ test = Template(ref_template).safe_substitute(ref_key=r)
199
+ test = Template(test).safe_substitute(reference=c)
200
+
201
+ test_program = body + "\n" + test
202
  args = (test_program, timeout, tid, cid)
203
  future = executor.submit(check_correctness, *args)
204
  result.add(future)