File size: 9,302 Bytes
970dac4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7cdc7d0
970dac4
 
4aee30b
7cdc7d0
970dac4
4aee30b
970dac4
 
 
 
 
 
 
 
 
 
4aee30b
970dac4
 
 
7cdc7d0
970dac4
 
7cdc7d0
 
970dac4
4aee30b
970dac4
 
 
 
 
 
 
 
4aee30b
970dac4
 
4aee30b
970dac4
 
 
 
4aee30b
970dac4
 
4aee30b
970dac4
4aee30b
970dac4
 
 
 
 
7cdc7d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
"""Dialog Serializers.

Dialog serializers are the way to take dialog data and turn it into
text that can be fed to the model.

The format of the dialog is:

dialog = [
    {"user": "hello", "system": "hi"},
    {"user": "kkk", "system": ""},
    {"user": "kkk", "system": ""},
]
"""
from typing import Any, Dict, List, Optional

from .formats import SystemFormat
from .operators import InstanceFieldOperator


class SerializeDialog(InstanceFieldOperator):
    """Serializes dialog data for feeding into a model.

    This class takes structured dialog data and converts it into a text format
    according to a specified template. It allows for the inclusion or exclusion
    of system responses and can operate on a per-turn basis or aggregate the entire
    dialog.

    Attributes:
        field (str): The field in the input data that contains the dialog.
        to_field (Optional[str]): The field in the output data where the serialized dialog will be stored.
        last_user_turn_to_field (Optional[str]): Field to store the last user turn.
        last_system_turn_to_field (Optional[str]): Field to store the last system turn.
        context_field (Optional[str]): Field that contains additional context to be prepended to the dialog.
    """

    format: SystemFormat = None
    last_response_to_field: Optional[str] = None
    context_field: Optional[str] = None
    context_separator: str = " "
    slice_first_and_last_turns_format: bool = True

    def standardize_format(self, demo_format):
        turn_format = demo_format.replace("{source}", "{user}")
        turn_format = turn_format.replace("{target}", "{system}")
        return turn_format.replace("{target_prefix}", "")

    def slice_first_turn(self, turn_format):
        return turn_format[turn_format.index("{user}") :]

    def slice_last_turn(self, turn_format):
        return turn_format[: turn_format.index("{system}") + len("{system}")]

    def slice_last_response(self, turn_format):
        return turn_format[: turn_format.index("{user}") + len("{user}")]

    def get_turn_format(self, turn_format, step, length):
        if step == 0 and self.slice_first_and_last_turns_format:
            turn_format = self.slice_first_turn(turn_format)
        if step == length - 1:
            if self.slice_first_and_last_turns_format:
                turn_format = self.slice_last_turn(turn_format)
            if self.last_response_to_field is not None:
                turn_format = self.slice_last_response(turn_format)
        return turn_format

    def get_general_turn_format(self, instance):
        general_format = (
            instance["recipe_metadata"]["format"]
            if self.format is None
            else self.format
        )
        return self.standardize_format(general_format.demo_format)

    def process_instance_value(
        self, structured_dialog: List[Dict[str, str]], instance: Dict[str, Any]
    ):
        dialog = (
            ""
            if self.context_field is None
            else instance[self.context_field] + self.context_separator
        )
        general_turn_format = self.get_general_turn_format(instance)
        for i, turn in enumerate(structured_dialog):
            turn_format = self.get_turn_format(
                general_turn_format, i, len(structured_dialog)
            )
            dialog += turn_format.format(**turn)
        if self.last_response_to_field is not None:
            instance[self.last_response_to_field] = turn["system"]
        return dialog


