File size: 2,365 Bytes
c6e9c8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from typing import List

from .card import TaskCard
from .dataclass import InternalField
from .formats import ICLFormat
from .instructions import Instruction
from .operator import SourceSequntialOperator, StreamingOperator
from .recipe import Recipe
from .renderers import StandardRenderer
from .schema import ToUnitxtGroup
from .splitters import Sampler, SeparateSplit, SpreadSplit
from .templates import Template


class StandardRecipe(Recipe, SourceSequntialOperator):
    card: TaskCard
    template: Template
    instruction: Instruction = None
    format: ICLFormat = None

    demos_pool_size: int = None
    num_demos: int = None

    demos_pool_name: str = "demos_pool"
    demos_taken_from: str = "train"
    demos_field: str = "demos"
    sampler: Sampler = None

    steps: List[StreamingOperator] = InternalField(default_factory=list)

    def prepare(self):
        self.steps = [
            self.card.loader,
        ]

        if self.card.preprocess_steps is not None:
            self.steps.extend(self.card.preprocess_steps)

        self.steps.append(self.card.task)

        if self.demos_pool_size is not None:
            self.steps.append(
                SeparateSplit(
                    from_split=self.demos_taken_from,
                    to_split_names=[self.demos_pool_name, self.demos_taken_from],
                    to_split_sizes=[int(self.demos_pool_size)],
                )
            )

        if self.num_demos is not None:
            sampler = self.card.sampler

            if self.sampler is not None:
                sampler = self.sampler

            sampler.set_size(self.num_demos)

            self.steps.append(
                SpreadSplit(
                    source_stream=self.demos_pool_name,
                    target_field=self.demos_field,
                    sampler=sampler,
                )
            )

        render = StandardRenderer(
            instruction=self.instruction,
            template=self.template,
            format=self.format,
            demos_field=self.demos_field,
        )

        self.steps.append(render)

        postprocessors = render.get_postprocessors()

        self.steps.append(
            ToUnitxtGroup(
                group="unitxt",
                metrics=self.card.task.metrics,
                postprocessors=postprocessors,
            )
        )