Elron commited on
Commit
df3c5b3
1 Parent(s): 9e30502

Upload struct_data_operators.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. struct_data_operators.py +364 -0
struct_data_operators.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This section describes unitxt operators for tabular data.
2
+
3
+ These operators are specialized in handling tabular data.
4
+ Input table format is assumed as:
5
+ {
6
+ "header": ["col1", "col2"],
7
+ "rows": [["row11", "row12"], ["row21", "row22"], ["row31", "row32"]]
8
+ }
9
+
10
+ ------------------------
11
+ """
12
+ import random
13
+ from abc import ABC, abstractmethod
14
+ from copy import deepcopy
15
+ from typing import (
16
+ Any,
17
+ Dict,
18
+ List,
19
+ Optional,
20
+ )
21
+
22
+ from .dict_utils import dict_get
23
+ from .operators import FieldOperator, StreamInstanceOperator
24
+
25
+
26
+ class SerializeTable(ABC, FieldOperator):
27
+ """TableSerializer converts a given table into a flat sequence with special symbols.
28
+
29
+ Output format varies depending on the chosen serializer. This abstract class defines structure of a typical table serializer that any concrete implementation should follow.
30
+ """
31
+
32
+ # main method to serialize a table
33
+ @abstractmethod
34
+ def serialize_table(self, table_content: Dict) -> str:
35
+ pass
36
+
37
+ # method to process table header
38
+ @abstractmethod
39
+ def process_header(self, header: List):
40
+ pass
41
+
42
+ # method to process a table row
43
+ @abstractmethod
44
+ def process_row(self, row: List, row_index: int):
45
+ pass
46
+
47
+
48
+ # Concrete classes implementing table serializers
49
+ class SerializeTableAsIndexedRowMajor(SerializeTable):
50
+ """Indexed Row Major Table Serializer.
51
+
52
+ Commonly used row major serialization format.
53
+ Format: col : col1 | col2 | col 3 row 1 : val1 | val2 | val3 | val4 row 2 : val1 | ...
54
+ """
55
+
56
+ def process_value(self, table: Any) -> Any:
57
+ table_input = deepcopy(table)
58
+ return self.serialize_table(table_content=table_input)
59
+
60
+ # main method that processes a table
61
+ # table_content must be in the presribed input format
62
+ def serialize_table(self, table_content: Dict) -> str:
63
+ # Extract headers and rows from the dictionary
64
+ header = table_content.get("header", [])
65
+ rows = table_content.get("rows", [])
66
+
67
+ assert header and rows, "Incorrect input table format"
68
+
69
+ # Process table header first
70
+ serialized_tbl_str = self.process_header(header) + " "
71
+
72
+ # Process rows sequentially starting from row 1
73
+ for i, row in enumerate(rows, start=1):
74
+ serialized_tbl_str += self.process_row(row, row_index=i) + " "
75
+
76
+ # return serialized table as a string
77
+ return serialized_tbl_str.strip()
78
+
79
+ # serialize header into a string containing the list of column names separated by '|' symbol
80
+ def process_header(self, header: List):
81
+ return "col : " + " | ".join(header)
82
+
83
+ # serialize a table row into a string containing the list of cell values separated by '|'
84
+ def process_row(self, row: List, row_index: int):
85
+ serialized_row_str = ""
86
+ row_cell_values = [
87
+ str(value) if isinstance(value, (int, float)) else value for value in row
88
+ ]
89
+
90
+ serialized_row_str += " | ".join(row_cell_values)
91
+
92
+ return f"row {row_index} : {serialized_row_str}"
93
+
94
+
95
+ class SerializeTableAsMarkdown(SerializeTable):
96
+ """Markdown Table Serializer.
97
+
98
+ Markdown table format is used in GitHub code primarily.
99
+ Format:
100
+ |col1|col2|col3|
101
+ |---|---|---|
102
+ |A|4|1|
103
+ |I|2|1|
104
+ ...
105
+ """
106
+
107
+ def process_value(self, table: Any) -> Any:
108
+ table_input = deepcopy(table)
109
+ return self.serialize_table(table_content=table_input)
110
+
111
+ # main method that serializes a table.
112
+ # table_content must be in the presribed input format.
113
+ def serialize_table(self, table_content: Dict) -> str:
114
+ # Extract headers and rows from the dictionary
115
+ header = table_content.get("header", [])
116
+ rows = table_content.get("rows", [])
117
+
118
+ assert header and rows, "Incorrect input table format"
119
+
120
+ # Process table header first
121
+ serialized_tbl_str = self.process_header(header)
122
+
123
+ # Process rows sequentially starting from row 1
124
+ for i, row in enumerate(rows, start=1):
125
+ serialized_tbl_str += self.process_row(row, row_index=i)
126
+
127
+ # return serialized table as a string
128
+ return serialized_tbl_str.strip()
129
+
130
+ # serialize header into a string containing the list of column names
131
+ def process_header(self, header: List):
132
+ header_str = "|{}|\n".format("|".join(header))
133
+ header_str += "|{}|\n".format("|".join(["---"] * len(header)))
134
+ return header_str
135
+
136
+ # serialize a table row into a string containing the list of cell values
137
+ def process_row(self, row: List, row_index: int):
138
+ row_str = ""
139
+ row_str += "|{}|\n".format("|".join(str(cell) for cell in row))
140
+ return row_str
141
+
142
+
143
+ # truncate cell value to maximum allowed length
144
+ def truncate_cell(cell_value, max_len):
145
+ if cell_value is None:
146
+ return None
147
+
148
+ if isinstance(cell_value, int) or isinstance(cell_value, float):
149
+ return None
150
+
151
+ if cell_value.strip() == "":
152
+ return None
153
+
154
+ if len(cell_value) > max_len:
155
+ return cell_value[:max_len]
156
+
157
+ return None
158
+
159
+
160
+ class TruncateTableCells(StreamInstanceOperator):
161
+ """Limit the maximum length of cell values in a table to reduce the overall length.
162
+
163
+ Args:
164
+ max_length (int) - maximum allowed length of cell values
165
+ For tasks that produce a cell value as answer, truncating a cell value should be replicated
166
+ with truncating the corresponding answer as well. This has been addressed in the implementation.
167
+
168
+ """
169
+
170
+ max_length: int = 15
171
+ table: str = None
172
+ text_output: Optional[str] = None
173
+ use_query: bool = False
174
+
175
+ def process(
176
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
177
+ ) -> Dict[str, Any]:
178
+ table = dict_get(instance, self.table, use_dpath=self.use_query)
179
+
180
+ answers = []
181
+ if self.text_output is not None:
182
+ answers = dict_get(instance, self.text_output, use_dpath=self.use_query)
183
+
184
+ self.truncate_table(table_content=table, answers=answers)
185
+
186
+ return instance
187
+
188
+ # truncate table cells
189
+ def truncate_table(self, table_content: Dict, answers: Optional[List]):
190
+ cell_mapping = {}
191
+
192
+ # One row at a time
193
+ for row in table_content.get("rows", []):
194
+ for i, cell in enumerate(row):
195
+ truncated_cell = truncate_cell(cell, self.max_length)
196
+ if truncated_cell is not None:
197
+ cell_mapping[cell] = truncated_cell
198
+ row[i] = truncated_cell
199
+
200
+ # Update values in answer list to truncated values
201
+ if answers is not None:
202
+ for i, case in enumerate(answers):
203
+ answers[i] = cell_mapping.get(case, case)
204
+
205
+
206
+ class TruncateTableRows(FieldOperator):
207
+ """Limits table rows to specified limit by removing excess rows via random selection.
208
+
209
+ Args:
210
+ rows_to_keep (int) - number of rows to keep.
211
+ """
212
+
213
+ rows_to_keep: int = 10
214
+
215
+ def process_value(self, table: Any) -> Any:
216
+ return self.truncate_table_rows(table_content=table)
217
+
218
+ def truncate_table_rows(self, table_content: Dict):
219
+ # Get rows from table
220
+ rows = table_content.get("rows", [])
221
+
222
+ num_rows = len(rows)
223
+
224
+ # if number of rows are anyway lesser, return.
225
+ if num_rows <= self.rows_to_keep:
226
+ return table_content
227
+
228
+ # calculate number of rows to delete, delete them
229
+ rows_to_delete = num_rows - self.rows_to_keep
230
+
231
+ # Randomly select rows to be deleted
232
+ deleted_rows_indices = random.sample(range(len(rows)), rows_to_delete)
233
+
234
+ remaining_rows = [
235
+ row for i, row in enumerate(rows) if i not in deleted_rows_indices
236
+ ]
237
+ table_content["rows"] = remaining_rows
238
+
239
+ return table_content
240
+
241
+
242
+ class SerializeTableRowAsText(StreamInstanceOperator):
243
+ """Serializes a table row as text.
244
+
245
+ Args:
246
+ fields (str) - list of fields to be included in serialization.
247
+ to_field (str) - serialized text field name.
248
+ max_cell_length (int) - limits cell length to be considered, optional.
249
+ """
250
+
251
+ fields: str
252
+ to_field: str
253
+ max_cell_length: Optional[int] = None
254
+
255
+ def process(
256
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
257
+ ) -> Dict[str, Any]:
258
+ linearized_str = ""
259
+ for field in self.fields:
260
+ value = dict_get(instance, field, use_dpath=False)
261
+ if self.max_cell_length is not None:
262
+ truncated_value = truncate_cell(value, self.max_cell_length)
263
+ if truncated_value is not None:
264
+ value = truncated_value
265
+
266
+ linearized_str = linearized_str + field + " is " + str(value) + ", "
267
+
268
+ instance[self.to_field] = linearized_str
269
+ return instance
270
+
271
+
272
+ class SerializeTableRowAsList(StreamInstanceOperator):
273
+ """Serializes a table row as list.
274
+
275
+ Args:
276
+ fields (str) - list of fields to be included in serialization.
277
+ to_field (str) - serialized text field name.
278
+ max_cell_length (int) - limits cell length to be considered, optional.
279
+ """
280
+
281
+ fields: str
282
+ to_field: str
283
+ max_cell_length: Optional[int] = None
284
+
285
+ def process(
286
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
287
+ ) -> Dict[str, Any]:
288
+ linearized_str = ""
289
+ for field in self.fields:
290
+ value = dict_get(instance, field, use_dpath=False)
291
+ if self.max_cell_length is not None:
292
+ truncated_value = truncate_cell(value, self.max_cell_length)
293
+ if truncated_value is not None:
294
+ value = truncated_value
295
+
296
+ linearized_str = linearized_str + field + ": " + str(value) + ", "
297
+
298
+ instance[self.to_field] = linearized_str
299
+ return instance
300
+
301
+
302
+ class SerializeTriples(FieldOperator):
303
+ """Serializes triples into a flat sequence.
304
+
305
+ Sample input in expected format:
306
+ [[ "First Clearing", "LOCATION", "On NYS 52 1 Mi. Youngsville" ], [ "On NYS 52 1 Mi. Youngsville", "CITY_OR_TOWN", "Callicoon, New York"]]
307
+
308
+ Sample output:
309
+ First Clearing : LOCATION : On NYS 52 1 Mi. Youngsville | On NYS 52 1 Mi. Youngsville : CITY_OR_TOWN : Callicoon, New York
310
+
311
+ """
312
+
313
+ def process_value(self, tripleset: Any) -> Any:
314
+ return self.serialize_triples(tripleset)
315
+
316
+ def serialize_triples(self, tripleset) -> str:
317
+ return " | ".join(
318
+ f"{subj} : {rel.lower()} : {obj}" for subj, rel, obj in tripleset
319
+ )
320
+
321
+
322
+ class SerializeKeyValPairs(FieldOperator):
323
+ """Serializes key, value pairs into a flat sequence.
324
+
325
+ Sample input in expected format: {"name": "Alex", "age": 31, "sex": "M"}
326
+ Sample output: name is Alex, age is 31, sex is M
327
+ """
328
+
329
+ def process_value(self, kvpairs: Any) -> Any:
330
+ return self.serialize_kvpairs(kvpairs)
331
+
332
+ def serialize_kvpairs(self, kvpairs) -> str:
333
+ serialized_str = ""
334
+ for key, value in kvpairs.items():
335
+ serialized_str += f"{key} is {value}, "
336
+
337
+ # Remove the trailing comma and space then return
338
+ return serialized_str[:-2]
339
+
340
+
341
+ class ListToKeyValPairs(StreamInstanceOperator):
342
+ """Maps list of keys and values into key:value pairs.
343
+
344
+ Sample input in expected format: {"keys": ["name", "age", "sex"], "values": ["Alex", 31, "M"]}
345
+ Sample output: {"name": "Alex", "age": 31, "sex": "M"}
346
+ """
347
+
348
+ fields: List[str]
349
+ to_field: str
350
+ use_query: bool = False
351
+
352
+ def process(
353
+ self, instance: Dict[str, Any], stream_name: Optional[str] = None
354
+ ) -> Dict[str, Any]:
355
+ keylist = dict_get(instance, self.fields[0], use_dpath=self.use_query)
356
+ valuelist = dict_get(instance, self.fields[1], use_dpath=self.use_query)
357
+
358
+ output_dict = {}
359
+ for key, value in zip(keylist, valuelist):
360
+ output_dict[key] = value
361
+
362
+ instance[self.to_field] = output_dict
363
+
364
+ return instance