File size: 2,953 Bytes
b5acc40
 
 
 
 
 
 
 
 
 
 
 
 
87b5f6e
b5acc40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87b5f6e
 
b5acc40
 
 
 
 
 
 
 
87b5f6e
b5acc40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368a37d
b5acc40
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from .stream import MultiStream
from .operator import SourceOperator
from .card import TaskCard
from .splitters import SliceSplit, SpreadSplit, RandomSampler
from .recipe import SequentialRecipe, Recipe
from .collections import ItemPicker, RandomPicker
from .templates import RenderTemplatedICL
from .schema import ToUnitxtGroup

from typing import Union


class CommonRecipe(Recipe, SourceOperator):
    
    card: TaskCard
    demos_pool_name: str = "demos_pool"
    demos_pool_size: int = None
    demos_field: str = "demos"
    num_demos: int = None
    sampler_type: str = "random"
    instruction_item: Union[str, int] = None
    template_item: Union[str, int] = None

    def verify(self):
        self.sampler_type in ["random"]

    def prepare(self):
        steps = [
            self.card.loader,
        ]

        if self.card.preprocess_steps is not None:
            steps.extend(self.card.preprocess_steps)

        steps.append(self.card.task)

        if self.demos_pool_size is not None:
            steps.append(
                SliceSplit(
                    slices={
                        self.demos_pool_name: f"train[:{int(self.demos_pool_size)}]",
                        "train": f"train[{int(self.demos_pool_size)}:]",
                        "validation": "validation",
                        "test": "test",
                    }
                )
            )

        if self.num_demos is not None:
            if self.sampler_type == "random":
                sampler = RandomSampler(sample_size=int(self.num_demos))

            steps.append(
                SpreadSplit(
                    source_stream=self.demos_pool_name,
                    target_field=self.demos_field,
                    sampler=sampler,
                )
            )

        if self.card.instructions is not None:
            if self.instruction_item is None:
                picker = ItemPicker(self.instruction_item)
            else:
                picker = RandomPicker()
            instruction = picker(self.card.instructions)
        else:
            instruction = None

        if self.card.templates is not None:
            if self.template_item is None:
                picker = ItemPicker(self.template_item)
            else:
                picker = RandomPicker()
            template = picker(self.card.templates)
        else:
            template = None

        render = RenderTemplatedICL(
            instruction=instruction,
            template=template,
            demos_field=self.demos_field,
        )

        steps.append(render)

        postprocessors = render.get_postprocessors()

        steps.append(
            ToUnitxtGroup(
                group="unitxt",
                metrics=self.card.task.metrics,
                postprocessors=postprocessors,
            )
        )

        self.recipe = SequentialRecipe(steps)

    def process(self) -> MultiStream:
        return self.recipe()