Elron commited on
Commit
b0144ee
1 Parent(s): 59f2fa9

Upload common.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. common.py +10 -5
common.py CHANGED
@@ -2,10 +2,11 @@ from typing import Union
2
 
3
  from .card import TaskCard
4
  from .collections import ItemPicker, RandomPicker
 
5
  from .operator import SourceOperator
6
  from .recipe import Recipe, SequentialRecipe
7
  from .schema import ToUnitxtGroup
8
- from .splitters import RandomSampler, SeparateSplit, SliceSplit, SpreadSplit
9
  from .stream import MultiStream
10
  from .templates import RenderTemplatedICL
11
 
@@ -17,12 +18,12 @@ class CommonRecipe(Recipe, SourceOperator):
17
  demos_pool_size: int = None
18
  demos_field: str = "demos"
19
  num_demos: int = None
20
- sampler_type: str = "random"
21
  instruction_item: Union[str, int] = None
22
  template_item: Union[str, int] = None
23
 
24
  def verify(self):
25
- assert self.sampler_type in ["random"], f"Uknown sampler type {self.sampler_type}"
26
 
27
  def prepare(self):
28
  steps = [
@@ -44,8 +45,12 @@ class CommonRecipe(Recipe, SourceOperator):
44
  )
45
 
46
  if self.num_demos is not None:
47
- if self.sampler_type == "random":
48
- sampler = RandomSampler(sample_size=int(self.num_demos))
 
 
 
 
49
 
50
  steps.append(
51
  SpreadSplit(
 
2
 
3
  from .card import TaskCard
4
  from .collections import ItemPicker, RandomPicker
5
+ from .dataclass import OptionalField
6
  from .operator import SourceOperator
7
  from .recipe import Recipe, SequentialRecipe
8
  from .schema import ToUnitxtGroup
9
+ from .splitters import RandomSampler, Sampler, SeparateSplit, SliceSplit, SpreadSplit
10
  from .stream import MultiStream
11
  from .templates import RenderTemplatedICL
12
 
 
18
  demos_pool_size: int = None
19
  demos_field: str = "demos"
20
  num_demos: int = None
21
+ sampler: Sampler = None
22
  instruction_item: Union[str, int] = None
23
  template_item: Union[str, int] = None
24
 
25
  def verify(self):
26
+ super().verify()
27
 
28
  def prepare(self):
29
  steps = [
 
45
  )
46
 
47
  if self.num_demos is not None:
48
+ sampler = self.card.sampler
49
+
50
+ if self.sampler is not None:
51
+ sampler = self.sampler
52
+
53
+ sampler.set_size(self.num_demos)
54
 
55
  steps.append(
56
  SpreadSplit(