class SerializeOpenAiFormatDialog(SerializeDialog):
    """Serializes dialog data for feeding into a model.

    This class takes structured dialog data in the OpenAi format, and converts it into a text format
    according to a specified template. It allows for the inclusion or exclusion
    of system responses and can operate on a per-turn basis or aggregate the entire
    dialog.

    Attributes:
        field (str): The field in the input data that contains the dialog.
        to_field (Optional[str]): The field in the output data where the serialized dialog will be stored.
        last_user_turn_to_field (Optional[str]): Field to store the last user turn.
        last_system_turn_to_field (Optional[str]): Field to store the last system turn.
        context_field (Optional[str]): Field that contains additional context to be prepended to the dialog.
    """

    is_last_turn_user_only: bool = True

    @staticmethod
    def validate_openai_dialog_format(dialog: List[Dict[str, str]]) -> None:
        """Validates that the given dialog follows the correct OpenAI format.

        The function checks that:
        1. The dialog is a list of dictionaries.
        2. Each dictionary contains the keys 'role' and 'content'.
        3. The 'role' value is either 'user' or 'assistant'.
        4. Both 'role' and 'content' values are strings.
        5. The first 'role' is 'user'

        If the dialog does not conform to the expected format, a descriptive
        ValueError is raised indicating the issue.

        Args:
            dialog (List[Dict[str, str]]): The dialog to validate.

        Raises:
            ValueError: If the dialog does not meet the format requirements.
        """
        if not isinstance(dialog, list):
            raise ValueError("Dialog must be a list of dictionaries.")

        for i, entry in enumerate(dialog):
            if not isinstance(entry, dict):
                raise ValueError(
                    f"Entry {i} is not a dictionary: {entry}. Each entry in the dialog must be a dictionary."
                )

            if "role" not in entry:
                raise ValueError(
                    f"Entry {i} is missing the 'role' key: {entry}. Each dictionary must have a 'role' key."
                )

            if "content" not in entry:
                raise ValueError(
                    f"Entry {i} is missing the 'content' key: {entry}. Each dictionary must have a 'content' key."
                )

            if not isinstance(entry["role"], str):
                raise ValueError(
                    f"Entry {i} has a non-string 'role': {entry['role']}. The 'role' value must be a string."
                )

            if not isinstance(entry["content"], str):
                raise ValueError(
                    f"Entry {i} has a non-string 'content': {entry['content']}. The 'content' value must be a string."
                )

            if entry["role"] not in {"user", "assistant"}:
                raise ValueError(
                    f"Entry {i} has an invalid role: {entry['role']}. Allowed roles are 'user' and 'assistant'."
                )

        first_entry = dialog[0]
        if first_entry["role"] != "user":
            raise ValueError(
                f"First entry role is expected to be 'user' It is  {first_entry['role']}."
            )

    @staticmethod
    def merge_dialog_entries(dialog: List[Dict[str, str]]) -> List[Dict[str, str]]:
        """Merges consecutive dialog entries with the same role.

        Args:
            dialog (List[Dict[str, str]]): The input dialog list where each dictionary has a 'role' and 'content'.

        Returns:
            List[Dict[str, str]]: A new list where consecutive entries with the same role are merged.
        """
        if len(dialog) == 0:
            return []

        merged_dialog = [dialog[0]]

        for entry in dialog[1:]:
            if entry["role"] == merged_dialog[-1]["role"]:
                merged_dialog[-1]["content"] += " " + entry["content"]
            else:
                merged_dialog.append(entry)

        return merged_dialog

    def transform_dialog_to_standard_format(
        self, dialog: List[Dict[str, str]]
    ) -> List[Dict[str, str]]:
        """Transforms a dialog from OpenAI format to a simplified format.

        Each dictionary
        contains 'user' and 'system' keys with their respective contents. Consecutive entries
        with the same role are merged. Entries with invalid roles raise an error.

        Args:
            dialog (List[Dict[str, str]]): The input dialog in OpenAI format.

        Returns:
            List[Dict[str, str]]: The transformed dialog.

        Raises:
            ValueError: If an invalid role is detected.
        """
        SerializeOpenAiFormatDialog.validate_openai_dialog_format(dialog)
        merged_dialog = SerializeOpenAiFormatDialog.merge_dialog_entries(dialog)
        # self.validate_dialog_have_complete_pairs(merged_dialog)

        result = []
        for i in range(0, len(merged_dialog) - 1, 2):
            user_entry = merged_dialog[i]
            system_entry = merged_dialog[i + 1]

            result.append(
                {"user": user_entry["content"], "system": system_entry["content"]}
            )
        if len(merged_dialog) % 2 != 0:
            user_entry = merged_dialog[-1]
            result.append({"user": user_entry["content"], "system": ""})

        return result

    def process_instance_value(
        self, structured_dialog: List[Dict[str, str]], instance: Dict[str, Any]
    ):
        standard_format_dialog = self.transform_dialog_to_standard_format(
            structured_dialog
        )
        return super().process_instance_value(standard_format_dialog, instance)