erntkn commited on
Commit
8c2c847
β€’
1 Parent(s): 68c2f62
Files changed (1) hide show
  1. dice_coefficient.py +222 -50
dice_coefficient.py CHANGED
@@ -11,85 +11,257 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
- """TODO: Add a description here."""
15
 
 
 
16
  import evaluate
17
  import datasets
18
 
19
 
20
- # TODO: Add BibTeX citation
21
- _CITATION = """\
22
- @InProceedings{huggingface:module,
23
- title = {A great new module},
24
- authors={huggingface, Inc.},
25
- year={2020}
26
- }
27
- """
28
-
29
- # TODO: Add description of the module here
30
  _DESCRIPTION = """\
31
- This new module is designed to solve this great ML task and is crafted with a lot of care.
32
  """
33
 
34
 
35
- # TODO: Add description of the arguments of the module here
36
  _KWARGS_DESCRIPTION = """
37
- Calculates how good are predictions given some references, using certain scores
38
  Args:
39
- predictions: list of predictions to score. Each predictions
40
- should be a string with tokens separated by spaces.
41
- references: list of reference for each prediction. Each
42
- reference should be a string with tokens separated by spaces.
 
 
 
 
 
 
 
 
 
 
 
43
  Returns:
44
- accuracy: description of the first score,
45
- another_score: description of the second score,
 
46
  Examples:
47
- Examples should be written in doctest format, and should illustrate how
48
- to use the function.
49
-
50
- >>> my_new_module = evaluate.load("my_new_module")
51
- >>> results = my_new_module.compute(references=[0, 1], predictions=[0, 1])
 
 
 
 
 
 
 
52
  >>> print(results)
53
- {'accuracy': 1.0}
54
  """
55
 
56
- # TODO: Define external resources urls if needed
57
- BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
61
  class DiceCoefficient(evaluate.Metric):
62
- """TODO: Short description of my evaluation module."""
63
-
64
  def _info(self):
65
- # TODO: Specifies the evaluate.EvaluationModuleInfo object
66
  return evaluate.MetricInfo(
67
- # This is the description that will appear on the modules page.
68
  module_type="metric",
69
  description=_DESCRIPTION,
70
  citation=_CITATION,
71
  inputs_description=_KWARGS_DESCRIPTION,
72
- # This defines the format of each prediction and reference
73
  features=datasets.Features({
74
  'predictions': datasets.Value('int64'),
75
  'references': datasets.Value('int64'),
76
  }),
77
- # Homepage of the module for documentation
78
- homepage="http://module.homepage",
79
- # Additional links to the codebase or references
80
- codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
81
- reference_urls=["http://path.to.reference.url/new_module"]
82
  )
83
 
84
- def _download_and_prepare(self, dl_manager):
85
- """Optional: download external resources useful to compute the scores"""
86
- # TODO: Download external resources if needed
87
- pass
88
-
89
- def _compute(self, predictions, references):
90
- """Returns the scores"""
91
- # TODO: Compute the different scores of the module
92
- accuracy = sum(i == j for i, j in zip(predictions, references)) / len(predictions)
93
- return {
94
- "accuracy": accuracy,
95
- }
 
 
 
 
 
 
 
 
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
+ """Dice Coefficient Metric."""
15
 
16
+ from typing import Dict, Optional
17
+ import numpy as np
18
  import evaluate
19
  import datasets
20
 
21
 
 
 
 
 
 
 
 
 
 
 
22
  _DESCRIPTION = """\
23
+ Dice coefficient is 2 times the are of overlap divided by the total number of pixels in both segmentation maps.
24
  """
25
 
26
 
 
27
  _KWARGS_DESCRIPTION = """
 
28
  Args:
29
+ predictions (`List[ndarray]`):
30
+ List of predicted segmentation maps, each of shape (height, width). Each segmentation map can be of a different size.
31
+ references (`List[ndarray]`):
32
+ List of ground truth segmentation maps, each of shape (height, width). Each segmentation map can be of a different size.
33
+ num_labels (`int`):
34
+ Number of classes (categories).
35
+ ignore_index (`int`):
36
+ Index that will be ignored during evaluation.
37
+ nan_to_num (`int`, *optional*):
38
+ If specified, NaN values will be replaced by the number defined by the user.
39
+ label_map (`dict`, *optional*):
40
+ If specified, dictionary mapping old label indices to new label indices.
41
+ reduce_labels (`bool`, *optional*, defaults to `False`):
42
+ Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is used for background,
43
+ and background itself is not included in all classes of a dataset (e.g. ADE20k). The background label will be replaced by 255.
44
  Returns:
45
+ `Dict[str, float | ndarray]` comprising various elements:
46
+ - *dice_score* (`float`):
47
+ Dice Coefficient.
48
  Examples:
49
+ >>> import numpy as np
50
+ >>> dice = evaluate.load("DiceCoefficient")
51
+ >>> # suppose one has 3 different segmentation maps predicted
52
+ >>> predicted_1 = np.array([[1, 2], [3, 4], [5, 255]])
53
+ >>> actual_1 = np.array([[0, 3], [5, 4], [6, 255]])
54
+ >>> predicted_2 = np.array([[2, 7], [9, 2], [3, 6]])
55
+ >>> actual_2 = np.array([[1, 7], [9, 2], [3, 6]])
56
+ >>> predicted_3 = np.array([[2, 2, 3], [8, 2, 4], [3, 255, 2]])
57
+ >>> actual_3 = np.array([[1, 2, 2], [8, 2, 1], [3, 255, 1]])
58
+ >>> predicted = [predicted_1, predicted_2, predicted_3]
59
+ >>> ground_truth = [actual_1, actual_2, actual_3]
60
+ >>> results = dice.compute(predictions=predicted, references=ground_truth, num_labels=10, ignore_index=255, reduce_labels=False)
61
  >>> print(results)
62
+ {'dice_score': 0.47750000}
63
  """
