from typing import List from .card import TaskCard from .dataclass import InternalField from .formats import ICLFormat from .instructions import Instruction from .operator import SourceSequntialOperator, StreamingOperator from .recipe import Recipe from .renderers import StandardRenderer from .schema import ToUnitxtGroup from .splitters import Sampler, SeparateSplit, SpreadSplit from .templates import Template class StandardRecipe(Recipe, SourceSequntialOperator): card: TaskCard template: Template instruction: Instruction = None format: ICLFormat = None demos_pool_size: int = None num_demos: int = None demos_pool_name: str = "demos_pool" demos_taken_from: str = "train" demos_field: str = "demos" sampler: Sampler = None steps: List[StreamingOperator] = InternalField(default_factory=list) def prepare(self): self.steps = [ self.card.loader, ] if self.card.preprocess_steps is not None: self.steps.extend(self.card.preprocess_steps) self.steps.append(self.card.task) if self.demos_pool_size is not None: self.steps.append( SeparateSplit( from_split=self.demos_taken_from, to_split_names=[self.demos_pool_name, self.demos_taken_from], to_split_sizes=[int(self.demos_pool_size)], ) ) if self.num_demos is not None: sampler = self.card.sampler if self.sampler is not None: sampler = self.sampler sampler.set_size(self.num_demos) self.steps.append( SpreadSplit( source_stream=self.demos_pool_name, target_field=self.demos_field, sampler=sampler, ) ) render = StandardRenderer( instruction=self.instruction, template=self.template, format=self.format, demos_field=self.demos_field, ) self.steps.append(render) postprocessors = render.get_postprocessors() self.steps.append( ToUnitxtGroup( group="unitxt", metrics=self.card.task.metrics, postprocessors=postprocessors, ) )