lvwerra HF staff commited on
Commit
dad91c7
·
1 Parent(s): 2cd389f

Update Space (evaluate main: c447fc8e)

Browse files
Files changed (2) hide show
  1. mse.py +3 -25
  2. requirements.txt +1 -1
mse.py CHANGED
@@ -13,9 +13,6 @@
13
  # limitations under the License.
14
  """MSE - Mean Squared Error Metric"""
15
 
16
- from dataclasses import dataclass
17
- from typing import List, Optional
18
-
19
  import datasets
20
  from sklearn.metrics import mean_squared_error
21
 
@@ -88,28 +85,13 @@ Examples:
88
  """
89
 
90
 
91
- @dataclass
92
- class MseConfig(evaluate.info.Config):
93
-
94
- name: str = "default"
95
-
96
- multioutput: str = "uniform_average"
97
- sample_weight: Optional[List[float]] = None
98
- squared: bool = True
99
-
100
-
101
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
102
  class Mse(evaluate.Metric):
103
-
104
- CONFIG_CLASS = MseConfig
105
- ALLOWED_CONFIG_NAMES = ["default", "multilist"]
106
-
107
- def _info(self, config):
108
  return evaluate.MetricInfo(
109
  description=_DESCRIPTION,
110
  citation=_CITATION,
111
  inputs_description=_KWARGS_DESCRIPTION,
112
- config=config,
113
  features=datasets.Features(self._get_feature_types()),
114
  reference_urls=[
115
  "https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_error.html"
@@ -128,14 +110,10 @@ class Mse(evaluate.Metric):
128
  "references": datasets.Value("float"),
129
  }
130
 
131
- def _compute(self, predictions, references):
132
 
133
  mse = mean_squared_error(
134
- references,
135
- predictions,
136
- sample_weight=self.config.sample_weight,
137
- multioutput=self.config.multioutput,
138
- squared=self.config.squared,
139
  )
140
 
141
  return {"mse": mse}
 
13
  # limitations under the License.
14
  """MSE - Mean Squared Error Metric"""
15
 
 
 
 
16
  import datasets
17
  from sklearn.metrics import mean_squared_error
18
 
 
85
  """
86
 
87
 
 
 
 
 
 
 
 
 
 
 
88
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
89
  class Mse(evaluate.Metric):
90
+ def _info(self):
 
 
 
 
91
  return evaluate.MetricInfo(
92
  description=_DESCRIPTION,
93
  citation=_CITATION,
94
  inputs_description=_KWARGS_DESCRIPTION,
 
95
  features=datasets.Features(self._get_feature_types()),
96
  reference_urls=[
97
  "https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_error.html"
 
110
  "references": datasets.Value("float"),
111
  }
112
 
113
+ def _compute(self, predictions, references, sample_weight=None, multioutput="uniform_average", squared=True):
114
 
115
  mse = mean_squared_error(
116
+ references, predictions, sample_weight=sample_weight, multioutput=multioutput, squared=squared
 
 
 
 
117
  )
118
 
119
  return {"mse": mse}
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
- git+https://github.com/huggingface/evaluate@e4a2724377909fe2aeb4357e3971e5a569673b39
2
  sklearn
 
1
+ git+https://github.com/huggingface/evaluate@c447fc8eda9c62af501bfdc6988919571050d950
2
  sklearn