64
 
65
+ _CITATION = """\
66
+ @software{MMSegmentation_Contributors_OpenMMLab_Semantic_Segmentation_2020,
67
+ author = {{MMSegmentation Contributors}},
68
+ license = {Apache-2.0},
69
+ month = {7},
70
+ title = {{OpenMMLab Semantic Segmentation Toolbox and Benchmark}},
71
+ url = {https://github.com/open-mmlab/mmsegmentation},
72
+ year = {2020}
73
+ }"""
74
+
75
+
76
+ def intersect_and_union(
77
+ pred_label,
78
+ label,
79
+ num_labels,
80
+ ignore_index: bool,
81
+ label_map: Optional[Dict[int, int]] = None,
82
+ reduce_labels: bool = False,
83
+ ):
84
+ """Calculate intersection and Union.
85
+ Args:
86
+ pred_label (`ndarray`):
87
+ Prediction segmentation map of shape (height, width).
88
+ label (`ndarray`):
89
+ Ground truth segmentation map of shape (height, width).
90
+ num_labels (`int`):
91
+ Number of categories.
92
+ ignore_index (`int`):
93
+ Index that will be ignored during evaluation.
94
+ label_map (`dict`, *optional*):
95
+ Mapping old labels to new labels. The parameter will work only when label is str.
96
+ reduce_labels (`bool`, *optional*, defaults to `False`):
97
+ Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is used for background,
98
+ and background itself is not included in all classes of a dataset (e.g. ADE20k). The background label will be replaced by 255.
99
+ Returns:
100
+ area_intersect (`ndarray`):
101
+ The intersection of prediction and ground truth histogram on all classes.
102
+ area_union (`ndarray`):
103
+ The union of prediction and ground truth histogram on all classes.
104
+ area_pred_label (`ndarray`):
105
+ The prediction histogram on all classes.
106
+ area_label (`ndarray`):
107
+ The ground truth histogram on all classes.
108
+ """
109
+ if label_map is not None:
110
+ for old_id, new_id in label_map.items():
111
+ label[label == old_id] = new_id
112
+
113
+ # turn into Numpy arrays
114
+ pred_label = np.array(pred_label)
115
+ label = np.array(label)
116
+
117
+ if reduce_labels:
118
+ label[label == 0] = 255
119
+ label = label - 1
120
+ label[label == 254] = 255
121
+
122
+ mask = label != ignore_index
123
+ mask = np.not_equal(label, ignore_index)
124
+ pred_label = pred_label[mask]
125
+ label = np.array(label)[mask]
126
+
127
+ intersect = pred_label[pred_label == label]
128
+
129
+ area_intersect = np.histogram(intersect, bins=num_labels, range=(0, num_labels - 1))[0]
130
+ area_pred_label = np.histogram(pred_label, bins=num_labels, range=(0, num_labels - 1))[0]
131
+ area_label = np.histogram(label, bins=num_labels, range=(0, num_labels - 1))[0]
132
 
