Elron commited on
Commit
04d2454
1 Parent(s): eee0bf8

Upload splitters.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. splitters.py +72 -87
splitters.py CHANGED
@@ -1,19 +1,10 @@
1
  import itertools
2
  from abc import abstractmethod
3
- from dataclasses import field
4
- from typing import Dict, List, Optional
5
 
6
  from .artifact import Artifact
7
- from .generator_utils import ReusableGenerator
8
- from .operator import InstanceOperatorWithGlobalAccess, MultiStreamOperator
9
- from .stream import MultiStream
10
-
11
-
12
- class Splitter(MultiStreamOperator):
13
- pass
14
-
15
-
16
- from .random_utils import random
17
  from .split_utils import (
18
  parse_random_mix_string,
19
  parse_slices_string,
@@ -21,6 +12,11 @@ from .split_utils import (
21
  rename_split,
22
  slice_streams,
23
  )
 
 
 
 
 
24
 
25
 
26
  class RenameSplits(Splitter):
@@ -41,8 +37,8 @@ class SplitRandomMix(Splitter):
41
 
42
 
43
  class SeparateSplit(Splitter):
44
- """
45
- Separates a split (e.g. train) into several splits (e.g. train1, train2)
46
  sizes must indicate the size of every split except the last. If no size is give for the last split,
47
  it includes all the examples not allocated to any split.
48
  """
@@ -59,9 +55,15 @@ class SeparateSplit(Splitter):
59
  return super().verify()
60
 
61
  def process(self, multi_stream: MultiStream) -> MultiStream:
62
- mapping = {key: {key: [(None, None)]} for key in multi_stream.keys() if key != self.from_split}
 
 
 
 
63
  so_far = 0
64
- for name, size in itertools.zip_longest(self.to_split_names, self.to_split_sizes):
 
 
65
  mapping[name] = {self.from_split: [(so_far, size)]}
66
  if size:
67
  so_far += size
@@ -87,19 +89,25 @@ class Sampler(Artifact):
87
 
88
  def set_size(self, size):
89
  if isinstance(size, str):
90
- assert size.isdigit(), f"sample_size must be a natural number, got {self.sample_size}"
 
 
91
  size = int(size)
92
  self.sample_size = size
93
 
94
  @abstractmethod
95
- def sample(self, instances_pool: List[Dict[str, object]]) -> List[Dict[str, object]]:
 
 
96
  pass
97
 
98
 
99
  class RandomSampler(Sampler):
100
- def sample(self, instances_pool: List[Dict[str, object]]) -> List[Dict[str, object]]:
 
 
101
  instances_pool = list(instances_pool)
102
- return random.sample(instances_pool, self.sample_size)
103
 
104
 
105
  class DiverseLabelsSampler(Sampler):
@@ -110,14 +118,29 @@ class DiverseLabelsSampler(Sampler):
110
  self.labels = None
111
 
112
  def examplar_repr(self, examplar):
113
- assert (
114
- "inputs" in examplar and self.choices in examplar["inputs"]
115
- ), f"DiverseLabelsSampler assumes each examplar has {self.choices} field in it input"
 
 
 
 
 
 
 
 
 
 
116
  examplar_outputs = next(iter(examplar["outputs"].values()))
117
- return str([choice for choice in examplar["inputs"][self.choices] if choice in examplar_outputs])
 
 
 
 
 
118
 
119
  def divide_by_repr(self, examplars_pool):
120
- labels = dict()
121
  for examplar in examplars_pool:
122
  label_repr = self.examplar_repr(examplar)
123
  if label_repr not in labels:
@@ -125,11 +148,13 @@ class DiverseLabelsSampler(Sampler):
125
  labels[label_repr].append(examplar)
126
  return labels
127
 
128
- def sample(self, instances_pool: List[Dict[str, object]]) -> List[Dict[str, object]]:
 
 
129
  if self.labels is None:
130
  self.labels = self.divide_by_repr(instances_pool)
131
  all_labels = list(self.labels.keys())
132
- random.shuffle(all_labels)
133
  from collections import Counter
134
 
135
  total_allocated = 0
@@ -146,22 +171,21 @@ class DiverseLabelsSampler(Sampler):
146
 
147
  result = []
148
  for label, allocation in allocations.items():
149
- sample = random.sample(self.labels[label], allocation)
150
  result.extend(sample)
151
 
152
- random.shuffle(result)
153
  return result
154
 
155
 
156
- class SpreadSplit(InstanceOperatorWithGlobalAccess):
157
  source_stream: str = None
158
  target_field: str = None
159
  sampler: Sampler = None
160
 
161
  def prepare(self):
162
- self.accessible_streams = [self.source_stream]
163
- self.cache_accessible_streams = True
164
  self.local_cache = None
 
165
 
166
  def verify(self):
167
  assert self.source_stream is not None, "Source stream must be specified"
@@ -169,58 +193,19 @@ class SpreadSplit(InstanceOperatorWithGlobalAccess):
169
  assert self.sampler is not None, "Sampler must be specified"
170
  return super().verify()
171
 
172
- def process(self, instance: Dict[str, object], multi_stream: MultiStream) -> Dict[str, object]:
173
- if self.local_cache is None:
174
- self.local_cache = list(multi_stream[self.source_stream])
175
-
176
- source_stream = self.local_cache
177
-
178
- sampled_instances = self.sampler.sample(source_stream)
179
- instance[self.target_field] = sampled_instances
180
- return instance
181
-
182
-
183
- if __name__ == "__main__":
184
- # some tests
185
- import random
186
-
187
- random.seed(0)
188
- splitter = SplitRandomMix(
189
- mix={
190
- "train": "train[90%]+validation[50%]",
191
- "validation": "train[10%]+validation[50%]",
192
- "test": "test",
193
- }
194
- )
195
-
196
- def generator(name, size):
197
- for i in range(size):
198
- yield {"text": f"{name}_{i}"}
199
-
200
- stream = MultiStream.from_generators(
201
- {
202
- "train": ReusableGenerator(generator, gen_kwargs={"name": "train", "size": 10}),
203
- "validation": ReusableGenerator(generator, gen_kwargs={"name": "validation", "size": 10}),
204
- "test": ReusableGenerator(generator, gen_kwargs={"name": "test", "size": 10}),
205
- }
206
- )
207
-
208
- ds = splitter(stream)
209
- for key, value in ds.items():
210
- print(key)
211
- for item in value:
212
- print(item)
213
-
214
- splitter = SliceSplit(
215
- slices={
216
- "train": "train[:2]+train[2:4]",
217
- "validation": "train[4:6]",
218
- "test": "train[6:]+test",
219
- }
220
- )
221
-
222
- ds = splitter(stream)
223
- for key, value in ds.items():
224
- print(key)
225
- for item in value:
226
- print(item)
 
1
  import itertools
2
  from abc import abstractmethod
3
+ from typing import Dict, List
 
4
 
5
  from .artifact import Artifact
6
+ from .operator import InstanceOperatorWithMultiStreamAccess, MultiStreamOperator
7
+ from .random_utils import get_random
 
 
 
 
 
 
 
 
8
  from .split_utils import (
9
  parse_random_mix_string,
10
  parse_slices_string,
 
12
  rename_split,
13
  slice_streams,
14
  )
15
+ from .stream import MultiStream
16
+
17
+
18
+ class Splitter(MultiStreamOperator):
19
+ pass
20
 
21
 
22
  class RenameSplits(Splitter):
 
37
 
38
 
39
  class SeparateSplit(Splitter):
40
+ """Separates a split (e.g. train) into several splits (e.g. train1, train2).
41
+
42
  sizes must indicate the size of every split except the last. If no size is give for the last split,
43
  it includes all the examples not allocated to any split.
44
  """
 
55
  return super().verify()
56
 
57
  def process(self, multi_stream: MultiStream) -> MultiStream:
58
+ mapping = {
59
+ key: {key: [(None, None)]}
60
+ for key in multi_stream.keys()
61
+ if key != self.from_split
62
+ }
63
  so_far = 0
64
+ for name, size in itertools.zip_longest(
65
+ self.to_split_names, self.to_split_sizes
66
+ ):
67
  mapping[name] = {self.from_split: [(so_far, size)]}
68
  if size:
69
  so_far += size
 
89
 
90
  def set_size(self, size):
91
  if isinstance(size, str):
92
+ assert (
93
+ size.isdigit()
94
+ ), f"sample_size must be a natural number, got {self.sample_size}"
95
  size = int(size)
96
  self.sample_size = size
97
 
98
  @abstractmethod
99
+ def sample(
100
+ self, instances_pool: List[Dict[str, object]]
101
+ ) -> List[Dict[str, object]]:
102
  pass
103
 
104
 
105
  class RandomSampler(Sampler):
106
+ def sample(
107
+ self, instances_pool: List[Dict[str, object]]
108
+ ) -> List[Dict[str, object]]:
109
  instances_pool = list(instances_pool)
110
+ return get_random().sample(instances_pool, self.sample_size)
111
 
112
 
113
  class DiverseLabelsSampler(Sampler):
 
118
  self.labels = None
119
 
120
  def examplar_repr(self, examplar):
121
+ if "inputs" not in examplar:
122
+ raise ValueError(f"'inputs' field is missing from '{examplar}'.")
123
+ inputs = examplar["inputs"]
124
+ if self.choices not in inputs:
125
+ raise ValueError(f"{self.choices} field is missing from '{inputs}'.")
126
+ choices = inputs[self.choices]
127
+ if not isinstance(choices, list):
128
+ raise ValueError(
129
+ f"Unexpected input choices value '{choices}'. Expected a list."
130
+ )
131
+
132
+ if "outputs" not in examplar:
133
+ raise ValueError(f"'outputs' field is missing from '{examplar}'.")
134
  examplar_outputs = next(iter(examplar["outputs"].values()))
135
+ if not isinstance(examplar_outputs, list):
136
+ raise ValueError(
137
+ f"Unexpected examplar_outputs value '{examplar_outputs}'. Expected a list."
138
+ )
139
+
140
+ return str([choice for choice in choices if choice in examplar_outputs])
141
 
142
  def divide_by_repr(self, examplars_pool):
143
+ labels = {}
144
  for examplar in examplars_pool:
145
  label_repr = self.examplar_repr(examplar)
146
  if label_repr not in labels:
 
148
  labels[label_repr].append(examplar)
149
  return labels
150
 
151
+ def sample(
152
+ self, instances_pool: List[Dict[str, object]]
153
+ ) -> List[Dict[str, object]]:
154
  if self.labels is None:
155
  self.labels = self.divide_by_repr(instances_pool)
156
  all_labels = list(self.labels.keys())
157
+ get_random().shuffle(all_labels)
158
  from collections import Counter
159
 
160
  total_allocated = 0
 
171
 
172
  result = []
173
  for label, allocation in allocations.items():
174
+ sample = get_random().sample(self.labels[label], allocation)
175
  result.extend(sample)
176
 
177
+ get_random().shuffle(result)
178
  return result
179
 
180
 
181
+ class SpreadSplit(InstanceOperatorWithMultiStreamAccess):
182
  source_stream: str = None
183
  target_field: str = None
184
  sampler: Sampler = None
185
 
186
  def prepare(self):
 
 
187
  self.local_cache = None
188
+ self.sampler.prepare()
189
 
190
  def verify(self):
191
  assert self.source_stream is not None, "Source stream must be specified"
 
193
  assert self.sampler is not None, "Sampler must be specified"
194
  return super().verify()
195
 
196
+ def process(
197
+ self, instance: Dict[str, object], multi_stream: MultiStream
198
+ ) -> Dict[str, object]:
199
+ try:
200
+ if self.local_cache is None:
201
+ self.local_cache = list(multi_stream[self.source_stream])
202
+
203
+ source_stream = self.local_cache
204
+
205
+ sampled_instances = self.sampler.sample(source_stream)
206
+ instance[self.target_field] = sampled_instances
207
+ return instance
208
+ except Exception as e:
209
+ raise Exception(
210
+ f"Unable to fetch instances from '{self.source_stream}' to '{self.target_field}'"
211
+ ) from e