File size: 3,325 Bytes
00a2077
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from abc import ABC, abstractmethod
from typing import Any, Dict, List

from .dataclass import InternalField
from .formats import Format, ICLFormat
from .instructions import Instruction
from .operator import Operator, SequntialOperator, StreamInstanceOperator
from .random_utils import random
from .templates import Template


class Renderer(ABC):
    pass
    # @abstractmethod
    # def get_postprocessors(self) -> List[str]:
    #     pass


class RenderTemplate(Renderer, StreamInstanceOperator):
    template: Template
    random_reference: bool = False

    def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
        inputs = instance.pop("inputs")
        outputs = instance.pop("outputs")

        source = self.template.process_inputs(inputs)
        targets = self.template.process_outputs(outputs)

        if self.template.is_multi_reference:
            references = targets
            if self.random_reference:
                target = random.choice(references)
            else:
                if len(references) == 0:
                    raise ValueError("No references found")
                target = references[0]
        else:
            references = [targets]
            target = targets

        instance.update(
            {
                "source": source,
                "target": target,
                "references": references,
            }
        )

        return instance


class RenderDemonstrations(RenderTemplate):
    demos_field: str

    def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
        demos = instance.get(self.demos_field, [])

        processed_demos = []
        for demo_instance in demos:
            processed_demo = super().process(demo_instance)
            processed_demos.append(processed_demo)

        instance[self.demos_field] = processed_demos

        return instance


class RenderInstruction(Renderer, StreamInstanceOperator):
    instruction: Instruction

    def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
        instance["instruction"] = self.instruction()
        return instance


class RenderFormat(Renderer, StreamInstanceOperator):
    format: Format
    demos_field: str = None

    def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
        demos_instances = instance.pop(self.demos_field, None)
        if demos_instances is not None:
            instance["source"] = self.format.format(instance, demos_instances=demos_instances)
        else:
            instance["source"] = self.format.format(instance)
        return instance


class StandardRenderer(Renderer, SequntialOperator):
    template: Template
    instruction: Instruction = None
    demos_field: str = None
    format: ICLFormat = None

    steps: List[Operator] = InternalField(default_factory=list)

    def prepare(self):
        self.steps = [
            RenderTemplate(template=self.template),
            RenderDemonstrations(template=self.template, demos_field=self.demos_field),
            RenderInstruction(instruction=self.instruction),
            RenderFormat(format=self.format, demos_field=self.demos_field),
        ]

    def get_postprocessors(self):
        return self.template.get_postprocessors()