Upload operators.py with huggingface_hub
Browse files- operators.py +709 -162
operators.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
import collections
|
2 |
import importlib
|
3 |
-
import inspect
|
4 |
import uuid
|
5 |
from abc import abstractmethod
|
|
|
6 |
from copy import deepcopy
|
7 |
from dataclasses import field
|
8 |
from itertools import zip_longest
|
@@ -19,7 +19,7 @@ from typing import (
|
|
19 |
)
|
20 |
|
21 |
from .artifact import Artifact, fetch_artifact
|
22 |
-
from .dataclass import NonPositionalField
|
23 |
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
|
24 |
from .operator import (
|
25 |
MultiStream,
|
@@ -32,15 +32,14 @@ from .operator import (
|
|
32 |
StreamInstanceOperator,
|
33 |
StreamSource,
|
34 |
)
|
35 |
-
from .random_utils import
|
36 |
-
from .stream import
|
37 |
from .text_utils import nested_tuple_to_string
|
38 |
from .utils import flatten_dict
|
39 |
|
40 |
|
41 |
class FromIterables(StreamInitializerOperator):
|
42 |
-
"""
|
43 |
-
Creates a MultiStream from iterables.
|
44 |
|
45 |
Args:
|
46 |
iterables (Dict[str, Iterable]): A dictionary where each key-value pair represents a stream name and its corresponding iterable.
|
@@ -70,35 +69,83 @@ class MapInstanceValues(StreamInstanceOperator):
|
|
70 |
strict (bool): If True, the mapping is applied strictly. That means if a value
|
71 |
does not exist in the mapper, it will raise a KeyError. If False, values
|
72 |
that are not present in the mapper are kept as they are.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
"""
|
74 |
|
75 |
mappers: Dict[str, Dict[str, str]]
|
76 |
strict: bool = True
|
77 |
-
use_query = False
|
|
|
78 |
|
79 |
def verify(self):
|
80 |
# make sure the mappers are valid
|
81 |
for key, mapper in self.mappers.items():
|
82 |
-
assert isinstance(
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
for key, mapper in self.mappers.items():
|
88 |
value = dict_get(instance, key, use_dpath=self.use_query)
|
89 |
if value is not None:
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
if value in mapper:
|
95 |
dict_set(instance, key, mapper[value], use_dpath=self.use_query)
|
|
|
96 |
return instance
|
97 |
|
98 |
|
99 |
class FlattenInstances(StreamInstanceOperator):
|
100 |
-
"""
|
101 |
-
Flattens each instance in a stream, making nested dictionary entries into top-level entries.
|
102 |
|
103 |
Args:
|
104 |
parent_key (str): A prefix to use for the flattened keys. Defaults to an empty string.
|
@@ -108,23 +155,42 @@ class FlattenInstances(StreamInstanceOperator):
|
|
108 |
parent_key: str = ""
|
109 |
sep: str = "_"
|
110 |
|
111 |
-
def process(
|
|
|
|
|
112 |
return flatten_dict(instance, parent_key=self.parent_key, sep=self.sep)
|
113 |
|
114 |
|
115 |
class AddFields(StreamInstanceOperator):
|
116 |
-
"""
|
117 |
-
Adds specified fields to each instance in a stream.
|
118 |
|
119 |
Args:
|
120 |
fields (Dict[str, object]): The fields to add to each instance.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
"""
|
122 |
|
123 |
fields: Dict[str, object]
|
124 |
use_query: bool = False
|
125 |
use_deepcopy: bool = False
|
126 |
|
127 |
-
def process(
|
|
|
|
|
128 |
if self.use_query:
|
129 |
for key, value in self.fields.items():
|
130 |
if self.use_deepcopy:
|
@@ -138,30 +204,31 @@ class AddFields(StreamInstanceOperator):
|
|
138 |
|
139 |
|
140 |
class RemoveFields(StreamInstanceOperator):
|
141 |
-
"""
|
142 |
-
Adds specified fields to each instance in a stream.
|
143 |
|
144 |
Args:
|
145 |
-
fields (
|
146 |
"""
|
147 |
|
148 |
fields: List[str]
|
149 |
|
150 |
-
def process(
|
151 |
-
|
152 |
-
|
|
|
|
|
153 |
return instance
|
154 |
|
155 |
|
156 |
class FieldOperator(StreamInstanceOperator):
|
157 |
-
"""
|
158 |
-
|
159 |
Args:
|
160 |
field (Optional[str]): The field to process, if only a single one is passed Defaults to None
|
161 |
to_field (Optional[str]): Field name to save, if only one field is to be saved, if None is passed the operator would happen in-place and replace "field" Defaults to None
|
162 |
field_to_field (Optional[Union[List[Tuple[str, str]], Dict[str, str]]]): Mapping from fields to process to their names after this process, duplicates are allowed. Defaults to None
|
163 |
process_every_value (bool): Processes the values in a list instead of the list as a value, similar to *var. Defaults to False
|
164 |
-
use_query (bool): Whether to use dpath style queries. Defaults to False
|
165 |
"""
|
166 |
|
167 |
field: Optional[str] = None
|
@@ -175,14 +242,18 @@ class FieldOperator(StreamInstanceOperator):
|
|
175 |
def verify(self):
|
176 |
super().verify()
|
177 |
|
178 |
-
assert
|
|
|
|
|
179 |
assert (
|
180 |
self.to_field is None or self.field_to_field is None
|
181 |
), f"Can not apply operator to create both on {self.to_field} and on the mapping from fields to fields {self.field_to_field}"
|
182 |
assert (
|
183 |
self.field is None or self.field_to_field is None
|
184 |
), f"Can not apply operator both on {self.field} and on the mapping from fields to fields {self.field_to_field}"
|
185 |
-
assert
|
|
|
|
|
186 |
|
187 |
@abstractmethod
|
188 |
def process_value(self, value: Any) -> Any:
|
@@ -195,11 +266,13 @@ class FieldOperator(StreamInstanceOperator):
|
|
195 |
self._field_to_field = [(self.field, self.to_field)]
|
196 |
else:
|
197 |
try:
|
198 |
-
self._field_to_field =
|
199 |
except AttributeError:
|
200 |
self._field_to_field = self.field_to_field
|
201 |
|
202 |
-
def process(
|
|
|
|
|
203 |
for from_field, to_field in self._field_to_field:
|
204 |
try:
|
205 |
old_value = dict_get(
|
@@ -209,27 +282,40 @@ class FieldOperator(StreamInstanceOperator):
|
|
209 |
default=self.get_default,
|
210 |
not_exist_ok=self.not_exist_ok,
|
211 |
)
|
212 |
-
except
|
213 |
-
raise
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
if self.use_query and is_subpath(from_field, to_field):
|
219 |
dict_delete(instance, from_field)
|
220 |
-
dict_set(
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
return instance
|
222 |
|
223 |
|
224 |
class RenameFields(FieldOperator):
|
225 |
-
"""
|
226 |
-
Renames fields
|
227 |
-
"""
|
228 |
|
229 |
def process_value(self, value: Any) -> Any:
|
230 |
return value
|
231 |
|
232 |
-
def process(
|
|
|
|
|
233 |
res = super().process(instance=instance, stream_name=stream_name)
|
234 |
vals = [x[1] for x in self._field_to_field]
|
235 |
for key, _ in self._field_to_field:
|
@@ -241,32 +327,202 @@ class RenameFields(FieldOperator):
|
|
241 |
|
242 |
|
243 |
class AddConstant(FieldOperator):
|
|
|
|
|
|
|
|
|
244 |
"""
|
245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
Args:
|
247 |
-
|
|
|
|
|
248 |
"""
|
249 |
|
250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
|
|
|
252 |
def process_value(self, value: Any) -> Any:
|
253 |
-
|
|
|
|
|
|
|
254 |
|
|
|
|
|
|
|
|
|
255 |
|
256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
"""
|
258 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
"""
|
260 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
def process_value(self, value: Any) -> Any:
|
262 |
res = list(value)
|
263 |
-
|
264 |
return res
|
265 |
|
266 |
|
267 |
class JoinStr(FieldOperator):
|
268 |
-
"""
|
269 |
-
|
270 |
Args:
|
271 |
separator (str): text to put between values
|
272 |
"""
|
@@ -278,6 +534,25 @@ class JoinStr(FieldOperator):
|
|
278 |
|
279 |
|
280 |
class Apply(StreamInstanceOperator):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
__allow_unexpected_arguments__ = True
|
282 |
function: Callable = NonPositionalField(required=True)
|
283 |
to_field: str = NonPositionalField(required=True)
|
@@ -292,25 +567,23 @@ class Apply(StreamInstanceOperator):
|
|
292 |
else:
|
293 |
parts.append(function.__name__)
|
294 |
|
295 |
-
|
296 |
-
|
297 |
-
return result
|
298 |
|
299 |
def str_to_function(self, function_str: str) -> Callable:
|
300 |
splitted = function_str.split(".", 1)
|
301 |
if len(splitted) == 1:
|
302 |
-
return __builtins__[
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
else:
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
obj = globals()[module_name]
|
309 |
-
else:
|
310 |
-
obj = importlib.import_module(module_name)
|
311 |
-
for part in function_name.split("."):
|
312 |
-
obj = getattr(obj, part)
|
313 |
-
return obj
|
314 |
|
315 |
def prepare(self):
|
316 |
super().prepare()
|
@@ -318,7 +591,9 @@ class Apply(StreamInstanceOperator):
|
|
318 |
self.function = self.str_to_function(self.function)
|
319 |
self._init_dict["function"] = self.function_to_str(self.function)
|
320 |
|
321 |
-
def process(
|
|
|
|
|
322 |
argv = [instance[arg] for arg in self._argv]
|
323 |
kwargs = {key: instance[val] for key, val in self._kwargs}
|
324 |
|
@@ -329,36 +604,36 @@ class Apply(StreamInstanceOperator):
|
|
329 |
|
330 |
|
331 |
class ListFieldValues(StreamInstanceOperator):
|
332 |
-
"""
|
333 |
-
Concatanates values of multiple fields into a list to list(fields)
|
334 |
-
"""
|
335 |
|
336 |
-
fields: str
|
337 |
to_field: str
|
338 |
use_query: bool = False
|
339 |
|
340 |
-
def process(
|
|
|
|
|
341 |
values = []
|
342 |
-
for
|
343 |
-
values.append(dict_get(instance,
|
344 |
instance[self.to_field] = values
|
345 |
return instance
|
346 |
|
347 |
|
348 |
class ZipFieldValues(StreamInstanceOperator):
|
349 |
-
"""
|
350 |
-
Zips values of multiple fields similar to list(zip(*fields))
|
351 |
-
"""
|
352 |
|
353 |
fields: str
|
354 |
to_field: str
|
355 |
longest: bool = False
|
356 |
use_query: bool = False
|
357 |
|
358 |
-
def process(
|
|
|
|
|
359 |
values = []
|
360 |
-
for
|
361 |
-
values.append(dict_get(instance,
|
362 |
if self.longest:
|
363 |
zipped = zip_longest(*values)
|
364 |
else:
|
@@ -368,16 +643,16 @@ class ZipFieldValues(StreamInstanceOperator):
|
|
368 |
|
369 |
|
370 |
class IndexOf(StreamInstanceOperator):
|
371 |
-
"""
|
372 |
-
Finds the location of one value in another (iterable) value similar to to_field=search_in.index(index_of)
|
373 |
-
"""
|
374 |
|
375 |
search_in: str
|
376 |
index_of: str
|
377 |
to_field: str
|
378 |
use_query: bool = False
|
379 |
|
380 |
-
def process(
|
|
|
|
|
381 |
lst = dict_get(instance, self.search_in, use_dpath=self.use_query)
|
382 |
item = dict_get(instance, self.index_of, use_dpath=self.use_query)
|
383 |
instance[self.to_field] = lst.index(item)
|
@@ -385,9 +660,7 @@ class IndexOf(StreamInstanceOperator):
|
|
385 |
|
386 |
|
387 |
class TakeByField(StreamInstanceOperator):
|
388 |
-
"""
|
389 |
-
Takes value from one field based on another field similar to field[index]
|
390 |
-
"""
|
391 |
|
392 |
field: str
|
393 |
index: str
|
@@ -398,7 +671,9 @@ class TakeByField(StreamInstanceOperator):
|
|
398 |
if self.to_field is None:
|
399 |
self.to_field = self.field
|
400 |
|
401 |
-
def process(
|
|
|
|
|
402 |
value = dict_get(instance, self.field, use_dpath=self.use_query)
|
403 |
index_value = dict_get(instance, self.index, use_dpath=self.use_query)
|
404 |
instance[self.to_field] = value[index_value]
|
@@ -406,8 +681,7 @@ class TakeByField(StreamInstanceOperator):
|
|
406 |
|
407 |
|
408 |
class CopyFields(FieldOperator):
|
409 |
-
"""
|
410 |
-
Copies specified fields from one field to another.
|
411 |
|
412 |
Args:
|
413 |
field_to_field (Union[List[List], Dict[str, str]]): A list of lists, where each sublist contains the source field and the destination field, or a dictionary mapping source fields to destination fields.
|
@@ -421,14 +695,15 @@ class CopyFields(FieldOperator):
|
|
421 |
class AddID(StreamInstanceOperator):
|
422 |
id_field_name: str = "id"
|
423 |
|
424 |
-
def process(
|
|
|
|
|
425 |
instance[self.id_field_name] = str(uuid.uuid4()).replace("-", "")
|
426 |
return instance
|
427 |
|
428 |
|
429 |
class CastFields(StreamInstanceOperator):
|
430 |
-
"""
|
431 |
-
Casts specified fields to specified types.
|
432 |
|
433 |
Args:
|
434 |
types (Dict[str, str]): A dictionary mapping fields to their new types.
|
@@ -451,24 +726,28 @@ class CastFields(StreamInstanceOperator):
|
|
451 |
def _cast_single(self, value, type, field):
|
452 |
try:
|
453 |
return self.types[type](value)
|
454 |
-
except:
|
455 |
if field not in self.failure_defaults:
|
456 |
raise ValueError(
|
457 |
f'Failed to cast field "{field}" with value {value} to type "{type}", and no default value is provided.'
|
458 |
-
)
|
459 |
return self.failure_defaults[field]
|
460 |
|
461 |
def _cast_multiple(self, values, type, field):
|
462 |
values = [self._cast_single(value, type, field) for value in values]
|
463 |
|
464 |
-
def process(
|
465 |
-
|
466 |
-
|
|
|
|
|
467 |
if self.cast_multiple:
|
468 |
-
casted_value = self._cast_multiple(value, type,
|
469 |
else:
|
470 |
-
casted_value = self._cast_single(value, type,
|
471 |
-
dict_set(
|
|
|
|
|
472 |
return instance
|
473 |
|
474 |
|
@@ -491,13 +770,14 @@ class DivideAllFieldsBy(StreamInstanceOperator):
|
|
491 |
strict: bool = False
|
492 |
recursive: bool = True
|
493 |
|
494 |
-
def process(
|
|
|
|
|
495 |
return recursive_divide(instance, self.divisor, strict=self.strict)
|
496 |
|
497 |
|
498 |
class ArtifactFetcherMixin:
|
499 |
-
"""
|
500 |
-
Provides a way to fetch and cache artifacts in the system.
|
501 |
|
502 |
Args:
|
503 |
cache (Dict[str, Artifact]): A cache for storing fetched artifacts.
|
@@ -514,8 +794,7 @@ class ArtifactFetcherMixin:
|
|
514 |
|
515 |
|
516 |
class ApplyOperatorsField(StreamInstanceOperator, ArtifactFetcherMixin):
|
517 |
-
"""
|
518 |
-
Applies value operators to each instance in a stream based on specified fields.
|
519 |
|
520 |
Args:
|
521 |
value_field (str): The field containing the value to be operated on.
|
@@ -529,7 +808,9 @@ class ApplyOperatorsField(StreamInstanceOperator, ArtifactFetcherMixin):
|
|
529 |
default_operators: List[str] = None
|
530 |
fields_to_treat_as_list: List[str] = NonPositionalField(default_factory=list)
|
531 |
|
532 |
-
def process(
|
|
|
|
|
533 |
operator_names = instance.get(self.operators_field)
|
534 |
if operator_names is None:
|
535 |
assert (
|
@@ -542,35 +823,228 @@ class ApplyOperatorsField(StreamInstanceOperator, ArtifactFetcherMixin):
|
|
542 |
|
543 |
for name in operator_names:
|
544 |
operator = self.get_artifact(name)
|
545 |
-
for
|
546 |
-
value = instance[
|
547 |
-
if
|
548 |
-
instance[
|
549 |
else:
|
550 |
-
instance[
|
551 |
|
552 |
return instance
|
553 |
|
554 |
|
555 |
class FilterByValues(SingleStreamOperator):
|
|
|
|
|
|
|
|
|
556 |
"""
|
557 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
558 |
|
559 |
Args:
|
560 |
-
|
561 |
"""
|
562 |
|
563 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
564 |
|
565 |
-
def process(self, stream: Stream, stream_name: str = None) -> Generator:
|
566 |
for instance in stream:
|
567 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
568 |
yield instance
|
569 |
|
570 |
|
571 |
-
class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
572 |
"""
|
573 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
574 |
|
575 |
Args:
|
576 |
fields (List[str]): The fields that should be unique in each instance.
|
@@ -581,8 +1055,8 @@ class Unique(SingleStreamReducer):
|
|
581 |
@staticmethod
|
582 |
def to_tuple(instance: dict, fields: List[str]) -> tuple:
|
583 |
result = []
|
584 |
-
for
|
585 |
-
value = instance[
|
586 |
if isinstance(value, list):
|
587 |
value = tuple(value)
|
588 |
result.append(value)
|
@@ -598,8 +1072,7 @@ class Unique(SingleStreamReducer):
|
|
598 |
|
599 |
|
600 |
class SplitByValue(MultiStreamOperator):
|
601 |
-
"""
|
602 |
-
Splits a MultiStream into multiple streams based on unique values in specified fields.
|
603 |
|
604 |
Args:
|
605 |
fields (List[str]): The fields to use when splitting the MultiStream.
|
@@ -615,17 +1088,20 @@ class SplitByValue(MultiStreamOperator):
|
|
615 |
for stream_name, stream in multi_stream.items():
|
616 |
stream_unique_values = uniques[stream_name]
|
617 |
for unique_values in stream_unique_values:
|
618 |
-
filtering_values =
|
619 |
-
filtered_streams = FilterByValues(
|
620 |
-
|
|
|
|
|
|
|
|
|
621 |
result[filtered_stream_name] = filtered_streams
|
622 |
|
623 |
return MultiStream(result)
|
624 |
|
625 |
|
626 |
class ApplyStreamOperatorsField(SingleStreamOperator, ArtifactFetcherMixin):
|
627 |
-
"""
|
628 |
-
Applies stream operators to a stream based on specified fields in each instance.
|
629 |
|
630 |
Args:
|
631 |
field (str): The field containing the operators to be applied.
|
@@ -635,7 +1111,7 @@ class ApplyStreamOperatorsField(SingleStreamOperator, ArtifactFetcherMixin):
|
|
635 |
field: str
|
636 |
reversed: bool = False
|
637 |
|
638 |
-
def process(self, stream: Stream, stream_name: str = None) -> Generator:
|
639 |
first_instance = stream.peak()
|
640 |
|
641 |
operators = first_instance.get(self.field, [])
|
@@ -647,16 +1123,67 @@ class ApplyStreamOperatorsField(SingleStreamOperator, ArtifactFetcherMixin):
|
|
647 |
|
648 |
for operator_name in operators:
|
649 |
operator = self.get_artifact(operator_name)
|
650 |
-
assert isinstance(
|
|
|
|
|
651 |
|
652 |
stream = operator(MultiStream({"tmp": stream}))["tmp"]
|
653 |
|
654 |
yield from stream
|
655 |
|
656 |
|
657 |
-
class
|
|
|
|
|
|
|
|
|
|
|
658 |
"""
|
659 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
660 |
|
661 |
Args:
|
662 |
prefix_dict (Dict[str, str]): A dictionary mapping stream names to prefixes.
|
@@ -667,13 +1194,17 @@ class AddFieldNamePrefix(StreamInstanceOperator):
|
|
667 |
def prepare(self):
|
668 |
return super().prepare()
|
669 |
|
670 |
-
def process(
|
671 |
-
|
|
|
|
|
|
|
|
|
|
|
672 |
|
673 |
|
674 |
class MergeStreams(MultiStreamOperator):
|
675 |
-
"""
|
676 |
-
Merges multiple streams into a single stream.
|
677 |
|
678 |
Args:
|
679 |
new_stream_name (str): The name of the new stream resulting from the merge.
|
@@ -681,37 +1212,43 @@ class MergeStreams(MultiStreamOperator):
|
|
681 |
origin_stream_name_field_name (str): The field name for the origin stream name.
|
682 |
"""
|
683 |
|
|
|
684 |
new_stream_name: str = "all"
|
685 |
add_origin_stream_name: bool = True
|
686 |
origin_stream_name_field_name: str = "origin"
|
687 |
|
688 |
def merge(self, multi_stream):
|
689 |
for stream_name, stream in multi_stream.items():
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
|
|
694 |
|
695 |
def process(self, multi_stream: MultiStream) -> MultiStream:
|
696 |
-
return MultiStream(
|
|
|
|
|
|
|
|
|
|
|
|
|
697 |
|
698 |
|
699 |
class Shuffle(PagedStreamOperator):
|
700 |
-
"""
|
701 |
-
Shuffles the order of instances in each page of a stream.
|
702 |
|
703 |
Args:
|
704 |
page_size (int): The size of each page in the stream. Defaults to 1000.
|
705 |
"""
|
706 |
|
707 |
-
def process(self, page: List[Dict], stream_name: str = None) -> Generator:
|
708 |
-
|
709 |
yield from page
|
710 |
|
711 |
|
712 |
class EncodeLabels(StreamInstanceOperator):
|
713 |
-
"""
|
714 |
-
Encode labels of specified fields together a into integers.
|
715 |
|
716 |
Args:
|
717 |
fields (List[str]): The fields to encode together.
|
@@ -723,16 +1260,20 @@ class EncodeLabels(StreamInstanceOperator):
|
|
723 |
self.encoder = {}
|
724 |
return super()._process_multi_stream(multi_stream)
|
725 |
|
726 |
-
def process(
|
727 |
-
|
728 |
-
|
|
|
|
|
729 |
if not isinstance(values, list):
|
730 |
values = [values]
|
731 |
for value in values:
|
732 |
if value not in self.encoder:
|
733 |
self.encoder[value] = len(self.encoder)
|
734 |
new_values = [self.encoder[value] for value in values]
|
735 |
-
dict_set(
|
|
|
|
|
736 |
|
737 |
return instance
|
738 |
|
@@ -740,7 +1281,7 @@ class EncodeLabels(StreamInstanceOperator):
|
|
740 |
class StreamRefiner(SingleStreamOperator):
|
741 |
max_instances: int = None
|
742 |
|
743 |
-
def process(self, stream: Stream, stream_name: str = None) -> Generator:
|
744 |
if self.max_instances is not None:
|
745 |
yield from stream.take(self.max_instances)
|
746 |
else:
|
@@ -748,8 +1289,7 @@ class StreamRefiner(SingleStreamOperator):
|
|
748 |
|
749 |
|
750 |
class DeterministicBalancer(StreamRefiner):
|
751 |
-
"""
|
752 |
-
A class used to balance streams deterministically.
|
753 |
|
754 |
Attributes:
|
755 |
fields (List[str]): A list of field names to be used in determining the signature of an instance.
|
@@ -763,19 +1303,26 @@ class DeterministicBalancer(StreamRefiner):
|
|
763 |
fields: List[str]
|
764 |
|
765 |
def signature(self, instance):
|
766 |
-
return str(
|
|
|
|
|
767 |
|
768 |
-
def process(self, stream: Stream, stream_name: str = None) -> Generator:
|
769 |
counter = collections.Counter()
|
770 |
|
771 |
for instance in stream:
|
772 |
counter[self.signature(instance)] += 1
|
773 |
|
|
|
|
|
|
|
774 |
lowest_count = counter.most_common()[-1][-1]
|
775 |
|
776 |
max_total_instances_per_sign = lowest_count
|
777 |
if self.max_instances is not None:
|
778 |
-
max_total_instances_per_sign = min(
|
|
|
|
|
779 |
|
780 |
counter = collections.Counter()
|
781 |
|
@@ -791,8 +1338,8 @@ class LengthBalancer(DeterministicBalancer):
|
|
791 |
|
792 |
def signature(self, instance):
|
793 |
total_len = 0
|
794 |
-
for
|
795 |
-
total_len += len(dict_get(instance,
|
796 |
for i, val in enumerate(self.segments_boundaries):
|
797 |
if total_len < val:
|
798 |
return i
|
|
|
1 |
import collections
|
2 |
import importlib
|
|
|
3 |
import uuid
|
4 |
from abc import abstractmethod
|
5 |
+
from collections import Counter
|
6 |
from copy import deepcopy
|
7 |
from dataclasses import field
|
8 |
from itertools import zip_longest
|
|
|
19 |
)
|
20 |
|
21 |
from .artifact import Artifact, fetch_artifact
|
22 |
+
from .dataclass import NonPositionalField
|
23 |
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
|
24 |
from .operator import (
|
25 |
MultiStream,
|
|
|
32 |
StreamInstanceOperator,
|
33 |
StreamSource,
|
34 |
)
|
35 |
+
from .random_utils import get_random, nested_seed
|
36 |
+
from .stream import Stream
|
37 |
from .text_utils import nested_tuple_to_string
|
38 |
from .utils import flatten_dict
|
39 |
|
40 |
|
41 |
class FromIterables(StreamInitializerOperator):
|
42 |
+
"""Creates a MultiStream from iterables.
|
|
|
43 |
|
44 |
Args:
|
45 |
iterables (Dict[str, Iterable]): A dictionary where each key-value pair represents a stream name and its corresponding iterable.
|
|
|
69 |
strict (bool): If True, the mapping is applied strictly. That means if a value
|
70 |
does not exist in the mapper, it will raise a KeyError. If False, values
|
71 |
that are not present in the mapper are kept as they are.
|
72 |
+
process_every_value (bool): If True, all fields to be mapped should be lists, and the mapping
|
73 |
+
is to be applied to their individual elements. If False, mapping is only applied to a field
|
74 |
+
containing a single value.
|
75 |
+
|
76 |
+
Examples:
|
77 |
+
MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}})
|
78 |
+
replaces '1' with 'hi' and '2' with 'bye' in field 'a' in all instances of all streams:
|
79 |
+
instance {"a":"1", "b": 2} becomes {"a":"hi", "b": 2}.
|
80 |
+
|
81 |
+
MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}}, process_every_element=True)
|
82 |
+
Assuming field 'a' is a list of values, potentially including "1"-s and "2"-s, this replaces
|
83 |
+
each such "1" with "hi" and "2" -- with "bye" in all instances of all streams:
|
84 |
+
instance {"a": ["1", "2"], "b": 2} becomes {"a": ["hi", "bye"], "b": 2}.
|
85 |
+
|
86 |
+
MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}}, strict=True)
|
87 |
+
To ensure that all values of field 'a' are mapped in every instance, use strict=True.
|
88 |
+
Input instance {"a":"3", "b": 2} will raise an exception per the above call,
|
89 |
+
because "3" is not a key in the mapper of "a".
|
90 |
"""
|
91 |
|
92 |
mappers: Dict[str, Dict[str, str]]
|
93 |
strict: bool = True
|
94 |
+
use_query: bool = False
|
95 |
+
process_every_value: bool = False
|
96 |
|
97 |
def verify(self):
|
98 |
# make sure the mappers are valid
|
99 |
for key, mapper in self.mappers.items():
|
100 |
+
assert isinstance(
|
101 |
+
mapper, dict
|
102 |
+
), f"Mapper for given field {key} should be a dict, got {type(mapper)}"
|
103 |
+
for k in mapper.keys():
|
104 |
+
assert isinstance(
|
105 |
+
k, str
|
106 |
+
), f'Key "{k}" in mapper for field "{key}" should be a string, got {type(k)}'
|
107 |
+
|
108 |
+
def process(
|
109 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
110 |
+
) -> Dict[str, Any]:
|
111 |
for key, mapper in self.mappers.items():
|
112 |
value = dict_get(instance, key, use_dpath=self.use_query)
|
113 |
if value is not None:
|
114 |
+
if (self.process_every_value is True) and (not isinstance(value, list)):
|
115 |
+
raise ValueError(
|
116 |
+
f"'process_every_field' == True is allowed only when all fields which have mappers, i.e., {list(self.mappers.keys())} are lists. Instace = {instance}"
|
117 |
+
)
|
118 |
+
if isinstance(value, list):
|
119 |
+
if self.process_every_value:
|
120 |
+
for i, val in enumerate(value):
|
121 |
+
val = str(val) # make sure the value is a string
|
122 |
+
if self.strict and (val not in mapper):
|
123 |
+
raise KeyError(
|
124 |
+
f"value '{val}' in instance '{instance}' is not found in mapper '{mapper}', associated with field '{key}'."
|
125 |
+
)
|
126 |
+
if val in mapper:
|
127 |
+
# replace just that member of value (value is a list)
|
128 |
+
value[i] = mapper[val]
|
129 |
+
dict_set(instance, key, value, use_dpath=self.use_query)
|
130 |
+
else: # field is a list, and process_every_value == False
|
131 |
+
if self.strict: # whole lists can not be mapped by a string-to-something mapper
|
132 |
+
raise KeyError(
|
133 |
+
f"A whole list ({value}) in the instance can not be mapped by a field mapper."
|
134 |
+
)
|
135 |
+
else: # value is not a list, implying process_every_value == False
|
136 |
+
value = str(value) # make sure the value is a string
|
137 |
+
if self.strict and (value not in mapper):
|
138 |
+
raise KeyError(
|
139 |
+
f"value '{value}' in instance '{instance}' is not found in mapper '{mapper}', associated with field '{key}'."
|
140 |
+
)
|
141 |
if value in mapper:
|
142 |
dict_set(instance, key, mapper[value], use_dpath=self.use_query)
|
143 |
+
|
144 |
return instance
|
145 |
|
146 |
|
147 |
class FlattenInstances(StreamInstanceOperator):
|
148 |
+
"""Flattens each instance in a stream, making nested dictionary entries into top-level entries.
|
|
|
149 |
|
150 |
Args:
|
151 |
parent_key (str): A prefix to use for the flattened keys. Defaults to an empty string.
|
|
|
155 |
parent_key: str = ""
|
156 |
sep: str = "_"
|
157 |
|
158 |
+
def process(
|
159 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
160 |
+
) -> Dict[str, Any]:
|
161 |
return flatten_dict(instance, parent_key=self.parent_key, sep=self.sep)
|
162 |
|
163 |
|
164 |
class AddFields(StreamInstanceOperator):
|
165 |
+
"""Adds specified fields to each instance in a given stream or all streams (default) If fields exist, updates them.
|
|
|
166 |
|
167 |
Args:
|
168 |
fields (Dict[str, object]): The fields to add to each instance.
|
169 |
+
use_query (bool) : Use '/' to access inner fields
|
170 |
+
use_deepcopy (bool) : Deep copy the input value to avoid later modifications
|
171 |
+
|
172 |
+
Examples:
|
173 |
+
# Add a 'classes' field with a value of a list "positive" and "negative" to all streams
|
174 |
+
AddFields(fields={"classes": ["positive","negatives"]})
|
175 |
+
|
176 |
+
# Add a 'start' field under the 'span' field with a value of 0 to all streams
|
177 |
+
AddFields(fields={"span/start": 0}
|
178 |
+
|
179 |
+
# Add a 'classes' field with a value of a list "positive" and "negative" to 'train' stream
|
180 |
+
AddFields(fields={"classes": ["positive","negatives"], apply_to_stream=["train"]})
|
181 |
+
|
182 |
+
# Add a 'classes' field on a given list, prevent modification of original list
|
183 |
+
# from changing the instance.
|
184 |
+
AddFields(fields={"classes": alist}), use_deepcopy=True)
|
185 |
"""
|
186 |
|
187 |
fields: Dict[str, object]
|
188 |
use_query: bool = False
|
189 |
use_deepcopy: bool = False
|
190 |
|
191 |
+
def process(
|
192 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
193 |
+
) -> Dict[str, Any]:
|
194 |
if self.use_query:
|
195 |
for key, value in self.fields.items():
|
196 |
if self.use_deepcopy:
|
|
|
204 |
|
205 |
|
206 |
class RemoveFields(StreamInstanceOperator):
|
207 |
+
"""Remove specified fields to each instance in a stream.
|
|
|
208 |
|
209 |
Args:
|
210 |
+
fields (List[str]): The fields to remove from each instance.
|
211 |
"""
|
212 |
|
213 |
fields: List[str]
|
214 |
|
215 |
+
def process(
|
216 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
217 |
+
) -> Dict[str, Any]:
|
218 |
+
for field_name in self.fields:
|
219 |
+
del instance[field_name]
|
220 |
return instance
|
221 |
|
222 |
|
223 |
class FieldOperator(StreamInstanceOperator):
|
224 |
+
"""A general stream that processes the values of a field (or multiple ones.
|
225 |
+
|
226 |
Args:
|
227 |
field (Optional[str]): The field to process, if only a single one is passed Defaults to None
|
228 |
to_field (Optional[str]): Field name to save, if only one field is to be saved, if None is passed the operator would happen in-place and replace "field" Defaults to None
|
229 |
field_to_field (Optional[Union[List[Tuple[str, str]], Dict[str, str]]]): Mapping from fields to process to their names after this process, duplicates are allowed. Defaults to None
|
230 |
process_every_value (bool): Processes the values in a list instead of the list as a value, similar to *var. Defaults to False
|
231 |
+
use_query (bool): Whether to use dpath style queries. Defaults to False.
|
232 |
"""
|
233 |
|
234 |
field: Optional[str] = None
|
|
|
242 |
def verify(self):
|
243 |
super().verify()
|
244 |
|
245 |
+
assert (
|
246 |
+
self.field is not None or self.field_to_field is not None
|
247 |
+
), "Must supply a field to work on"
|
248 |
assert (
|
249 |
self.to_field is None or self.field_to_field is None
|
250 |
), f"Can not apply operator to create both on {self.to_field} and on the mapping from fields to fields {self.field_to_field}"
|
251 |
assert (
|
252 |
self.field is None or self.field_to_field is None
|
253 |
), f"Can not apply operator both on {self.field} and on the mapping from fields to fields {self.field_to_field}"
|
254 |
+
assert (
|
255 |
+
self._field_to_field
|
256 |
+
), f"the from and to fields must be defined got: {self._field_to_field}"
|
257 |
|
258 |
@abstractmethod
|
259 |
def process_value(self, value: Any) -> Any:
|
|
|
266 |
self._field_to_field = [(self.field, self.to_field)]
|
267 |
else:
|
268 |
try:
|
269 |
+
self._field_to_field = list(self.field_to_field.items())
|
270 |
except AttributeError:
|
271 |
self._field_to_field = self.field_to_field
|
272 |
|
273 |
+
def process(
|
274 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
275 |
+
) -> Dict[str, Any]:
|
276 |
for from_field, to_field in self._field_to_field:
|
277 |
try:
|
278 |
old_value = dict_get(
|
|
|
282 |
default=self.get_default,
|
283 |
not_exist_ok=self.not_exist_ok,
|
284 |
)
|
285 |
+
except Exception as e:
|
286 |
+
raise ValueError(
|
287 |
+
f"Failed to get '{from_field}' from {instance} due to : {e}"
|
288 |
+
) from e
|
289 |
+
try:
|
290 |
+
if self.process_every_value:
|
291 |
+
new_value = [self.process_value(value) for value in old_value]
|
292 |
+
else:
|
293 |
+
new_value = self.process_value(old_value)
|
294 |
+
except Exception as e:
|
295 |
+
raise ValueError(
|
296 |
+
f"Failed to process '{from_field}' from {instance} due to : {e}"
|
297 |
+
) from e
|
298 |
if self.use_query and is_subpath(from_field, to_field):
|
299 |
dict_delete(instance, from_field)
|
300 |
+
dict_set(
|
301 |
+
instance,
|
302 |
+
to_field,
|
303 |
+
new_value,
|
304 |
+
use_dpath=self.use_query,
|
305 |
+
not_exist_ok=True,
|
306 |
+
)
|
307 |
return instance
|
308 |
|
309 |
|
310 |
class RenameFields(FieldOperator):
|
311 |
+
"""Renames fields."""
|
|
|
|
|
312 |
|
313 |
def process_value(self, value: Any) -> Any:
|
314 |
return value
|
315 |
|
316 |
+
def process(
|
317 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
318 |
+
) -> Dict[str, Any]:
|
319 |
res = super().process(instance=instance, stream_name=stream_name)
|
320 |
vals = [x[1] for x in self._field_to_field]
|
321 |
for key, _ in self._field_to_field:
|
|
|
327 |
|
328 |
|
329 |
class AddConstant(FieldOperator):
|
330 |
+
"""Adds a value, similar to add + field.
|
331 |
+
|
332 |
+
Args:
|
333 |
+
add: sum to add.
|
334 |
"""
|
335 |
+
|
336 |
+
add: Any
|
337 |
+
|
338 |
+
def process_value(self, value: Any) -> Any:
|
339 |
+
return self.add + value
|
340 |
+
|
341 |
+
|
342 |
+
class Augmentor(StreamInstanceOperator):
|
343 |
+
"""A stream that augments the values of either the task input fields before rendering with the template, or the input passed to the model after rendering of the template.
|
344 |
+
|
345 |
Args:
|
346 |
+
augment_model_input: Whether to augment the input to the model.
|
347 |
+
augment_task_input: Whether to augment the task input fields. The specific fields are defined in the FormTask operator.
|
348 |
+
|
349 |
"""
|
350 |
|
351 |
+
augment_task_input: bool = False
|
352 |
+
augment_model_input: bool = False
|
353 |
+
|
354 |
+
def verify(self):
|
355 |
+
assert not (
|
356 |
+
self.augment_task_input and self.augment_model_input
|
357 |
+
), "Augmentor must set either 'augment_task_input' and 'augment_model_input' but not both"
|
358 |
+
assert (
|
359 |
+
self.augment_task_input or self.augment_model_input
|
360 |
+
), "Augmentor must set either 'augment_task_input' or 'augment_model_input'"
|
361 |
+
|
362 |
+
super().verify()
|
363 |
|
364 |
+
@abstractmethod
|
365 |
def process_value(self, value: Any) -> Any:
|
366 |
+
pass
|
367 |
+
|
368 |
+
def prepare(self):
|
369 |
+
pass
|
370 |
|
371 |
+
def set_task_input_fields(self, task_input_fields: List[str]):
|
372 |
+
self._task_input_fields = [
|
373 |
+
"inputs/" + task_input_field for task_input_field in task_input_fields
|
374 |
+
]
|
375 |
|
376 |
+
def process(
|
377 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
378 |
+
) -> Dict[str, Any]:
|
379 |
+
if self.augment_task_input:
|
380 |
+
assert (
|
381 |
+
len(self._task_input_fields) > 0
|
382 |
+
), "No augmentable input fields were defined in FormTask, and augmentation was requested. Specify the fields to augment in 'argumentable_inputs' attribute of the FormTask."
|
383 |
+
fields = self._task_input_fields
|
384 |
+
assert not self.augment_model_input
|
385 |
+
|
386 |
+
if self.augment_model_input:
|
387 |
+
fields = ["source"]
|
388 |
+
assert not self.augment_task_input
|
389 |
+
|
390 |
+
for field_name in fields:
|
391 |
+
try:
|
392 |
+
old_value = dict_get(
|
393 |
+
instance,
|
394 |
+
field_name,
|
395 |
+
use_dpath=True,
|
396 |
+
default="",
|
397 |
+
not_exist_ok=False,
|
398 |
+
)
|
399 |
+
except TypeError as e:
|
400 |
+
raise TypeError(f"Failed to get {field_name} from {instance}") from e
|
401 |
+
|
402 |
+
# We are setting a nested seed based on the value processed, to ensure that
|
403 |
+
# the augmentation randomizations do not effect other randomization choices and
|
404 |
+
# to make the augmentation randomization choices different for each text.
|
405 |
+
with nested_seed(str(hash(old_value))):
|
406 |
+
try:
|
407 |
+
new_value = self.process_value(old_value)
|
408 |
+
except Exception as e:
|
409 |
+
raise RuntimeError(
|
410 |
+
f"Error augmenting value '{old_value}' from '{field_name}' in instance: {instance}"
|
411 |
+
) from e
|
412 |
+
dict_set(instance, field_name, new_value, use_dpath=True, not_exist_ok=True)
|
413 |
+
return instance
|
414 |
+
|
415 |
+
|
416 |
+
class NullAugmentor(Augmentor):
|
417 |
+
def verify(self):
|
418 |
+
pass
|
419 |
+
|
420 |
+
def process_value(self, value: Any) -> Any:
|
421 |
+
return value
|
422 |
+
|
423 |
+
|
424 |
+
class AugmentWhitespace(Augmentor):
|
425 |
+
"""Augments the inputs by replace existing whitespace with other whitespace.
|
426 |
+
|
427 |
+
Currently each whitespace is replaced by a random choice of 1-3 whitespace charaters (spcae, tab, newline).
|
428 |
"""
|
429 |
+
|
430 |
+
def process_value(self, value: Any) -> Any:
|
431 |
+
import re
|
432 |
+
|
433 |
+
words = re.split(r"(\s+)", value)
|
434 |
+
new_value = ""
|
435 |
+
|
436 |
+
for word in words:
|
437 |
+
if word.isspace():
|
438 |
+
new_value += get_random().choice(
|
439 |
+
["\n", "\t", " "]
|
440 |
+
) * get_random().randint(1, 3)
|
441 |
+
else:
|
442 |
+
new_value += word
|
443 |
+
return new_value
|
444 |
+
|
445 |
+
|
446 |
+
class AugmentSuffix(Augmentor):
|
447 |
+
r"""Augments the input by appending to it a randomly selected (typically, whitespace) pattern.
|
448 |
+
|
449 |
+
Args:
|
450 |
+
suffixes : the potential (typically, whitespace) patterns to select from.
|
451 |
+
The dictionary version allows to specify relative weights of the different patterns.
|
452 |
+
remove_existing_trailing_whitespaces : allows to first clean existing trailing whitespaces.
|
453 |
+
The selected pattern is then appended to the potentially trimmed at its end input.
|
454 |
+
|
455 |
+
|
456 |
+
Examples:
|
457 |
+
to append a '\n' or a '\t' to the end of the input, employ
|
458 |
+
AugmentSuffix(augment_model_input=True, suffixes=['\n','\t'])
|
459 |
+
If '\n' is preferred over '\t', at 2:1 ratio, employ
|
460 |
+
AugmentSuffix(augment_model_input=True, suffixes={'\n':2,'\t':1})
|
461 |
+
which will append '\n' twice as often as '\t'.
|
462 |
+
|
463 |
"""
|
464 |
|
465 |
+
suffixes: Optional[Union[List[str], Dict[str, int]]] = [" ", "\n", "\t"]
|
466 |
+
remove_existing_trailing_whitespaces: Optional[bool] = False
|
467 |
+
|
468 |
+
def verify(self):
|
469 |
+
assert (
|
470 |
+
isinstance(self.suffixes, list) or isinstance(self.suffixes, dict)
|
471 |
+
), f"Argument 'suffixes' should be either a list or a dictionary, whereas it is of type {type(self.suffixes)}"
|
472 |
+
|
473 |
+
if isinstance(self.suffixes, dict):
|
474 |
+
for k, v in self.suffixes.items():
|
475 |
+
assert isinstance(
|
476 |
+
k, str
|
477 |
+
), f"suffixes should map strings, whereas key {k!s} is of type {type(k)}"
|
478 |
+
assert isinstance(
|
479 |
+
v, int
|
480 |
+
), f"suffixes should map to ints, whereas value {v!s} is of type {type(v)}"
|
481 |
+
else:
|
482 |
+
for k in self.suffixes:
|
483 |
+
assert isinstance(
|
484 |
+
k, str
|
485 |
+
), f"suffixes should be a list of strings, whereas member {k!s} is of type {type(k)}"
|
486 |
+
|
487 |
+
self.pats = (
|
488 |
+
self.suffixes
|
489 |
+
if isinstance(self.suffixes, list)
|
490 |
+
else [k for k, v in self.suffixes.items()]
|
491 |
+
)
|
492 |
+
total_weight = (
|
493 |
+
len(self.pats)
|
494 |
+
if isinstance(self.suffixes, list)
|
495 |
+
else sum([v for k, v in self.suffixes.items()])
|
496 |
+
)
|
497 |
+
self.weights = (
|
498 |
+
[1.0 / total_weight] * len(self.pats)
|
499 |
+
if isinstance(self.suffixes, list)
|
500 |
+
else [float(self.suffixes[p]) / total_weight for p in self.pats]
|
501 |
+
)
|
502 |
+
super().verify()
|
503 |
+
|
504 |
+
def process_value(self, value: Any) -> Any:
|
505 |
+
assert value is not None, "input value should not be None"
|
506 |
+
new_value = str(value)
|
507 |
+
if self.remove_existing_trailing_whitespaces:
|
508 |
+
new_value = new_value.rstrip()
|
509 |
+
new_value += get_random().choices(self.pats, self.weights, k=1)[0]
|
510 |
+
|
511 |
+
return new_value
|
512 |
+
|
513 |
+
|
514 |
+
class ShuffleFieldValues(FieldOperator):
|
515 |
+
"""Shuffles an iterable value."""
|
516 |
+
|
517 |
def process_value(self, value: Any) -> Any:
|
518 |
res = list(value)
|
519 |
+
get_random().shuffle(res)
|
520 |
return res
|
521 |
|
522 |
|
523 |
class JoinStr(FieldOperator):
|
524 |
+
"""Joins a list of strings (contents of a field), similar to str.join().
|
525 |
+
|
526 |
Args:
|
527 |
separator (str): text to put between values
|
528 |
"""
|
|
|
534 |
|
535 |
|
536 |
class Apply(StreamInstanceOperator):
|
537 |
+
"""A class used to apply a python function and store the result in a field.
|
538 |
+
|
539 |
+
Args:
|
540 |
+
function (str): name of function.
|
541 |
+
to_field (str): the field to store the result
|
542 |
+
additional arguments are field names passed to the function
|
543 |
+
|
544 |
+
Examples:
|
545 |
+
Store in field "b" the uppercase string of the value in field "a"
|
546 |
+
Apply("a", function=str.upper, to_field="b")
|
547 |
+
|
548 |
+
Dump the json representation of field "t" and store back in the same field.
|
549 |
+
Apply("t", function=json.dumps, to_field="t")
|
550 |
+
|
551 |
+
Set the time in a field 'b'.
|
552 |
+
Apply(function=time.time, to_field="b")
|
553 |
+
|
554 |
+
"""
|
555 |
+
|
556 |
__allow_unexpected_arguments__ = True
|
557 |
function: Callable = NonPositionalField(required=True)
|
558 |
to_field: str = NonPositionalField(required=True)
|
|
|
567 |
else:
|
568 |
parts.append(function.__name__)
|
569 |
|
570 |
+
return ".".join(parts)
|
|
|
|
|
571 |
|
572 |
def str_to_function(self, function_str: str) -> Callable:
|
573 |
splitted = function_str.split(".", 1)
|
574 |
if len(splitted) == 1:
|
575 |
+
return __builtins__[splitted[0]]
|
576 |
+
|
577 |
+
module_name, function_name = splitted
|
578 |
+
if module_name in __builtins__:
|
579 |
+
obj = __builtins__[module_name]
|
580 |
+
elif module_name in globals():
|
581 |
+
obj = globals()[module_name]
|
582 |
else:
|
583 |
+
obj = importlib.import_module(module_name)
|
584 |
+
for part in function_name.split("."):
|
585 |
+
obj = getattr(obj, part)
|
586 |
+
return obj
|
|
|
|
|
|
|
|
|
|
|
|
|
587 |
|
588 |
def prepare(self):
|
589 |
super().prepare()
|
|
|
591 |
self.function = self.str_to_function(self.function)
|
592 |
self._init_dict["function"] = self.function_to_str(self.function)
|
593 |
|
594 |
+
def process(
|
595 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
596 |
+
) -> Dict[str, Any]:
|
597 |
argv = [instance[arg] for arg in self._argv]
|
598 |
kwargs = {key: instance[val] for key, val in self._kwargs}
|
599 |
|
|
|
604 |
|
605 |
|
606 |
class ListFieldValues(StreamInstanceOperator):
|
607 |
+
"""Concatenates values of multiple fields into a list, and assigns it to a new field."""
|
|
|
|
|
608 |
|
609 |
+
fields: List[str]
|
610 |
to_field: str
|
611 |
use_query: bool = False
|
612 |
|
613 |
+
def process(
|
614 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
615 |
+
) -> Dict[str, Any]:
|
616 |
values = []
|
617 |
+
for field_name in self.fields:
|
618 |
+
values.append(dict_get(instance, field_name, use_dpath=self.use_query))
|
619 |
instance[self.to_field] = values
|
620 |
return instance
|
621 |
|
622 |
|
623 |
class ZipFieldValues(StreamInstanceOperator):
|
624 |
+
"""Zips values of multiple fields similar to list(zip(*fields))."""
|
|
|
|
|
625 |
|
626 |
fields: str
|
627 |
to_field: str
|
628 |
longest: bool = False
|
629 |
use_query: bool = False
|
630 |
|
631 |
+
def process(
|
632 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
633 |
+
) -> Dict[str, Any]:
|
634 |
values = []
|
635 |
+
for field_name in self.fields:
|
636 |
+
values.append(dict_get(instance, field_name, use_dpath=self.use_query))
|
637 |
if self.longest:
|
638 |
zipped = zip_longest(*values)
|
639 |
else:
|
|
|
643 |
|
644 |
|
645 |
class IndexOf(StreamInstanceOperator):
|
646 |
+
"""Finds the location of one value in another (iterable) value similar to to_field=search_in.index(index_of)."""
|
|
|
|
|
647 |
|
648 |
search_in: str
|
649 |
index_of: str
|
650 |
to_field: str
|
651 |
use_query: bool = False
|
652 |
|
653 |
+
def process(
|
654 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
655 |
+
) -> Dict[str, Any]:
|
656 |
lst = dict_get(instance, self.search_in, use_dpath=self.use_query)
|
657 |
item = dict_get(instance, self.index_of, use_dpath=self.use_query)
|
658 |
instance[self.to_field] = lst.index(item)
|
|
|
660 |
|
661 |
|
662 |
class TakeByField(StreamInstanceOperator):
|
663 |
+
"""Takes value from one field based on another field similar to field[index]."""
|
|
|
|
|
664 |
|
665 |
field: str
|
666 |
index: str
|
|
|
671 |
if self.to_field is None:
|
672 |
self.to_field = self.field
|
673 |
|
674 |
+
def process(
|
675 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
676 |
+
) -> Dict[str, Any]:
|
677 |
value = dict_get(instance, self.field, use_dpath=self.use_query)
|
678 |
index_value = dict_get(instance, self.index, use_dpath=self.use_query)
|
679 |
instance[self.to_field] = value[index_value]
|
|
|
681 |
|
682 |
|
683 |
class CopyFields(FieldOperator):
|
684 |
+
"""Copies specified fields from one field to another.
|
|
|
685 |
|
686 |
Args:
|
687 |
field_to_field (Union[List[List], Dict[str, str]]): A list of lists, where each sublist contains the source field and the destination field, or a dictionary mapping source fields to destination fields.
|
|
|
695 |
class AddID(StreamInstanceOperator):
|
696 |
id_field_name: str = "id"
|
697 |
|
698 |
+
def process(
|
699 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
700 |
+
) -> Dict[str, Any]:
|
701 |
instance[self.id_field_name] = str(uuid.uuid4()).replace("-", "")
|
702 |
return instance
|
703 |
|
704 |
|
705 |
class CastFields(StreamInstanceOperator):
|
706 |
+
"""Casts specified fields to specified types.
|
|
|
707 |
|
708 |
Args:
|
709 |
types (Dict[str, str]): A dictionary mapping fields to their new types.
|
|
|
726 |
def _cast_single(self, value, type, field):
|
727 |
try:
|
728 |
return self.types[type](value)
|
729 |
+
except Exception as e:
|
730 |
if field not in self.failure_defaults:
|
731 |
raise ValueError(
|
732 |
f'Failed to cast field "{field}" with value {value} to type "{type}", and no default value is provided.'
|
733 |
+
) from e
|
734 |
return self.failure_defaults[field]
|
735 |
|
736 |
def _cast_multiple(self, values, type, field):
|
737 |
values = [self._cast_single(value, type, field) for value in values]
|
738 |
|
739 |
+
def process(
|
740 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
741 |
+
) -> Dict[str, Any]:
|
742 |
+
for field_name, type in self.fields.items():
|
743 |
+
value = dict_get(instance, field_name, use_dpath=self.use_nested_query)
|
744 |
if self.cast_multiple:
|
745 |
+
casted_value = self._cast_multiple(value, type, field_name)
|
746 |
else:
|
747 |
+
casted_value = self._cast_single(value, type, field_name)
|
748 |
+
dict_set(
|
749 |
+
instance, field_name, casted_value, use_dpath=self.use_nested_query
|
750 |
+
)
|
751 |
return instance
|
752 |
|
753 |
|
|
|
770 |
strict: bool = False
|
771 |
recursive: bool = True
|
772 |
|
773 |
+
def process(
|
774 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
775 |
+
) -> Dict[str, Any]:
|
776 |
return recursive_divide(instance, self.divisor, strict=self.strict)
|
777 |
|
778 |
|
779 |
class ArtifactFetcherMixin:
|
780 |
+
"""Provides a way to fetch and cache artifacts in the system.
|
|
|
781 |
|
782 |
Args:
|
783 |
cache (Dict[str, Artifact]): A cache for storing fetched artifacts.
|
|
|
794 |
|
795 |
|
796 |
class ApplyOperatorsField(StreamInstanceOperator, ArtifactFetcherMixin):
|
797 |
+
"""Applies value operators to each instance in a stream based on specified fields.
|
|
|
798 |
|
799 |
Args:
|
800 |
value_field (str): The field containing the value to be operated on.
|
|
|
808 |
default_operators: List[str] = None
|
809 |
fields_to_treat_as_list: List[str] = NonPositionalField(default_factory=list)
|
810 |
|
811 |
+
def process(
|
812 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
813 |
+
) -> Dict[str, Any]:
|
814 |
operator_names = instance.get(self.operators_field)
|
815 |
if operator_names is None:
|
816 |
assert (
|
|
|
823 |
|
824 |
for name in operator_names:
|
825 |
operator = self.get_artifact(name)
|
826 |
+
for field_name in self.inputs_fields:
|
827 |
+
value = instance[field_name]
|
828 |
+
if field_name in self.fields_to_treat_as_list:
|
829 |
+
instance[field_name] = [operator.process(v) for v in value]
|
830 |
else:
|
831 |
+
instance[field_name] = operator.process(instance[field_name])
|
832 |
|
833 |
return instance
|
834 |
|
835 |
|
836 |
class FilterByValues(SingleStreamOperator):
|
837 |
+
"""Filters a stream, yielding only instances that match specified values in the provided fields.
|
838 |
+
|
839 |
+
Args:
|
840 |
+
values (Dict[str, Any]): For each field, the values that instances should match to be included in the output.
|
841 |
"""
|
842 |
+
|
843 |
+
required_values: Dict[str, Any]
|
844 |
+
|
845 |
+
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
846 |
+
for instance in stream:
|
847 |
+
filter = False
|
848 |
+
for key, value in self.required_values.items():
|
849 |
+
if key not in instance:
|
850 |
+
raise ValueError(
|
851 |
+
f"Required filter field ('{key}') in FilterByValues is not found in {instance}"
|
852 |
+
)
|
853 |
+
if instance[key] != value:
|
854 |
+
filter = True
|
855 |
+
if not filter:
|
856 |
+
yield instance
|
857 |
+
|
858 |
+
|
859 |
+
class ExtractFieldValues(MultiStreamOperator):
|
860 |
+
field: str
|
861 |
+
stream_name: str
|
862 |
+
overall_top_frequency_percent: Optional[int] = 100
|
863 |
+
min_frequency_percent: Optional[int] = 0
|
864 |
+
to_field: str
|
865 |
+
process_every_value: Optional[bool] = False
|
866 |
+
|
867 |
+
"""
|
868 |
+
Extract the unique values of a field ('field') of a given stream ('stream_name') and store (the most frequent of) them
|
869 |
+
as a list in a new field ('to_field') in all streams.
|
870 |
+
|
871 |
+
More specifically, sort all the unique values encountered in field 'field' by decreasing order of frequency.
|
872 |
+
When 'overall_top_frequency_percent' is smaller than 100, trim the list from bottom, so that the total frequency of
|
873 |
+
the remaining values makes 'overall_top_frequency_percent' of the total number of instances in the stream.
|
874 |
+
When 'min_frequency_percent' is larger than 0, remove from the list any value whose relative frequency makes
|
875 |
+
less than 'min_frequency_percent' of the total number of instances in the stream.
|
876 |
+
At most one of 'overall_top_frequency_percent' and 'min_frequency_percent' is allowed to move from their default values.
|
877 |
+
|
878 |
+
Examples:
|
879 |
+
|
880 |
+
ExtractFieldValues(stream_name="train", field="label", to_field="classes") - extracts all the unique values of
|
881 |
+
field 'label', sorts them by decreasing frequency, and stores the resulting list in field 'classes' of each and
|
882 |
+
every instance in all streams.
|
883 |
+
|
884 |
+
ExtractFieldValues(stream_name="train", field="labels", to_field="classes", process_every_value=True) -
|
885 |
+
in case that field 'labels' contains a list of values (and not a single value) - track the occurrences of all the possible
|
886 |
+
value members in these lists, and report the most frequent values.
|
887 |
+
if process_every_value=False, track the most frequent whole lists, and report those (as a list of lists) in field
|
888 |
+
'to_field' of each instance of all streams.
|
889 |
+
|
890 |
+
ExtractFieldValues(stream_name="train", field="label", to_field="classes",overall_top_frequency_percent=80) -
|
891 |
+
extracts the most frequent possible values of field 'label' that together cover at least 80% of the instances of stream_name,
|
892 |
+
and stores them in field 'classes' of each instance of all streams.
|
893 |
+
|
894 |
+
ExtractFieldValues(stream_name="train", field="label", to_field="classes",min_frequency_percent=5) -
|
895 |
+
extracts all possible values of field 'label' that cover, each, at least 5% of the instances.
|
896 |
+
Stores these values, sorted by decreasing order of frequency, in field 'classes' of each instance in all streams.
|
897 |
+
"""
|
898 |
+
|
899 |
+
def verify(self):
|
900 |
+
assert (
|
901 |
+
self.overall_top_frequency_percent <= 100
|
902 |
+
and self.overall_top_frequency_percent >= 0
|
903 |
+
), "'overall_top_frequency_percent' must be between 0 and 100"
|
904 |
+
assert (
|
905 |
+
self.min_frequency_percent <= 100 and self.min_frequency_percent >= 0
|
906 |
+
), "'min_frequency_percent' must be between 0 and 100"
|
907 |
+
assert not (
|
908 |
+
self.overall_top_frequency_percent < 100 and self.min_frequency_percent > 0
|
909 |
+
), "At most one of 'overall_top_frequency_percent' and 'min_frequency_percent' is allowed to move from their default value"
|
910 |
+
super().verify()
|
911 |
+
|
912 |
+
def process(self, multi_stream: MultiStream) -> MultiStream:
|
913 |
+
stream = multi_stream[self.stream_name]
|
914 |
+
all_values = []
|
915 |
+
for instance in stream:
|
916 |
+
if (not isinstance(instance[self.field], list)) and (
|
917 |
+
self.process_every_value is True
|
918 |
+
):
|
919 |
+
raise ValueError(
|
920 |
+
"'process_every_field' is allowed to change to 'True' only for fields whose contents are lists"
|
921 |
+
)
|
922 |
+
if (not isinstance(instance[self.field], list)) or (
|
923 |
+
self.process_every_value is False
|
924 |
+
):
|
925 |
+
# either not a list, or is a list but process_every_value == False : view contetns of 'field' as one entity whose occurrences are counted.
|
926 |
+
all_values.append(
|
927 |
+
(*instance[self.field],)
|
928 |
+
if isinstance(instance[self.field], list)
|
929 |
+
else instance[self.field]
|
930 |
+
) # convert to a tuple if list, to enable the use of Counter which would not accept
|
931 |
+
# a list as an entity to count its occurrences
|
932 |
+
else:
|
933 |
+
# content of 'field' is a list and process_every_value == True: add one occurrence on behalf of each individual value
|
934 |
+
all_values.extend(instance[self.field])
|
935 |
+
counter = Counter(
|
936 |
+
all_values
|
937 |
+
) # here all_values is a list of individual values, or tupples. Hence, Counter is feasible
|
938 |
+
values_and_counts = counter.most_common()
|
939 |
+
if self.overall_top_frequency_percent < 100:
|
940 |
+
top_frequency = len(all_values) * self.overall_top_frequency_percent / 100.0
|
941 |
+
sum_counts = 0
|
942 |
+
for _i, p in enumerate(values_and_counts):
|
943 |
+
sum_counts += p[1]
|
944 |
+
if sum_counts >= top_frequency:
|
945 |
+
break
|
946 |
+
values_and_counts = counter.most_common(_i + 1)
|
947 |
+
if self.min_frequency_percent > 0:
|
948 |
+
min_frequency = self.min_frequency_percent * len(all_values) / 100.0
|
949 |
+
while values_and_counts[-1][1] < min_frequency:
|
950 |
+
values_and_counts.pop()
|
951 |
+
values_to_keep = [
|
952 |
+
[*ele[0]] if isinstance(ele[0], tuple) else ele[0]
|
953 |
+
for ele in values_and_counts
|
954 |
+
]
|
955 |
+
for name in multi_stream:
|
956 |
+
for instance in multi_stream[name]:
|
957 |
+
instance[self.to_field] = values_to_keep
|
958 |
+
return multi_stream
|
959 |
+
|
960 |
+
|
961 |
+
class FilterByListsOfValues(SingleStreamOperator):
|
962 |
+
"""Filters a stream, yielding only instances that whose field values are included in the specified value lists.
|
963 |
|
964 |
Args:
|
965 |
+
required_values (Dict[str, List]): For each field, the list of values that instances should match to be included in the output.
|
966 |
"""
|
967 |
|
968 |
+
required_values: Dict[str, List]
|
969 |
+
|
970 |
+
def verify(self):
|
971 |
+
super().verify()
|
972 |
+
for key, value in self.required_values.items():
|
973 |
+
if not isinstance(value, list):
|
974 |
+
raise ValueError(
|
975 |
+
f"The filter for key ('{key}') in FilterByListsOfValues is not a list but '{value}'"
|
976 |
+
)
|
977 |
|
978 |
+
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
979 |
for instance in stream:
|
980 |
+
filter = False
|
981 |
+
for key, value in self.required_values.items():
|
982 |
+
if key not in instance:
|
983 |
+
raise ValueError(
|
984 |
+
f"Required filter field ('{key}') in FilterByListsOfValues is not found in {instance}"
|
985 |
+
)
|
986 |
+
if instance[key] not in value:
|
987 |
+
filter = True
|
988 |
+
if not filter:
|
989 |
yield instance
|
990 |
|
991 |
|
992 |
+
class Intersect(FieldOperator):
|
993 |
+
"""Intersects the value of a field, which must be a list, with a given list.
|
994 |
+
|
995 |
+
Args:
|
996 |
+
allowed_values (list) - list to intersect.
|
997 |
+
"""
|
998 |
+
|
999 |
+
allowed_values: List[Any]
|
1000 |
+
|
1001 |
+
def verify(self):
|
1002 |
+
super().verify()
|
1003 |
+
if self.process_every_value:
|
1004 |
+
raise ValueError(
|
1005 |
+
"'process_every_value=True' is not supported in Intersect operator"
|
1006 |
+
)
|
1007 |
+
|
1008 |
+
if not isinstance(self.allowed_values, list):
|
1009 |
+
raise ValueError(
|
1010 |
+
f"The allowed_values is not a list but '{self.allowed_values}'"
|
1011 |
+
)
|
1012 |
+
|
1013 |
+
def process_value(self, value: Any) -> Any:
|
1014 |
+
if not isinstance(value, list):
|
1015 |
+
raise ValueError(f"The value in field is not a list but '{value}'")
|
1016 |
+
return [e for e in value if e in self.allowed_values]
|
1017 |
+
|
1018 |
+
|
1019 |
+
class RemoveValues(FieldOperator):
|
1020 |
+
"""Removes elements in a field, which must be a list, using a given list of unallowed.
|
1021 |
+
|
1022 |
+
Args:
|
1023 |
+
unallowed_values (list) - removed_values.
|
1024 |
"""
|
1025 |
+
|
1026 |
+
unallowed_values: List[Any]
|
1027 |
+
|
1028 |
+
def verify(self):
|
1029 |
+
super().verify()
|
1030 |
+
if self.process_every_value:
|
1031 |
+
raise ValueError(
|
1032 |
+
"'process_every_value=True' is not supported in RemoveValues operator"
|
1033 |
+
)
|
1034 |
+
|
1035 |
+
if not isinstance(self.unallowed_values, list):
|
1036 |
+
raise ValueError(
|
1037 |
+
f"The unallowed_values is not a list but '{self.unallowed_values}'"
|
1038 |
+
)
|
1039 |
+
|
1040 |
+
def process_value(self, value: Any) -> Any:
|
1041 |
+
if not isinstance(value, list):
|
1042 |
+
raise ValueError(f"The value in field is not a list but '{value}'")
|
1043 |
+
return [e for e in value if e not in self.unallowed_values]
|
1044 |
+
|
1045 |
+
|
1046 |
+
class Unique(SingleStreamReducer):
|
1047 |
+
"""Reduces a stream to unique instances based on specified fields.
|
1048 |
|
1049 |
Args:
|
1050 |
fields (List[str]): The fields that should be unique in each instance.
|
|
|
1055 |
@staticmethod
|
1056 |
def to_tuple(instance: dict, fields: List[str]) -> tuple:
|
1057 |
result = []
|
1058 |
+
for field_name in fields:
|
1059 |
+
value = instance[field_name]
|
1060 |
if isinstance(value, list):
|
1061 |
value = tuple(value)
|
1062 |
result.append(value)
|
|
|
1072 |
|
1073 |
|
1074 |
class SplitByValue(MultiStreamOperator):
|
1075 |
+
"""Splits a MultiStream into multiple streams based on unique values in specified fields.
|
|
|
1076 |
|
1077 |
Args:
|
1078 |
fields (List[str]): The fields to use when splitting the MultiStream.
|
|
|
1088 |
for stream_name, stream in multi_stream.items():
|
1089 |
stream_unique_values = uniques[stream_name]
|
1090 |
for unique_values in stream_unique_values:
|
1091 |
+
filtering_values = dict(zip(self.fields, unique_values))
|
1092 |
+
filtered_streams = FilterByValues(
|
1093 |
+
required_values=filtering_values
|
1094 |
+
)._process_single_stream(stream)
|
1095 |
+
filtered_stream_name = (
|
1096 |
+
stream_name + "_" + nested_tuple_to_string(unique_values)
|
1097 |
+
)
|
1098 |
result[filtered_stream_name] = filtered_streams
|
1099 |
|
1100 |
return MultiStream(result)
|
1101 |
|
1102 |
|
1103 |
class ApplyStreamOperatorsField(SingleStreamOperator, ArtifactFetcherMixin):
|
1104 |
+
"""Applies stream operators to a stream based on specified fields in each instance.
|
|
|
1105 |
|
1106 |
Args:
|
1107 |
field (str): The field containing the operators to be applied.
|
|
|
1111 |
field: str
|
1112 |
reversed: bool = False
|
1113 |
|
1114 |
+
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
1115 |
first_instance = stream.peak()
|
1116 |
|
1117 |
operators = first_instance.get(self.field, [])
|
|
|
1123 |
|
1124 |
for operator_name in operators:
|
1125 |
operator = self.get_artifact(operator_name)
|
1126 |
+
assert isinstance(
|
1127 |
+
operator, StreamingOperator
|
1128 |
+
), f"Operator {operator_name} must be a SingleStreamOperator"
|
1129 |
|
1130 |
stream = operator(MultiStream({"tmp": stream}))["tmp"]
|
1131 |
|
1132 |
yield from stream
|
1133 |
|
1134 |
|
1135 |
+
class ApplyMetric(SingleStreamOperator, ArtifactFetcherMixin):
|
1136 |
+
"""Applies metric operators to a stream based on a metric field specified in each instance.
|
1137 |
+
|
1138 |
+
Args:
|
1139 |
+
metric_field (str): The field containing the metrics to be applied.
|
1140 |
+
calc_confidence_intervals (bool): Whether the applied metric should calculate confidence intervals or not.
|
1141 |
"""
|
1142 |
+
|
1143 |
+
metric_field: str
|
1144 |
+
calc_confidence_intervals: bool
|
1145 |
+
|
1146 |
+
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
1147 |
+
from .metrics import Metric, MetricPipeline, MetricWithConfidenceInterval
|
1148 |
+
|
1149 |
+
first_instance = stream.peak()
|
1150 |
+
|
1151 |
+
metric_names = first_instance.get(self.metric_field, [])
|
1152 |
+
if not metric_names:
|
1153 |
+
raise RuntimeError(
|
1154 |
+
f"Missing metric names in field '{self.metric_field}' and instance '{first_instance}'."
|
1155 |
+
)
|
1156 |
+
|
1157 |
+
if isinstance(metric_names, str):
|
1158 |
+
metric_names = [metric_names]
|
1159 |
+
|
1160 |
+
# Each metric operator computes its score and then sets the main score, overwriting
|
1161 |
+
# the previous main score value (if any). So, we need to reverse the order of the listed metrics.
|
1162 |
+
# This will cause the first listed metric to run last, and the main score will be set
|
1163 |
+
# by the first listed metric (as desired).
|
1164 |
+
metric_names = list(reversed(metric_names))
|
1165 |
+
|
1166 |
+
for metric_name in metric_names:
|
1167 |
+
metric = self.get_artifact(metric_name)
|
1168 |
+
assert isinstance(
|
1169 |
+
metric, Metric
|
1170 |
+
), f"Operator {metric_name} must be a Metric"
|
1171 |
+
|
1172 |
+
if not self.calc_confidence_intervals:
|
1173 |
+
if isinstance(metric, MetricWithConfidenceInterval):
|
1174 |
+
metric.disable_confidence_interval_calculation()
|
1175 |
+
elif isinstance(metric, MetricPipeline) and isinstance(
|
1176 |
+
metric.metric, MetricWithConfidenceInterval
|
1177 |
+
):
|
1178 |
+
metric.metric.disable_confidence_interval_calculation()
|
1179 |
+
|
1180 |
+
stream = metric(MultiStream({"tmp": stream}))["tmp"]
|
1181 |
+
|
1182 |
+
yield from stream
|
1183 |
+
|
1184 |
+
|
1185 |
+
class AddFieldNamePrefix(StreamInstanceOperator):
|
1186 |
+
"""Adds a prefix to each field name in each instance of a stream.
|
1187 |
|
1188 |
Args:
|
1189 |
prefix_dict (Dict[str, str]): A dictionary mapping stream names to prefixes.
|
|
|
1194 |
def prepare(self):
|
1195 |
return super().prepare()
|
1196 |
|
1197 |
+
def process(
|
1198 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
1199 |
+
) -> Dict[str, Any]:
|
1200 |
+
return {
|
1201 |
+
self.prefix_dict[stream_name] + key: value
|
1202 |
+
for key, value in instance.items()
|
1203 |
+
}
|
1204 |
|
1205 |
|
1206 |
class MergeStreams(MultiStreamOperator):
|
1207 |
+
"""Merges multiple streams into a single stream.
|
|
|
1208 |
|
1209 |
Args:
|
1210 |
new_stream_name (str): The name of the new stream resulting from the merge.
|
|
|
1212 |
origin_stream_name_field_name (str): The field name for the origin stream name.
|
1213 |
"""
|
1214 |
|
1215 |
+
streams_to_merge: List[str] = None
|
1216 |
new_stream_name: str = "all"
|
1217 |
add_origin_stream_name: bool = True
|
1218 |
origin_stream_name_field_name: str = "origin"
|
1219 |
|
1220 |
def merge(self, multi_stream):
|
1221 |
for stream_name, stream in multi_stream.items():
|
1222 |
+
if self.streams_to_merge is None or stream_name in self.streams_to_merge:
|
1223 |
+
for instance in stream:
|
1224 |
+
if self.add_origin_stream_name:
|
1225 |
+
instance[self.origin_stream_name_field_name] = stream_name
|
1226 |
+
yield instance
|
1227 |
|
1228 |
def process(self, multi_stream: MultiStream) -> MultiStream:
|
1229 |
+
return MultiStream(
|
1230 |
+
{
|
1231 |
+
self.new_stream_name: Stream(
|
1232 |
+
self.merge, gen_kwargs={"multi_stream": multi_stream}
|
1233 |
+
)
|
1234 |
+
}
|
1235 |
+
)
|
1236 |
|
1237 |
|
1238 |
class Shuffle(PagedStreamOperator):
|
1239 |
+
"""Shuffles the order of instances in each page of a stream.
|
|
|
1240 |
|
1241 |
Args:
|
1242 |
page_size (int): The size of each page in the stream. Defaults to 1000.
|
1243 |
"""
|
1244 |
|
1245 |
+
def process(self, page: List[Dict], stream_name: Optional[str] = None) -> Generator:
|
1246 |
+
get_random().shuffle(page)
|
1247 |
yield from page
|
1248 |
|
1249 |
|
1250 |
class EncodeLabels(StreamInstanceOperator):
|
1251 |
+
"""Encode labels of specified fields together a into integers.
|
|
|
1252 |
|
1253 |
Args:
|
1254 |
fields (List[str]): The fields to encode together.
|
|
|
1260 |
self.encoder = {}
|
1261 |
return super()._process_multi_stream(multi_stream)
|
1262 |
|
1263 |
+
def process(
|
1264 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
1265 |
+
) -> Dict[str, Any]:
|
1266 |
+
for field_name in self.fields:
|
1267 |
+
values = dict_get(instance, field_name, use_dpath=True)
|
1268 |
if not isinstance(values, list):
|
1269 |
values = [values]
|
1270 |
for value in values:
|
1271 |
if value not in self.encoder:
|
1272 |
self.encoder[value] = len(self.encoder)
|
1273 |
new_values = [self.encoder[value] for value in values]
|
1274 |
+
dict_set(
|
1275 |
+
instance, field_name, new_values, use_dpath=True, set_multiple=True
|
1276 |
+
)
|
1277 |
|
1278 |
return instance
|
1279 |
|
|
|
1281 |
class StreamRefiner(SingleStreamOperator):
|
1282 |
max_instances: int = None
|
1283 |
|
1284 |
+
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
1285 |
if self.max_instances is not None:
|
1286 |
yield from stream.take(self.max_instances)
|
1287 |
else:
|
|
|
1289 |
|
1290 |
|
1291 |
class DeterministicBalancer(StreamRefiner):
|
1292 |
+
"""A class used to balance streams deterministically.
|
|
|
1293 |
|
1294 |
Attributes:
|
1295 |
fields (List[str]): A list of field names to be used in determining the signature of an instance.
|
|
|
1303 |
fields: List[str]
|
1304 |
|
1305 |
def signature(self, instance):
|
1306 |
+
return str(
|
1307 |
+
tuple(dict_get(instance, field, use_dpath=True) for field in self.fields)
|
1308 |
+
)
|
1309 |
|
1310 |
+
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
1311 |
counter = collections.Counter()
|
1312 |
|
1313 |
for instance in stream:
|
1314 |
counter[self.signature(instance)] += 1
|
1315 |
|
1316 |
+
if len(counter) == 0:
|
1317 |
+
return
|
1318 |
+
|
1319 |
lowest_count = counter.most_common()[-1][-1]
|
1320 |
|
1321 |
max_total_instances_per_sign = lowest_count
|
1322 |
if self.max_instances is not None:
|
1323 |
+
max_total_instances_per_sign = min(
|
1324 |
+
lowest_count, self.max_instances // len(counter)
|
1325 |
+
)
|
1326 |
|
1327 |
counter = collections.Counter()
|
1328 |
|
|
|
1338 |
|
1339 |
def signature(self, instance):
|
1340 |
total_len = 0
|
1341 |
+
for field_name in self.fields:
|
1342 |
+
total_len += len(dict_get(instance, field_name, use_dpath=True))
|
1343 |
for i, val in enumerate(self.segments_boundaries):
|
1344 |
if total_len < val:
|
1345 |
return i
|