File size: 3,682 Bytes
00a2077
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f634e4
00a2077
 
1f634e4
 
 
 
 
 
 
 
 
 
00a2077
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f634e4
 
00a2077
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from abc import ABC, abstractmethod
from typing import Any, Dict, List

from .dataclass import InternalField
from .formats import Format, ICLFormat
from .instructions import Instruction
from .operator import Operator, SequntialOperator, StreamInstanceOperator
from .random_utils import random
from .templates import Template


class Renderer(ABC):
    pass
    # @abstractmethod
    # def get_postprocessors(self) -> List[str]:
    #     pass


class RenderTemplate(Renderer, StreamInstanceOperator):
    template: Template
    random_reference: bool = False
    skip_rendered_instance: bool = True

    def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
        if self.skip_rendered_instance:
            if (
                "inputs" not in instance
                and "outputs" not in instance
                and "source" in instance
                and "target" in instance
                and "references" in instance
            ):
                return instance

        inputs = instance.pop("inputs")
        outputs = instance.pop("outputs")

        source = self.template.process_inputs(inputs)
        targets = self.template.process_outputs(outputs)

        if self.template.is_multi_reference:
            references = targets
            if self.random_reference:
                target = random.choice(references)
            else:
                if len(references) == 0:
                    raise ValueError("No references found")
                target = references[0]
        else:
            references = [targets]
            target = targets

        instance.update(
            {
                "source": source,
                "target": target,
                "references": references,
            }
        )

        return instance


class RenderDemonstrations(RenderTemplate):
    demos_field: str

    def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
        demos = instance.get(self.demos_field, [])

        processed_demos = []
        for demo_instance in demos:
            demo_instance = super().process(demo_instance)
            processed_demos.append(demo_instance)

        instance[self.demos_field] = processed_demos

        return instance


class RenderInstruction(Renderer, StreamInstanceOperator):
    instruction: Instruction

    def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
        instance["instruction"] = self.instruction()
        return instance


class RenderFormat(Renderer, StreamInstanceOperator):
    format: Format
    demos_field: str = None

    def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
        demos_instances = instance.pop(self.demos_field, None)
        if demos_instances is not None:
            instance["source"] = self.format.format(instance, demos_instances=demos_instances)
        else:
            instance["source"] = self.format.format(instance)
        return instance


class StandardRenderer(Renderer, SequntialOperator):
    template: Template
    instruction: Instruction = None
    demos_field: str = None
    format: ICLFormat = None

    steps: List[Operator] = InternalField(default_factory=list)

    def prepare(self):
        self.steps = [
            RenderTemplate(template=self.template),
            RenderDemonstrations(template=self.template, demos_field=self.demos_field),
            RenderInstruction(instruction=self.instruction),
            RenderFormat(format=self.format, demos_field=self.demos_field),
        ]

    def get_postprocessors(self):
        return self.template.get_postprocessors()