Elron commited on
Commit
5bbb99c
1 Parent(s): 80ef08e

Upload standard.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. standard.py +79 -5
standard.py CHANGED
@@ -1,10 +1,11 @@
1
  from typing import List
2
 
3
  from .card import TaskCard
4
- from .dataclass import InternalField
5
  from .formats import ICLFormat
6
  from .instructions import Instruction
7
  from .operator import SourceSequntialOperator, StreamingOperator
 
8
  from .recipe import Recipe
9
  from .renderers import StandardRenderer
10
  from .schema import ToUnitxtGroup
@@ -12,14 +13,24 @@ from .splitters import Sampler, SeparateSplit, SpreadSplit
12
  from .templates import Template
13
 
14
 
15
- class StandardRecipe(Recipe, SourceSequntialOperator):
16
  card: TaskCard
17
  template: Template = None
18
  instruction: Instruction = None
19
  format: ICLFormat = ICLFormat()
20
 
 
 
 
 
 
 
 
 
 
 
21
  demos_pool_size: int = None
22
- num_demos: int = None
23
 
24
  demos_pool_name: str = "demos_pool"
25
  demos_taken_from: str = "train"
@@ -28,6 +39,18 @@ class StandardRecipe(Recipe, SourceSequntialOperator):
28
 
29
  steps: List[StreamingOperator] = InternalField(default_factory=list)
30
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def prepare(self):
32
  self.steps = [
33
  self.card.loader,
@@ -47,7 +70,7 @@ class StandardRecipe(Recipe, SourceSequntialOperator):
47
  )
48
  )
49
 
50
- if self.num_demos is not None:
51
  sampler = self.card.sampler
52
 
53
  if self.sampler is not None:
@@ -63,6 +86,15 @@ class StandardRecipe(Recipe, SourceSequntialOperator):
63
  )
64
  )
65
 
 
 
 
 
 
 
 
 
 
66
  render = StandardRenderer(
67
  instruction=self.instruction,
68
  template=self.template,
@@ -83,7 +115,7 @@ class StandardRecipe(Recipe, SourceSequntialOperator):
83
  )
84
 
85
 
86
- class StandardRecipeWithIndexes(StandardRecipe):
87
  instruction_card_index: int = None
88
  template_card_index: int = None
89
 
@@ -101,3 +133,45 @@ class StandardRecipeWithIndexes(StandardRecipe):
101
  self.instruction = self.card.instructions[int(self.instruction_card_index)]
102
 
103
  super().prepare()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import List
2
 
3
  from .card import TaskCard
4
+ from .dataclass import InternalField, OptionalField
5
  from .formats import ICLFormat
6
  from .instructions import Instruction
7
  from .operator import SourceSequntialOperator, StreamingOperator
8
+ from .operators import StreamRefiner
9
  from .recipe import Recipe
10
  from .renderers import StandardRenderer
11
  from .schema import ToUnitxtGroup
 
13
  from .templates import Template
14
 
15
 
16
+ class BaseRecipe(Recipe, SourceSequntialOperator):
17
  card: TaskCard
18
  template: Template = None
19
  instruction: Instruction = None
20
  format: ICLFormat = ICLFormat()
21
 
22
+ max_train_instances: int = None
23
+ max_validation_instances: int = None
24
+ max_test_instances: int = None
25
+
26
+ train_refiner: StreamRefiner = OptionalField(default_factory=lambda: StreamRefiner(apply_to_streams=["train"]))
27
+ validation_refiner: StreamRefiner = OptionalField(
28
+ default_factory=lambda: StreamRefiner(apply_to_streams=["validation"])
29
+ )
30
+ test_refiner: StreamRefiner = OptionalField(default_factory=lambda: StreamRefiner(apply_to_streams=["test"]))
31
+
32
  demos_pool_size: int = None
33
+ num_demos: int = 0
34
 
35
  demos_pool_name: str = "demos_pool"
36
  demos_taken_from: str = "train"
 
39
 
40
  steps: List[StreamingOperator] = InternalField(default_factory=list)
