Elron commited on
Commit
e81c49a
1 Parent(s): e62296d

Upload operators.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. operators.py +47 -21
operators.py CHANGED
@@ -35,7 +35,6 @@ General Operaotrs List:
35
  import collections
36
  import copy
37
  import operator
38
- import os
39
  import uuid
40
  import zipfile
41
  from abc import abstractmethod
@@ -418,8 +417,6 @@ class InstanceFieldOperator(StreamInstanceOperator):
418
  raise ValueError(
419
  f"Failed to process '{from_field}' from {instance} due to : {e}"
420
  ) from e
421
- if is_subpath(from_field, to_field) or is_subpath(to_field, from_field):
422
- dict_delete(instance, from_field)
423
  dict_set(
424
  instance,
425
  to_field,
@@ -471,18 +468,7 @@ class RenameFields(FieldOperator):
471
  if (not is_subpath(from_field, to_field)) and (
472
  not is_subpath(to_field, from_field)
473
  ):
474
- dict_delete(res, from_field)
475
- if self.use_query:
476
- from_field_components = list(
477
- os.path.normpath(from_field).split(os.path.sep)
478
- )
479
- while len(from_field_components) > 1:
480
- from_field_components.pop()
481
- parent = dict_get(res, os.path.sep.join(from_field_components))
482
- if isinstance(parent, dict) and not parent:
483
- dict_delete(res, os.path.sep.join(from_field_components))
484
- else:
485
- break
486
 
487
  return res
488
 
@@ -1480,10 +1466,6 @@ class RemoveValues(FieldOperator):
1480
 
1481
  def verify(self):
1482
  super().verify()
1483
- if self.process_every_value:
1484
- raise ValueError(
1485
- "'process_every_value=True' is not supported in RemoveValues operator"
1486
- )
1487
 
1488
  if not isinstance(self.unallowed_values, list):
1489
  raise ValueError(
@@ -1712,7 +1694,7 @@ class EncodeLabels(StreamInstanceOperator):
1712
  {"a": "blue", "b": ["green"], "c":"water"}] will yield the
1713
  output stream = [{'a': 0, 'b': [0, 1], 'c': 'bread'}, {'a': 1, 'b': [2], 'c': 'water'}]
1714
 
1715
- Note: dpath is applied here, and hence, fields that are lists, should be included in
1716
  input 'fields' with the appendix "/*" as in the above example.
1717
 
1718
  """
@@ -1728,14 +1710,21 @@ class EncodeLabels(StreamInstanceOperator):
1728
  ) -> Dict[str, Any]:
1729
  for field_name in self.fields:
1730
  values = dict_get(instance, field_name, use_dpath=True)
 
1731
  if not isinstance(values, list):
1732
  values = [values]
1733
  for value in values:
1734
  if value not in self.encoder:
1735
  self.encoder[value] = len(self.encoder)
1736
  new_values = [self.encoder[value] for value in values]
 
 
1737
  dict_set(
1738
- instance, field_name, new_values, use_dpath=True, set_multiple=True
 
 
 
 
1739
  )
1740
 
1741
  return instance
@@ -1904,3 +1893,40 @@ class ExtractZipFile(SideEffectOperator):
1904
  def process(self):
1905
  with zipfile.ZipFile(self.zip_file) as zf:
1906
  zf.extractall(self.target_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  import collections
36
  import copy
37
  import operator
 
38
  import uuid
39
  import zipfile
40
  from abc import abstractmethod
 
417
  raise ValueError(
418
  f"Failed to process '{from_field}' from {instance} due to : {e}"
419
  ) from e
 
 
420
  dict_set(
421
  instance,
422
  to_field,
 
468
  if (not is_subpath(from_field, to_field)) and (
469
  not is_subpath(to_field, from_field)
470
  ):
471
+ dict_delete(res, from_field, remove_empty_ancestors=True)
 
 
 
 
 
 
 
 
 
 
 
472
 
473
  return res
474
 
 
1466
 
1467
  def verify(self):
1468
  super().verify()
 
 
 
 
1469
 
1470
  if not isinstance(self.unallowed_values, list):
1471
  raise ValueError(
 
1694
  {"a": "blue", "b": ["green"], "c":"water"}] will yield the
1695
  output stream = [{'a': 0, 'b': [0, 1], 'c': 'bread'}, {'a': 1, 'b': [2], 'c': 'water'}]
1696
 
1697
+ Note: qpath is applied here, and hence, fields that are lists, should be included in
1698
  input 'fields' with the appendix "/*" as in the above example.
1699
 
1700
  """
 
1710
  ) -> Dict[str, Any]:
1711
  for field_name in self.fields:
1712
  values = dict_get(instance, field_name, use_dpath=True)
1713
+ values_was_a_list = isinstance(values, list)
1714
  if not isinstance(values, list):
1715
  values = [values]
1716
  for value in values:
1717
  if value not in self.encoder:
1718
  self.encoder[value] = len(self.encoder)
1719
  new_values = [self.encoder[value] for value in values]
1720
+ if not values_was_a_list:
1721
+ new_values = new_values[0]
1722
  dict_set(
1723
+ instance,
1724
+ field_name,
1725
+ new_values,
1726
+ use_dpath=True,
1727
+ set_multiple="*" in field_name,
1728
  )
1729
 
1730
  return instance
 
1893
  def process(self):
1894
  with zipfile.ZipFile(self.zip_file) as zf:
1895
  zf.extractall(self.target_dir)
1896
+
1897
+
1898
+ class DuplicateInstances(SingleStreamOperator):
1899
+ """Operator which duplicates each instance in stream a given number of times.
1900
+
1901
+ Attributes:
1902
+ num_duplications (int): How many times each instance should be duplicated (1 means no duplication).
1903
+ duplication_index_field (Optional[str]):
1904
+ If given, then additional field with specified name is added to each duplicated instance,
1905
+ which contains id of a given duplication. Defaults to None, so no field is added.
1906
+ """
1907
+
1908
+ num_duplications: int
1909
+ duplication_index_field: Optional[str] = None
1910
+
1911
+ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1912
+ for instance in stream:
1913
+ for idx in range(self.num_duplications):
1914
+ duplicate = deepcopy(instance)
1915
+ if self.duplication_index_field:
1916
+ duplicate.update({self.duplication_index_field: idx})
1917
+ yield duplicate
1918
+
1919
+ def verify(self):
1920
+ if not isinstance(self.num_duplications, int) or self.num_duplications < 1:
1921
+ raise ValueError(
1922
+ f"num_duplications must be an integer equal to or greater than 1. "
1923
+ f"Got: {self.num_duplications}."
1924
+ )
1925
+
1926
+ if self.duplication_index_field is not None and not isinstance(
1927
+ self.duplication_index_field, str
1928
+ ):
1929
+ raise ValueError(
1930
+ f"If given, duplication_index_field must be a string. "
1931
+ f"Got: {self.duplication_index_field}"
1932
+ )