Elron commited on
Commit
8d5bd0c
1 Parent(s): b74242e

Upload task.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. task.py +57 -2
task.py CHANGED
@@ -13,10 +13,65 @@ class FormTask(Tasker, StreamInstanceOperator):
13
  metrics: List[str]
14
 
15
  def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
16
- inputs = {key: instance[key] for key in self.inputs}
17
- outputs = {key: instance[key] for key in self.outputs}
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  return {
19
  "inputs": inputs,
20
  "outputs": outputs,
21
  "metrics": self.metrics,
22
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  metrics: List[str]
14
 
15
  def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
16
+ try:
17
+ inputs = {key: instance[key] for key in self.inputs}
18
+ except KeyError as e:
19
+ raise KeyError(
20
+ f"Unexpected input column names: {list(key for key in self.inputs if key not in instance)}"
21
+ f"\n available names:{list(instance.keys())}\n given input names:{self.inputs}"
22
+ )
23
+ try:
24
+ outputs = {key: instance[key] for key in self.outputs}
25
+ except KeyError as e:
26
+ raise KeyError(
27
+ f"Unexpected output column names: {list(key for key in self.inputs if key not in instance)}"
28
+ f" \n available names:{list(instance.keys())}\n given output names:{self.outputs}"
29
+ )
30
+
31
  return {
32
  "inputs": inputs,
33
  "outputs": outputs,
34
  "metrics": self.metrics,
35
  }
36
+
37
+
38
+ class MultipleChoiceTask(FormTask):
39
+ choices_field: str = "choices"
40
+ choices_separator: str = "\n"
41
+ enumeration_suffix: str = ". "
42
+ use_text_in_target: bool = False
43
+ alphabet: str = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
44
+
45
+ def process_single_choice(self, choice: str, index: int, use_text: bool = True) -> str:
46
+ try:
47
+ processed_choice = f"{self.alphabet[index]}"
48
+ except IndexError:
49
+ raise ValueError(
50
+ f"Too many choices, the length of alphabet '{self.alphabet}': {len(self.alphabet)} is the limit"
51
+ )
52
+ if use_text:
53
+ processed_choice += f"{self.enumeration_suffix}{choice}"
54
+ return processed_choice
55
+
56
+ def process_choices(self, choices: List[str]) -> str:
57
+ processed_choices = []
58
+ for index, choice in enumerate(choices):
59
+ processed_choices.append(self.process_single_choice(choice, index))
60
+ return self.choices_separator.join(processed_choices)
61
+
62
+ def process_target(self, choices, target_index):
63
+ return self.process_single_choice(choices[target_index], target_index, use_text=self.use_text_in_target)
64
+
65
+ def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
66
+ result = super().process(instance, stream_name)
67
+ target_key, target_value = next(iter(result["outputs"].items()))
68
+ choices = result["inputs"][self.choices_field]
69
+ target_index_in_choices = choices.index(target_value)
70
+
71
+ processed_choices = self.process_choices(choices)
72
+ processed_target = self.process_target(choices, target_index_in_choices)
73
+
74
+ result["inputs"][self.choices_field] = processed_choices
75
+ result["outputs"][target_key] = processed_target
76
+
77
+ return result