jjkim commited on
Commit
cc070df
·
1 Parent(s): 346d7a2

change candidate to prediction

Browse files
Files changed (1) hide show
  1. code_eval.py +14 -14
code_eval.py CHANGED
@@ -157,9 +157,9 @@ class CodeEval(evaluate.Metric):
157
 
158
  def _compute(
159
  self,
160
- candidates,
161
- cand_key,
162
- cand_template,
163
  references,
164
  ref_key,
165
  ref_template,
@@ -179,27 +179,27 @@ class CodeEval(evaluate.Metric):
179
  "This metric is currently not supported on Windows."
180
  )
181
 
182
- candidates = sorted(candidates, key=lambda x: x["id"])
183
  references = sorted(references, key=lambda x: x["id"])
184
  with ThreadPoolExecutor(max_workers=num_workers) as executor:
185
  results = {}
186
- for cand_d, ref_d in zip(candidates, references):
187
- assert cand_d["id"] == ref_d["id"]
188
- tid = cand_d["id"]
189
 
190
  results[tid] = []
191
- cand = cand_d[cand_key]
192
  ref = ref_d[ref_key]
193
- for cid, c in enumerate(cand):
194
- result = Result(task_id=tid, completion_id=cid)
195
- body = Template(cand_template).safe_substitute(candidate=c)
196
  for r in ref:
197
  assert isinstance(r, str)
198
  test = Template(ref_template).safe_substitute(ref_key=r)
199
- test = Template(test).safe_substitute(reference=c)
200
 
201
  test_program = body + "\n" + test
202
- args = (test_program, timeout, tid, cid)
203
  future = executor.submit(check_correctness, *args)
204
  result.add(future)
205
  results[tid].append(result)
@@ -266,7 +266,7 @@ def estimate_pass_at_k(num_samples, num_correct, k):
266
 
267
  class Result(BaseModel):
268
  task_id: str
269
- completion_id: int
270
 
271
  passed: Optional[bool] = None
272
  result: List[str] = []
 
157
 
158
  def _compute(
159
  self,
160
+ predictions,
161
+ pred_key,
162
+ pred_template,
163
  references,
164
  ref_key,
165
  ref_template,
 
179
  "This metric is currently not supported on Windows."
180
  )
181
 
182
+ predictions = sorted(predictions, key=lambda x: x["id"])
183
  references = sorted(references, key=lambda x: x["id"])
184
  with ThreadPoolExecutor(max_workers=num_workers) as executor:
185
  results = {}
186
+ for pred_d, ref_d in zip(predictions, references):
187
+ assert pred_d["id"] == ref_d["id"]
188
+ tid = pred_d["id"]
189
 
190
  results[tid] = []
191
+ pred = pred_d[pred_key]
192
  ref = ref_d[ref_key]
193
+ for pid, p in enumerate(pred):
194
+ result = Result(task_id=tid, prediction_id=pid)
195
+ body = Template(pred_template).safe_substitute(prediction=p)
196
  for r in ref:
197
  assert isinstance(r, str)
198
  test = Template(ref_template).safe_substitute(ref_key=r)
199
+ test = Template(test).safe_substitute(prediction=p)
200
 
201
  test_program = body + "\n" + test
202
+ args = (test_program, timeout, tid, pid)
203
  future = executor.submit(check_correctness, *args)
204
  result.add(future)
205
  results[tid].append(result)
 
266
 
267
  class Result(BaseModel):
268
  task_id: str
269
+ prediction_id: int
270
 
271
  passed: Optional[bool] = None
272
  result: List[str] = []