jjkim commited on
Commit
747a38b
·
1 Parent(s): 96c951b

fix metric info

Browse files
Files changed (1) hide show
  1. code_eval.py +7 -10
code_eval.py CHANGED
@@ -145,8 +145,11 @@ class CodeEval(evaluate.Metric):
145
  # This defines the format of each prediction and reference
146
  features=datasets.Features(
147
  {
148
- "predictions": defaultdict(lambda: datasets.Value("string")),
149
- "references": defaultdict(lambda: datasets.Value("string")),
 
 
 
150
  }
151
  ),
152
  homepage="https://github.com/openai/human-eval",
@@ -157,11 +160,10 @@ class CodeEval(evaluate.Metric):
157
 
158
  def _compute(
159
  self,
 
160
  predictions,
161
- pred_key,
162
  pred_template,
163
  references,
164
- ref_key,
165
  ref_template,
166
  k=[1, 10, 100],
167
  num_workers=4,
@@ -179,14 +181,9 @@ class CodeEval(evaluate.Metric):
179
  "This metric is currently not supported on Windows."
180
  )
181
 
182
- predictions = sorted(predictions, key=lambda x: x["id"])
183
- references = sorted(references, key=lambda x: x["id"])
184
  with ThreadPoolExecutor(max_workers=num_workers) as executor:
185
  results = {}
186
- for pred_d, ref_d in zip(predictions, references):
187
- assert pred_d["id"] == ref_d["id"]
188
- tid = pred_d["id"]
189
-
190
  results[tid] = []
191
  pred = pred_d[pred_key]
192
  ref = ref_d[ref_key]
 
145
  # This defines the format of each prediction and reference
146
  features=datasets.Features(
147
  {
148
+ "ids": datasets.Value("string"),
149
+ "predictions": datasets.Sequence(datasets.Value("string")),
150
+ "pred_template": datasets.Value("string"),
151
+ "references": datasets.Sequence(datasets.Value("string")),
152
+ "ref_template": datasets.Value("string"),
153
  }
154
  ),
155
  homepage="https://github.com/openai/human-eval",
 
160
 
161
  def _compute(
162
  self,
163
+ ids,
164
  predictions,
 
165
  pred_template,
166
  references,
 
167
  ref_template,
168
  k=[1, 10, 100],
169
  num_workers=4,
 
181
  "This metric is currently not supported on Windows."
182
  )
183
 
 
 
184
  with ThreadPoolExecutor(max_workers=num_workers) as executor:
185
  results = {}
186
+ for tid, pred_d, ref_d in zip(ids, predictions, references):
 
 
 
187
  results[tid] = []
188
  pred = pred_d[pred_key]
189
  ref = ref_d[ref_key]