133
+ area_union = area_pred_label + area_label - area_intersect
134
+
135
+ return area_intersect, area_union, area_pred_label, area_label
136
+
137
+
138
+ def total_intersect_and_union(
139
+ results,
140
+ gt_seg_maps,
141
+ num_labels,
142
+ ignore_index: bool,
143
+ label_map: Optional[Dict[int, int]] = None,
144
+ reduce_labels: bool = False,
145
+ ):
146
+ """Calculate Total Intersection and Union, by calculating `intersect_and_union` for each (predicted, ground truth) pair.
147
+ Args:
148
+ results (`ndarray`):
149
+ List of prediction segmentation maps, each of shape (height, width).
150
+ gt_seg_maps (`ndarray`):
151
+ List of ground truth segmentation maps, each of shape (height, width).
152
+ num_labels (`int`):
153
+ Number of categories.
154
+ ignore_index (`int`):
155
+ Index that will be ignored during evaluation.
156
+ label_map (`dict`, *optional*):
157
+ Mapping old labels to new labels. The parameter will work only when label is str.
158
+ reduce_labels (`bool`, *optional*, defaults to `False`):
159
+ Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is used for background,
160
+ and background itself is not included in all classes of a dataset (e.g. ADE20k). The background label will be replaced by 255.
161
+ Returns:
162
+ total_area_intersect (`ndarray`):
163
+ The intersection of prediction and ground truth histogram on all classes.
164
+ total_area_union (`ndarray`):
165
+ The union of prediction and ground truth histogram on all classes.
166
+ total_area_pred_label (`ndarray`):
167
+ The prediction histogram on all classes.
168
+ total_area_label (`ndarray`):
169
+ The ground truth histogram on all classes.
170
+ """
171
+ total_area_intersect = np.zeros((num_labels,), dtype=np.float64)
172
+ total_area_union = np.zeros((num_labels,), dtype=np.float64)
173
+ total_area_pred_label = np.zeros((num_labels,), dtype=np.float64)
174
+ total_area_label = np.zeros((num_labels,), dtype=np.float64)
175
+ for result, gt_seg_map in zip(results, gt_seg_maps):
176
+ area_intersect, area_union, area_pred_label, area_label = intersect_and_union(
177
+ result, gt_seg_map, num_labels, ignore_index, label_map, reduce_labels
178
+ )
179
+ total_area_intersect += area_intersect
180
+ total_area_union += area_union
181
+ total_area_pred_label += area_pred_label
182
+ total_area_label += area_label
183
+ return total_area_intersect, total_area_union, total_area_pred_label, total_area_label
184
+
185
+
186
+ def dice_coef(
187
+ results,
188
+ gt_seg_maps,
189
+ num_labels,
190
+ ignore_index: bool,
191
+ nan_to_num: Optional[int] = None,
192
+ label_map: Optional[Dict[int, int]] = None,
193
+ reduce_labels: bool = False,
194
+ ):
195
+ """Calculate Mean Dice Coefficient (mDSC).
196
+ Args:
197
+ results (`ndarray`):
198
+ List of prediction segmentation maps, each of shape (height, width).
199
+ gt_seg_maps (`ndarray`):
200
+ List of ground truth segmentation maps, each of shape (height, width).
201
+ num_labels (`int`):
202
+ Number of categories.
203
+ ignore_index (`int`):
204
+ Index that will be ignored during evaluation.
205
+ nan_to_num (`int`, *optional*):
206
+ If specified, NaN values will be replaced by the number defined by the user.
207
+ label_map (`dict`, *optional*):
208
+ Mapping old labels to new labels. The parameter will work only when label is str.
209
+ reduce_labels (`bool`, *optional*, defaults to `False`):
210
+ Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is used for background,
211
+ and background itself is not included in all classes of a dataset (e.g. ADE20k). The background label will be replaced by 255.
212
+ Returns:
213
+ `Dict[str, float | ndarray]` comprising various elements:
214
+ - *mean_dsc* (`float`):
215
+ Mean Dice Coefficient (DSC averaged over all categories).
216
+ """
217
+ total_area_intersect, _, total_area_pred_label, total_area_label = total_intersect_and_union(
218
+ results, gt_seg_maps, num_labels, ignore_index, label_map, reduce_labels
219
+ )
220
+
221
+ result = dict()
222
+ dice = 2 * total_area_intersect / (total_area_pred_label + total_area_label)
223
+ result["dice_score"] = np.nanmean(dice)
224
+
225
+ if nan_to_num is not None:
226
+ metrics = dict(
227
+ {metric: np.nan_to_num(metric_value, nan=nan_to_num) for metric, metric_value in metrics.items()}
228
+ )
229
+
230
+ return result
231
+
232
 
233
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
234
  class DiceCoefficient(evaluate.Metric):
 
 
235
  def _info(self):
 
236
  return evaluate.MetricInfo(
 
237
  module_type="metric",
238
  description=_DESCRIPTION,
239
  citation=_CITATION,
240
  inputs_description=_KWARGS_DESCRIPTION,
 
241
  features=datasets.Features({
242
  'predictions': datasets.Value('int64'),
243
  'references': datasets.Value('int64'),
244
  }),
245
+ reference_urls=["https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/core/evaluation/metrics.py"]
 
 
 
 
246
  )
247
 
248
+ def _compute(
249
+ self,
250
+ predictions,
251
+ references,
252
+ num_labels: int,
253
+ ignore_index: bool,
254
+ nan_to_num: Optional[int] = None,
255
+ label_map: Optional[Dict[int, int]] = None,
256
+ reduce_labels: bool = False,
257
+ ):
258
+ dice = dice_coef(
259
+ results=predictions,
260
+ ground_truths=references,
261
+ num_labels=num_labels,
262
+ ignore_index=ignore_index,
263
+ nan_to_num=nan_to_num,
264
+ label_map=label_map,
265
+ reduce_labels=reduce_labels,
266
+ )
267
+ return dice