41
 
42
+ def verify(self):
43
+ super().verify()
44
+ if self.num_demos > 0:
45
+ if self.demos_pool_size is None or self.demos_pool_size < 1:
46
+ raise ValueError(
47
+ "When using demonstrations both num_demos and demos_pool_size should be assigned with postive integers."
48
+ )
49
+ if self.demos_pool_size < self.num_demos:
50
+ raise ValueError(
51
+ f"demos_pool_size must be bigger than num_demos={self.num_demos}, Got demos_pool_size={self.demos_pool_size}"
52
+ )
53
+
54
  def prepare(self):
55
  self.steps = [
56
  self.card.loader,
 
70
  )
71
  )
72
 
73
+ if self.num_demos > 0:
74
  sampler = self.card.sampler
75
 
76
  if self.sampler is not None:
 
86
  )
87
  )
88
 
89
+ self.train_refiner.max_instances = self.max_train_instances
90
+ self.steps.append(self.train_refiner)
91
+
92
+ self.validation_refiner.max_instances = self.max_validation_instances
93
+ self.steps.append(self.validation_refiner)
94
+
95
+ self.test_refiner.max_instances = self.max_test_instances
96
+ self.steps.append(self.test_refiner)
97
+
98
  render = StandardRenderer(
99
  instruction=self.instruction,
100
  template=self.template,
 
115
  )
116
 
117
 
118
+ class StandardRecipeWithIndexes(BaseRecipe):
119
  instruction_card_index: int = None
120
  template_card_index: int = None
121
 
 
133
  self.instruction = self.card.instructions[int(self.instruction_card_index)]
134
 
135
  super().prepare()
136
+
137
+
138
+ class StandardRecipe(StandardRecipeWithIndexes):
139
+ """
140
+ This class represents a standard recipe for data processing and preperation.
141
+ This class can be used to prepare a recipe
142
+ with all necessary steps, refiners and renderers included. It allows to set various
143
+ parameters and steps in a sequential manner for preparing the recipe.
144
+
145
+ Attributes:
146
+ card (TaskCard): TaskCard object associated with the recipe.
147
+ template (Template, optional): Template object to be used for the recipe.
148
+ instruction (Instruction, optional): Instruction object to be used for the recipe.
149
+ format (ICLFormat, optional): ICLFormat object to be used for the recipe.
150
+ train_refiner (StreamRefiner, optional): Train refiner to be used in the recipe.
151
+ max_train_instances (int, optional): Maximum training instances for the refiner.
152
+ validation_refiner (StreamRefiner, optional): Validation refiner to be used in the recipe.
153
+ max_validation_instances (int, optional): Maximum validation instances for the refiner.
154
+ test_refiner (StreamRefiner, optional): Test refiner to be used in the recipe.
155
+ max_test_instances (int, optional): Maximum test instances for the refiner.
156
+ demos_pool_size (int, optional): Size of the demos pool.
157
+ num_demos (int, optional): Number of demos to be used.
158
+ demos_pool_name (str, optional): Name of the demos pool. Default is "demos_pool".
159
+ demos_taken_from (str, optional): Specifies from where the demos are taken. Default is "train".
160
+ demos_field (str, optional): Field name for demos. Default is "demos".
161
+ sampler (Sampler, optional): Sampler object to be used in the recipe.
162
+ steps (List[StreamingOperator], optional): List of StreamingOperator objects to be used in the recipe.
163
+ instruction_card_index (int, optional): Index of instruction card to be used
164
+ for preparing the recipe.
165
+ template_card_index (int, optional): Index of template card to be used for
166
+ preparing the recipe.
167
+
168
+ Methods:
169
+ prepare(): This overridden method is used for preparing the recipe
170
+ by arranging all the steps, refiners, and renderers in a sequential manner.
171
+
172
+ Raises:
173
+ AssertionError: If both template and template_card_index, or instruction and instruction_card_index
174
+ are specified at the same time.
175
+ """
176
+
177
+ pass