aisyahhrazak commited on
Commit
c070f8a
1 Parent(s): 67c4342

Upload flagging.py

Browse files
Files changed (1) hide show
  1. flagging.py +498 -0
flagging.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import csv
4
+ import datetime
5
+ import json
6
+ import os
7
+ import time
8
+ import uuid
9
+ from abc import ABC, abstractmethod
10
+ from collections import OrderedDict
11
+ from pathlib import Path
12
+ from typing import TYPE_CHECKING, Any, Sequence
13
+
14
+ import filelock
15
+ import huggingface_hub
16
+ from gradio_client import utils as client_utils
17
+ from gradio_client.documentation import document
18
+
19
+ import gradio as gr
20
+ from gradio import utils
21
+
22
+ if TYPE_CHECKING:
23
+ from gradio.components import Component
24
+
25
+
26
+ class FlaggingCallback(ABC):
27
+ """
28
+ An abstract class for defining the methods that any FlaggingCallback should have.
29
+ """
30
+
31
+ @abstractmethod
32
+ def setup(self, components: Sequence[Component], flagging_dir: str):
33
+ """
34
+ This method should be overridden and ensure that everything is set up correctly for flag().
35
+ This method gets called once at the beginning of the Interface.launch() method.
36
+ Parameters:
37
+ components: Set of components that will provide flagged data.
38
+ flagging_dir: A string, typically containing the path to the directory where the flagging file should be stored (provided as an argument to Interface.__init__()).
39
+ """
40
+ pass
41
+
42
+ @abstractmethod
43
+ def flag(
44
+ self,
45
+ flag_data: list[Any],
46
+ flag_option: str = "",
47
+ username: str | None = None,
48
+ ) -> int:
49
+ """
50
+ This method should be overridden by the FlaggingCallback subclass and may contain optional additional arguments.
51
+ This gets called every time the <flag> button is pressed.
52
+ Parameters:
53
+ interface: The Interface object that is being used to launch the flagging interface.
54
+ flag_data: The data to be flagged.
55
+ flag_option (optional): In the case that flagging_options are provided, the flag option that is being used.
56
+ username (optional): The username of the user that is flagging the data, if logged in.
57
+ Returns:
58
+ (int) The total number of samples that have been flagged.
59
+ """
60
+ pass
61
+
62
+
63
+ @document()
64
+ class SimpleCSVLogger(FlaggingCallback):
65
+ """
66
+ A simplified implementation of the FlaggingCallback abstract class
67
+ provided for illustrative purposes. Each flagged sample (both the input and output data)
68
+ is logged to a CSV file on the machine running the gradio app.
69
+ Example:
70
+ import gradio as gr
71
+ def image_classifier(inp):
72
+ return {'cat': 0.3, 'dog': 0.7}
73
+ demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
74
+ flagging_callback=SimpleCSVLogger())
75
+ """
76
+
77
+ def __init__(self):
78
+ pass
79
+
80
+ def setup(self, components: Sequence[Component], flagging_dir: str | Path):
81
+ self.components = components
82
+ self.flagging_dir = flagging_dir
83
+ os.makedirs(flagging_dir, exist_ok=True)
84
+
85
+ def flag(
86
+ self,
87
+ flag_data: list[Any],
88
+ flag_option: str = "", # noqa: ARG002
89
+ username: str | None = None, # noqa: ARG002
90
+ ) -> int:
91
+ flagging_dir = self.flagging_dir
92
+ log_filepath = Path(flagging_dir) / "log.csv"
93
+
94
+ csv_data = []
95
+ for component, sample in zip(self.components, flag_data):
96
+ save_dir = Path(
97
+ flagging_dir
98
+ ) / client_utils.strip_invalid_filename_characters(component.label or "")
99
+ save_dir.mkdir(exist_ok=True)
100
+ csv_data.append(
101
+ component.flag(
102
+ sample,
103
+ save_dir,
104
+ )
105
+ )
106
+
107
+ with open(log_filepath, "a", encoding="utf-8", newline="") as csvfile:
108
+ writer = csv.writer(csvfile)
109
+ writer.writerow(utils.sanitize_list_for_csv(csv_data))
110
+
111
+ with open(log_filepath, encoding="utf-8") as csvfile:
112
+ line_count = len(list(csv.reader(csvfile))) - 1
113
+ return line_count
114
+
115
+
116
+ @document()
117
+ class CSVLogger(FlaggingCallback):
118
+ """
119
+ The default implementation of the FlaggingCallback abstract class. Each flagged
120
+ sample (both the input and output data) is logged to a CSV file with headers on the machine running the gradio app.
121
+ Example:
122
+ import gradio as gr
123
+ def image_classifier(inp):
124
+ return {'cat': 0.3, 'dog': 0.7}
125
+ demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
126
+ flagging_callback=CSVLogger())
127
+ Guides: using-flagging
128
+ """
129
+
130
+ def __init__(self, simplify_file_data: bool = True):
131
+ self.simplify_file_data = simplify_file_data
132
+
133
+ def setup(
134
+ self,
135
+ components: Sequence[Component],
136
+ flagging_dir: str | Path,
137
+ ):
138
+ self.components = components
139
+ self.flagging_dir = flagging_dir
140
+ os.makedirs(flagging_dir, exist_ok=True)
141
+
142
+ def flag(
143
+ self,
144
+ flag_data: list[Any],
145
+ flag_option: str = "",
146
+ username: str | None = None,
147
+ ) -> int:
148
+ flagging_dir = self.flagging_dir
149
+ log_filepath = Path(flagging_dir) / "log.csv"
150
+ is_new = not Path(log_filepath).exists()
151
+ headers = [
152
+ getattr(component, "label", None) or f"component {idx}"
153
+ for idx, component in enumerate(self.components)
154
+ ] + [
155
+ "flag",
156
+ "username",
157
+ "timestamp",
158
+ ]
159
+
160
+ csv_data = []
161
+ for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
162
+ save_dir = Path(
163
+ flagging_dir
164
+ ) / client_utils.strip_invalid_filename_characters(
165
+ getattr(component, "label", None) or f"component {idx}"
166
+ )
167
+ if utils.is_prop_update(sample):
168
+ csv_data.append(str(sample))
169
+ else:
170
+ data = (
171
+ component.flag(sample, flag_dir=save_dir)
172
+ if sample is not None
173
+ else ""
174
+ )
175
+ if self.simplify_file_data:
176
+ data = utils.simplify_file_data_in_str(data)
177
+ csv_data.append(data)
178
+ csv_data.append(flag_option)
179
+ csv_data.append(username if username is not None else "")
180
+ csv_data.append(str(datetime.datetime.now()))
181
+
182
+ with open(log_filepath, "a", newline="", encoding="utf-8") as csvfile:
183
+ writer = csv.writer(csvfile)
184
+ if is_new:
185
+ writer.writerow(utils.sanitize_list_for_csv(headers))
186
+ writer.writerow(utils.sanitize_list_for_csv(csv_data))
187
+
188
+ with open(log_filepath, encoding="utf-8") as csvfile:
189
+ line_count = len(list(csv.reader(csvfile))) - 1
190
+ return line_count
191
+
192
+
193
+ @document()
194
+ class HuggingFaceDatasetSaver(FlaggingCallback):
195
+ """
196
+ A callback that saves each flagged sample (both the input and output data) to a HuggingFace dataset.
197
+
198
+ Example:
199
+ import gradio as gr
200
+ hf_writer = gr.HuggingFaceDatasetSaver(HF_API_TOKEN, "image-classification-mistakes")
201
+ def image_classifier(inp):
202
+ return {'cat': 0.3, 'dog': 0.7}
203
+ demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
204
+ allow_flagging="manual", flagging_callback=hf_writer)
205
+ Guides: using-flagging
206
+ """
207
+
208
+ def __init__(
209
+ self,
210
+ hf_token: str,
211
+ dataset_name: str,
212
+ private: bool = False,
213
+ info_filename: str = "dataset_info.json",
214
+ separate_dirs: bool = False,
215
+ ):
216
+ """
217
+ Parameters:
218
+ hf_token: The HuggingFace token to use to create (and write the flagged sample to) the HuggingFace dataset (defaults to the registered one).
219
+ dataset_name: The repo_id of the dataset to save the data to, e.g. "image-classifier-1" or "username/image-classifier-1".
220
+ private: Whether the dataset should be private (defaults to False).
221
+ info_filename: The name of the file to save the dataset info (defaults to "dataset_infos.json").
222
+ separate_dirs: If True, each flagged item will be saved in a separate directory. This makes the flagging more robust to concurrent editing, but may be less convenient to use.
223
+ """
224
+ self.hf_token = hf_token
225
+ self.dataset_id = dataset_name # TODO: rename parameter (but ensure backward compatibility somehow)
226
+ self.dataset_private = private
227
+ self.info_filename = info_filename
228
+ self.separate_dirs = separate_dirs
229
+
230
+ def setup(self, components: Sequence[Component], flagging_dir: str):
231
+ """
232
+ Params:
233
+ flagging_dir (str): local directory where the dataset is cloned,
234
+ updated, and pushed from.
235
+ """
236
+ # Setup dataset on the Hub
237
+ self.dataset_id = huggingface_hub.create_repo(
238
+ repo_id=self.dataset_id,
239
+ token=self.hf_token,
240
+ private=self.dataset_private,
241
+ repo_type="dataset",
242
+ exist_ok=True,
243
+ ).repo_id
244
+ path_glob = "**/*.jsonl" if self.separate_dirs else "data.csv"
245
+ huggingface_hub.metadata_update(
246
+ repo_id=self.dataset_id,
247
+ repo_type="dataset",
248
+ metadata={
249
+ "configs": [
250
+ {
251
+ "config_name": "default",
252
+ "data_files": [{"split": "train", "path": path_glob}],
253
+ }
254
+ ]
255
+ },
256
+ overwrite=True,
257
+ token=self.hf_token,
258
+ )
259
+
260
+ # Setup flagging dir
261
+ self.components = components
262
+ self.dataset_dir = (
263
+ Path(flagging_dir).absolute() / self.dataset_id.split("/")[-1]
264
+ )
265
+ self.dataset_dir.mkdir(parents=True, exist_ok=True)
266
+ self.infos_file = self.dataset_dir / self.info_filename
267
+
268
+ # Download remote files to local
269
+ remote_files = [self.info_filename]
270
+ if not self.separate_dirs:
271
+ # No separate dirs => means all data is in the same CSV file => download it to get its current content
272
+ remote_files.append("data.csv")
273
+
274
+ for filename in remote_files:
275
+ try:
276
+ huggingface_hub.hf_hub_download(
277
+ repo_id=self.dataset_id,
278
+ repo_type="dataset",
279
+ filename=filename,
280
+ local_dir=self.dataset_dir,
281
+ token=self.hf_token,
282
+ )
283
+ except huggingface_hub.utils.EntryNotFoundError:
284
+ pass
285
+
286
+ def flag(
287
+ self,
288
+ flag_data: list[Any],
289
+ flag_option: str = "",
290
+ username: str | None = None,
291
+ ) -> int:
292
+ if self.separate_dirs:
293
+ # JSONL files to support dataset preview on the Hub
294
+ unique_id = str(uuid.uuid4())
295
+ components_dir = self.dataset_dir / unique_id
296
+ data_file = components_dir / "metadata.jsonl"
297
+ path_in_repo = unique_id # upload in sub folder (safer for concurrency)
298
+ else:
299
+ # Unique CSV file
300
+ components_dir = self.dataset_dir
301
+ data_file = components_dir / "data.csv"
302
+ path_in_repo = None # upload at root level
303
+
304
+ return self._flag_in_dir(
305
+ data_file=data_file,
306
+ components_dir=components_dir,
307
+ path_in_repo=path_in_repo,
308
+ flag_data=flag_data,
309
+ flag_option=flag_option,
310
+ username=username or "",
311
+ )
312
+
313
+ def _flag_in_dir(
314
+ self,
315
+ data_file: Path,
316
+ components_dir: Path,
317
+ path_in_repo: str | None,
318
+ flag_data: list[Any],
319
+ flag_option: str = "",
320
+ username: str = "",
321
+ ) -> int:
322
+ # Deserialize components (write images/audio to files)
323
+ features, row = self._deserialize_components(
324
+ components_dir, flag_data, flag_option, username
325
+ )
326
+
327
+ # Write generic info to dataset_infos.json + upload
328
+ with filelock.FileLock(str(self.infos_file) + ".lock"):
329
+ if not self.infos_file.exists():
330
+ self.infos_file.write_text(
331
+ json.dumps({"flagged": {"features": features}})
332
+ )
333
+
334
+ huggingface_hub.upload_file(
335
+ repo_id=self.dataset_id,
336
+ repo_type="dataset",
337
+ token=self.hf_token,
338
+ path_in_repo=self.infos_file.name,
339
+ path_or_fileobj=self.infos_file,
340
+ )
341
+
342
+ headers = list(features.keys())
343
+
344
+ if not self.separate_dirs:
345
+ with filelock.FileLock(components_dir / ".lock"):
346
+ sample_nb = self._save_as_csv(data_file, headers=headers, row=row)
347
+ sample_name = str(sample_nb)
348
+ huggingface_hub.upload_folder(
349
+ repo_id=self.dataset_id,
350
+ repo_type="dataset",
351
+ commit_message=f"Flagged sample #{sample_name}",
352
+ path_in_repo=path_in_repo,
353
+ ignore_patterns="*.lock",
354
+ folder_path=components_dir,
355
+ token=self.hf_token,
356
+ )
357
+ else:
358
+ sample_name = self._save_as_jsonl(data_file, headers=headers, row=row)
359
+ sample_nb = len(
360
+ [path for path in self.dataset_dir.iterdir() if path.is_dir()]
361
+ )
362
+ huggingface_hub.upload_folder(
363
+ repo_id=self.dataset_id,
364
+ repo_type="dataset",
365
+ commit_message=f"Flagged sample #{sample_name}",
366
+ path_in_repo=path_in_repo,
367
+ ignore_patterns="*.lock",
368
+ folder_path=components_dir,
369
+ token=self.hf_token,
370
+ )
371
+
372
+ return sample_nb
373
+
374
+ @staticmethod
375
+ def _save_as_csv(data_file: Path, headers: list[str], row: list[Any]) -> int:
376
+ """Save data as CSV and return the sample name (row number)."""
377
+ is_new = not data_file.exists()
378
+
379
+ with data_file.open("a", newline="", encoding="utf-8") as csvfile:
380
+ writer = csv.writer(csvfile)
381
+
382
+ # Write CSV headers if new file
383
+ if is_new:
384
+ writer.writerow(utils.sanitize_list_for_csv(headers))
385
+
386
+ # Write CSV row for flagged sample
387
+ writer.writerow(utils.sanitize_list_for_csv(row))
388
+
389
+ with data_file.open(encoding="utf-8") as csvfile:
390
+ return sum(1 for _ in csv.reader(csvfile)) - 1
391
+
392
+ @staticmethod
393
+ def _save_as_jsonl(data_file: Path, headers: list[str], row: list[Any]) -> str:
394
+ """Save data as JSONL and return the sample name (uuid)."""
395
+ Path.mkdir(data_file.parent, parents=True, exist_ok=True)
396
+ with open(data_file, "w", encoding="utf-8") as f:
397
+ json.dump(dict(zip(headers, row)), f)
398
+ return data_file.parent.name
399
+
400
+ def _deserialize_components(
401
+ self,
402
+ data_dir: Path,
403
+ flag_data: list[Any],
404
+ flag_option: str = "",
405
+ username: str = "",
406
+ ) -> tuple[dict[Any, Any], list[Any]]:
407
+ """Deserialize components and return the corresponding row for the flagged sample.
408
+
409
+ Images/audio are saved to disk as individual files.
410
+ """
411
+ # Components that can have a preview on dataset repos
412
+ file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
413
+
414
+ # Generate the row corresponding to the flagged sample
415
+ features = OrderedDict()
416
+ row = []
417
+ for component, sample in zip(self.components, flag_data):
418
+ # Get deserialized object (will save sample to disk if applicable -file, audio, image,...-)
419
+ label = component.label or ""
420
+ save_dir = data_dir / client_utils.strip_invalid_filename_characters(label)
421
+ save_dir.mkdir(exist_ok=True, parents=True)
422
+ deserialized = utils.simplify_file_data_in_str(
423
+ component.flag(sample, save_dir)
424
+ )
425
+
426
+ # Add deserialized object to row
427
+ features[label] = {"dtype": "string", "_type": "Value"}
428
+ try:
429
+ deserialized_path = Path(deserialized)
430
+ if not deserialized_path.exists():
431
+ raise FileNotFoundError(f"File {deserialized} not found")
432
+ row.append(str(deserialized_path.relative_to(self.dataset_dir)))
433
+ except (FileNotFoundError, TypeError, ValueError, OSError):
434
+ deserialized = "" if deserialized is None else str(deserialized)
435
+ row.append(deserialized)
436
+
437
+ # If component is eligible for a preview, add the URL of the file
438
+ # Be mindful that images and audio can be None
439
+ if isinstance(component, tuple(file_preview_types)): # type: ignore
440
+ for _component, _type in file_preview_types.items():
441
+ if isinstance(component, _component):
442
+ features[label + " file"] = {"_type": _type}
443
+ break
444
+ if deserialized:
445
+ path_in_repo = str( # returned filepath is absolute, we want it relative to compute URL
446
+ Path(deserialized).relative_to(self.dataset_dir)
447
+ ).replace("\\", "/")
448
+ row.append(
449
+ huggingface_hub.hf_hub_url(
450
+ repo_id=self.dataset_id,
451
+ filename=path_in_repo,
452
+ repo_type="dataset",
453
+ )
454
+ )
455
+ else:
456
+ row.append("")
457
+ features["flag"] = {"dtype": "string", "_type": "Value"}
458
+ features["username"] = {"dtype": "string", "_type": "Value"}
459
+ row.append(flag_option)
460
+ row.append(username)
461
+ return features, row
462
+
463
+
464
+ class FlagMethod:
465
+ """
466
+ Helper class that contains the flagging options and calls the flagging method. Also
467
+ provides visual feedback to the user when flag is clicked.
468
+ """
469
+
470
+ def __init__(
471
+ self,
472
+ flagging_callback: FlaggingCallback,
473
+ label: str,
474
+ value: str,
475
+ visual_feedback: bool = True,
476
+ ):
477
+ self.flagging_callback = flagging_callback
478
+ self.label = label
479
+ self.value = value
480
+ self.__name__ = "Flag"
481
+ self.visual_feedback = visual_feedback
482
+
483
+ def __call__(self, request: gr.Request, *flag_data):
484
+ try:
485
+ self.flagging_callback.flag(
486
+ list(flag_data), flag_option=self.value, username=request.username
487
+ )
488
+ except Exception as e:
489
+ print(f"Error while flagging: {e}")
490
+ if self.visual_feedback:
491
+ return "Error!"
492
+ if not self.visual_feedback:
493
+ return
494
+ time.sleep(0.8) # to provide enough time for the user to observe button change
495
+ return self.reset()
496
+
497
+ def reset(self):
498
+ return gr.Button(value=self.label, interactive=True)