Elron commited on
Commit
59be457
1 Parent(s): 88f4dd2

Upload folder using huggingface_hub

Browse files
Files changed (19) hide show
  1. _version.py +0 -1
  2. app.py +0 -3
  3. common.py +0 -104
  4. dataset.py +3 -7
  5. hf_utils.py +2 -2
  6. inference.py +124 -2
  7. load.py +0 -15
  8. metric.py +1 -6
  9. metrics.py +4 -1
  10. operators.py +60 -0
  11. processors.py +8 -0
  12. renderers.py +0 -132
  13. serializers.py +0 -130
  14. standard.py +6 -3
  15. task.py +68 -18
  16. templates.py +1 -1
  17. tests.py +0 -17
  18. type_utils.py +31 -0
  19. version.py +1 -1
_version.py DELETED
@@ -1 +0,0 @@
1
- def get_current_version(): return '1.0.31'
 
 
app.py DELETED
@@ -1,3 +0,0 @@
1
- from unitxt.ui import launch
2
-
3
- launch()
 
 
 
 
common.py DELETED
@@ -1,104 +0,0 @@
1
- from typing import Union
2
-
3
- from .card import TaskCard
4
- from .collections import ItemPicker, RandomPicker
5
- from .dataclass import OptionalField
6
- from .operator import SourceOperator
7
- from .recipe import Recipe, SequentialRecipe
8
- from .schema import ToUnitxtGroup
9
- from .splitters import RandomSampler, Sampler, SeparateSplit, SliceSplit, SpreadSplit
10
- from .stream import MultiStream
11
- from .templates import RenderTemplatedICL
12
-
13
-
14
- class CommonRecipe(Recipe, SourceOperator):
15
- card: TaskCard
16
- demos_pool_name: str = "demos_pool"
17
- demos_taken_from: str = "train"
18
- demos_pool_size: int = None
19
- demos_field: str = "demos"
20
- num_demos: int = None
21
- sampler: Sampler = None
22
- instruction_item: Union[str, int] = None
23
- template_item: Union[str, int] = None
24
- system_prompt: str = None
25
-
26
- def verify(self):
27
- super().verify()
28
-
29
- def prepare(self):
30
- steps = [
31
- self.card.loader,
32
- ]
33
-
34
- if self.card.preprocess_steps is not None:
35
- steps.extend(self.card.preprocess_steps)
36
-
37
- steps.append(self.card.task)
38
-
39
- if self.demos_pool_size is not None:
40
- steps.append(
41
- SeparateSplit(
42
- from_split=self.demos_taken_from,
43
- to_split_names=[self.demos_pool_name, self.demos_taken_from],
44
- to_split_sizes=[int(self.demos_pool_size)],
45
- )
46
- )
47
-
48
- if self.num_demos is not None:
49
- sampler = self.card.sampler
50
-
51
- if self.sampler is not None:
52
- sampler = self.sampler
53
-
54
- sampler.set_size(self.num_demos)
55
-
56
- steps.append(
57
- SpreadSplit(
58
- source_stream=self.demos_pool_name,
59
- target_field=self.demos_field,
60
- sampler=sampler,
61
- )
62
- )
63
-
64
- if self.card.instructions is not None:
65
- if not self.instruction_item is None:
66
- picker = ItemPicker(int(self.instruction_item))
67
- else:
68
- picker = RandomPicker()
69
- instruction = picker(self.card.instructions)
70
- else:
71
- instruction = None
72
-
73
- if self.card.templates is not None:
74
- if self.template_item is None:
75
- picker = RandomPicker()
76
- else:
77
- picker = ItemPicker(self.template_item)
78
- template = picker(self.card.templates)
79
- else:
80
- template = None
81
-
82
- render = RenderTemplatedICL(
83
- instruction=instruction,
84
- template=template,
85
- demos_field=self.demos_field,
86
- system_prompt=self.system_prompt,
87
- )
88
-
89
- steps.append(render)
90
-
91
- postprocessors = render.get_postprocessors()
92
-
93
- steps.append(
94
- ToUnitxtGroup(
95
- group="unitxt",
96
- metrics=self.card.task.metrics,
97
- postprocessors=postprocessors,
98
- )
99
- )
100
-
101
- self.recipe = SequentialRecipe(steps)
102
-
103
- def process(self) -> MultiStream:
104
- return self.recipe()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dataset.py CHANGED
@@ -10,7 +10,6 @@ from .catalog import __file__ as _
10
  from .collections import __file__ as _
11
  from .collections_operators import __file__ as _
12
  from .dataclass import __file__ as _
13
- from .dataset_utils import __file__ as _
14
  from .dataset_utils import get_dataset_artifact
15
  from .deprecation_utils import __file__ as _
16
  from .dialog_operators import __file__ as _
@@ -20,13 +19,11 @@ from .file_utils import __file__ as _
20
  from .formats import __file__ as _
21
  from .fusion import __file__ as _
22
  from .generator_utils import __file__ as _
23
- from .hf_utils import __file__ as _
24
  from .hf_utils import verify_versions_compatibility
25
  from .inference import __file__ as _
26
  from .instructions import __file__ as _
27
  from .llm_as_judge import __file__ as _
28
  from .loaders import __file__ as _
29
- from .logging_utils import __file__ as _
30
  from .logging_utils import get_logger
31
  from .metric import __file__ as _
32
  from .metric_utils import __file__ as _
@@ -40,7 +37,6 @@ from .random_utils import __file__ as _
40
  from .recipe import __file__ as _
41
  from .register import __file__ as _
42
  from .schema import __file__ as _
43
- from .settings_utils import __file__ as _
44
  from .settings_utils import get_constants
45
  from .span_lableing_operators import __file__ as _
46
  from .split_utils import __file__ as _
@@ -54,7 +50,6 @@ from .task import __file__ as _
54
  from .templates import __file__ as _
55
  from .text_utils import __file__ as _
56
  from .type_utils import __file__ as _
57
- from .utils import __file__ as _
58
  from .utils import is_package_installed
59
  from .validate import __file__ as _
60
  from .version import __file__ as _
@@ -75,8 +70,9 @@ class Dataset(datasets.GeneratorBasedBuilder):
75
  if is_package_installed("unitxt"):
76
  verify_versions_compatibility("dataset", self.VERSION)
77
 
78
- from unitxt.dataset_utils import \
79
- get_dataset_artifact as get_dataset_artifact_installed
 
80
 
81
  logger.info("Loading with installed unitxt library...")
82
  dataset = get_dataset_artifact_installed(self.config.name)
 
10
  from .collections import __file__ as _
11
  from .collections_operators import __file__ as _
12
  from .dataclass import __file__ as _
 
13
  from .dataset_utils import get_dataset_artifact
14
  from .deprecation_utils import __file__ as _
15
  from .dialog_operators import __file__ as _
 
19
  from .formats import __file__ as _
20
  from .fusion import __file__ as _
21
  from .generator_utils import __file__ as _
 
22
  from .hf_utils import verify_versions_compatibility
23
  from .inference import __file__ as _
24
  from .instructions import __file__ as _
25
  from .llm_as_judge import __file__ as _
26
  from .loaders import __file__ as _
 
27
  from .logging_utils import get_logger
28
  from .metric import __file__ as _
29
  from .metric_utils import __file__ as _
 
37
  from .recipe import __file__ as _
38
  from .register import __file__ as _
39
  from .schema import __file__ as _
 
40
  from .settings_utils import get_constants
41
  from .span_lableing_operators import __file__ as _
42
  from .split_utils import __file__ as _
 
50
  from .templates import __file__ as _
51
  from .text_utils import __file__ as _
52
  from .type_utils import __file__ as _
 
53
  from .utils import is_package_installed
54
  from .validate import __file__ as _
55
  from .version import __file__ as _
 
70
  if is_package_installed("unitxt"):
71
  verify_versions_compatibility("dataset", self.VERSION)
72
 
73
+ from unitxt.dataset_utils import (
74
+ get_dataset_artifact as get_dataset_artifact_installed,
75
+ )
76
 
77
  logger.info("Loading with installed unitxt library...")
78
  dataset = get_dataset_artifact_installed(self.config.name)
hf_utils.py CHANGED
@@ -24,9 +24,9 @@ class UnitxtVersionsConflictError(ValueError):
24
  def __init__(self, error_in: str, hf_unitxt_version, installed_unitxt_version):
25
  assert hf_unitxt_version != installed_unitxt_version
26
  if compare_versions(hf_unitxt_version, installed_unitxt_version) == 1:
27
- msg = f"Located installed unitxt version {installed_unitxt_version} that is older then unitxt {error_in} version {hf_unitxt_version}. Please update unitxt package or uninstall it to avoid conflicts."
28
  if compare_versions(hf_unitxt_version, installed_unitxt_version) == -1:
29
- msg = f"Located installed unitxt version {installed_unitxt_version} that is newer then unitxt {error_in} version {hf_unitxt_version}. Please force-reload the {error_in} or downgrade unitxt to {error_in} version or uninstall unitxt to avoid conflicts."
30
  super().__init__(msg)
31
 
32
 
 
24
  def __init__(self, error_in: str, hf_unitxt_version, installed_unitxt_version):
25
  assert hf_unitxt_version != installed_unitxt_version
26
  if compare_versions(hf_unitxt_version, installed_unitxt_version) == 1:
27
+ msg = f"Located installed unitxt version {installed_unitxt_version} that is older than unitxt {error_in} version {hf_unitxt_version}. Please update unitxt package or uninstall it to avoid conflicts."
28
  if compare_versions(hf_unitxt_version, installed_unitxt_version) == -1:
29
+ msg = f"Located installed unitxt version {installed_unitxt_version} that is newer than unitxt {error_in} version {hf_unitxt_version}. Please force-reload the {error_in} or downgrade unitxt to {error_in} version or uninstall unitxt to avoid conflicts."
30
  super().__init__(msg)
31
 
32
 
inference.py CHANGED
@@ -1,6 +1,11 @@
1
  import abc
 
 
 
2
 
3
  from .artifact import Artifact
 
 
4
 
5
 
6
  class InferenceEngine(abc.ABC, Artifact):
@@ -11,12 +16,21 @@ class InferenceEngine(abc.ABC, Artifact):
11
  """Perform inference on the input dataset."""
12
  pass
13
 
 
 
 
 
 
 
 
14
 
15
- class HFPipelineBasedInferenceEngine(Artifact):
16
- """Abstract base class for inference."""
17
 
 
18
  model_name: str
19
  max_new_tokens: int
 
 
 
20
 
21
  def prepare(self):
22
  from transformers import pipeline
@@ -31,3 +45,111 @@ class HFPipelineBasedInferenceEngine(Artifact):
31
  max_new_tokens=self.max_new_tokens,
32
  )
33
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import abc
2
+ import os
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Union
5
 
6
  from .artifact import Artifact
7
+ from .operator import PackageRequirementsMixin
8
+ from .settings_utils import get_settings
9
 
10
 
11
  class InferenceEngine(abc.ABC, Artifact):
 
16
  """Perform inference on the input dataset."""
17
  pass
18
 
19
+ @staticmethod
20
+ def _assert_allow_passing_data_to_remote_api(remote_api_label: str):
21
+ assert get_settings().allow_passing_data_to_remote_api, (
22
+ f"LlmAsJudge metric cannot run send data to remote APIs ({remote_api_label}) when"
23
+ f" unitxt.settings.allow_passing_data_to_remote_api=False."
24
+ f" Set UNITXT_ALLOW_PASSING_DATA_TO_REMOTE_API environment variable, if you want to allow this. "
25
+ )
26
 
 
 
27
 
28
+ class HFPipelineBasedInferenceEngine(InferenceEngine, PackageRequirementsMixin):
29
  model_name: str
30
  max_new_tokens: int
31
+ _requirement = {
32
+ "transformers": "Install huggingface package using 'pip install --upgrade transformers"
33
+ }
34
 
35
  def prepare(self):
36
  from transformers import pipeline
 
45
  max_new_tokens=self.max_new_tokens,
46
  )
47
  ]
48
+
49
+
50
+ @dataclass()
51
+ class IbmGenAiInferenceEngineParams:
52
+ decoding_method: str = None
53
+ max_new_tokens: Optional[int] = None
54
+ min_new_tokens: Optional[int] = None
55
+ random_seed: Optional[int] = None
56
+ repetition_penalty: Optional[float] = None
57
+ stop_sequences: Optional[List[str]] = None
58
+ temperature: Optional[float] = None
59
+ top_k: Optional[int] = None
60
+ top_p: Optional[float] = None
61
+ typical_p: Optional[float] = None
62
+
63
+
64
+ class IbmGenAiInferenceEngine(InferenceEngine, PackageRequirementsMixin):
65
+ label: str = "ibm_genai"
66
+ model_name: str
67
+ parameters: IbmGenAiInferenceEngineParams = IbmGenAiInferenceEngineParams()
68
+ _requirement = {
69
+ "genai": "Install ibm-genai package using 'pip install --upgrade ibm-generative-ai"
70
+ }
71
+
72
+ def prepare(self):
73
+ from genai import Client, Credentials
74
+
75
+ api_key_env_var_name = "GENAI_KEY"
76
+ api_key = os.environ.get(api_key_env_var_name)
77
+ assert api_key is not None, (
78
+ f"Error while trying to run IbmGenAiInferenceEngine."
79
+ f" Please set the environment param '{api_key_env_var_name}'."
80
+ )
81
+ api_endpoint = os.environ.get("GENAI_KEY")
82
+ credentials = Credentials(api_key=api_key, api_endpoint=api_endpoint)
83
+ self.client = Client(credentials=credentials)
84
+
85
+ self._assert_allow_passing_data_to_remote_api(self.label)
86
+
87
+ def infer(self, dataset):
88
+ from genai.schema import TextGenerationParameters
89
+
90
+ genai_params = TextGenerationParameters(**self.parameters.__dict__)
91
+ return list(
92
+ self.client.text.generation.create(
93
+ model_id=self.model_name,
94
+ inputs=[instance["source"] for instance in dataset],
95
+ parameters=genai_params,
96
+ )
97
+ )
98
+
99
+
100
+ @dataclass
101
+ class OpenAiInferenceEngineParams:
102
+ frequency_penalty: Optional[float] = None
103
+ presence_penalty: Optional[float] = None
104
+ max_tokens: Optional[int] = None
105
+ seed: Optional[int] = None
106
+ stop: Union[Optional[str], List[str]] = None
107
+ temperature: Optional[float] = None
108
+ top_p: Optional[float] = None
109
+
110
+
111
+ class OpenAiInferenceEngine(InferenceEngine, PackageRequirementsMixin):
112
+ label: str = "openai"
113
+ model_name: str
114
+ parameters: OpenAiInferenceEngineParams = OpenAiInferenceEngineParams()
115
+ _requirement = {
116
+ "openai": "Install openai package using 'pip install --upgrade openai"
117
+ }
118
+
119
+ def prepare(self):
120
+ from openai import OpenAI
121
+
122
+ api_key_env_var_name = "OPENAI_API_KEY"
123
+ api_key = os.environ.get(api_key_env_var_name)
124
+ assert api_key is not None, (
125
+ f"Error while trying to run OpenAiInferenceEngine."
126
+ f" Please set the environment param '{api_key_env_var_name}'."
127
+ )
128
+
129
+ self.client = OpenAI(api_key=api_key)
130
+ self._assert_allow_passing_data_to_remote_api(self.label)
131
+
132
+ def infer(self, dataset):
133
+ return [
134
+ self.client.chat.completions.create(
135
+ messages=[
136
+ # {
137
+ # "role": "system",
138
+ # "content": self.system_prompt,
139
+ # },
140
+ {
141
+ "role": "user",
142
+ "content": instance["source"],
143
+ }
144
+ ],
145
+ model=self.model_name,
146
+ frequency_penalty=self.parameters.frequency_penalty,
147
+ presence_penalty=self.parameters.presence_penalty,
148
+ max_tokens=self.parameters.max_tokens,
149
+ seed=self.parameters.seed,
150
+ stop=self.parameters.stop,
151
+ temperature=self.parameters.temperature,
152
+ top_p=self.parameters.top_p,
153
+ )
154
+ for instance in dataset
155
+ ]
load.py DELETED
@@ -1,15 +0,0 @@
1
- from typing import Union
2
-
3
- from datasets import DatasetDict
4
-
5
- from .artifact import fetch_artifact
6
- from .operator import StreamSource
7
-
8
-
9
- def load_dataset(source: Union[StreamSource, str]) -> DatasetDict:
10
- assert isinstance(
11
- source, (StreamSource, str)
12
- ), "source must be a StreamSource or a string"
13
- if isinstance(source, str):
14
- source, _ = fetch_artifact(source)
15
- return source().to_dataset()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
metric.py CHANGED
@@ -19,16 +19,13 @@ from .file_utils import __file__ as _
19
  from .formats import __file__ as _
20
  from .fusion import __file__ as _
21
  from .generator_utils import __file__ as _
22
- from .hf_utils import __file__ as _
23
  from .hf_utils import verify_versions_compatibility
24
  from .inference import __file__ as _
25
  from .instructions import __file__ as _
26
  from .llm_as_judge import __file__ as _
27
  from .loaders import __file__ as _
28
  from .logging_utils import __file__ as _
29
- from .metric_utils import UNITXT_METRIC_SCHEMA
30
- from .metric_utils import __file__ as _
31
- from .metric_utils import _compute
32
  from .metrics import __file__ as _
33
  from .normalizers import __file__ as _
34
  from .operator import __file__ as _
@@ -39,7 +36,6 @@ from .random_utils import __file__ as _
39
  from .recipe import __file__ as _
40
  from .register import __file__ as _
41
  from .schema import __file__ as _
42
- from .settings_utils import __file__ as _
43
  from .settings_utils import get_constants
44
  from .span_lableing_operators import __file__ as _
45
  from .split_utils import __file__ as _
@@ -53,7 +49,6 @@ from .task import __file__ as _
53
  from .templates import __file__ as _
54
  from .text_utils import __file__ as _
55
  from .type_utils import __file__ as _
56
- from .utils import __file__ as _
57
  from .utils import is_package_installed
58
  from .validate import __file__ as _
59
  from .version import __file__ as _
 
19
  from .formats import __file__ as _
20
  from .fusion import __file__ as _
21
  from .generator_utils import __file__ as _
 
22
  from .hf_utils import verify_versions_compatibility
23
  from .inference import __file__ as _
24
  from .instructions import __file__ as _
25
  from .llm_as_judge import __file__ as _
26
  from .loaders import __file__ as _
27
  from .logging_utils import __file__ as _
28
+ from .metric_utils import UNITXT_METRIC_SCHEMA, _compute
 
 
29
  from .metrics import __file__ as _
30
  from .normalizers import __file__ as _
31
  from .operator import __file__ as _
 
36
  from .recipe import __file__ as _
37
  from .register import __file__ as _
38
  from .schema import __file__ as _
 
39
  from .settings_utils import get_constants
40
  from .span_lableing_operators import __file__ as _
41
  from .split_utils import __file__ as _
 
49
  from .templates import __file__ as _
50
  from .text_utils import __file__ as _
51
  from .type_utils import __file__ as _
 
52
  from .utils import is_package_installed
53
  from .validate import __file__ as _
54
  from .version import __file__ as _
metrics.py CHANGED
@@ -2255,6 +2255,8 @@ class Perplexity(BulkInstanceMetric):
2255
  self.model_class().from_pretrained(self.model_name).to(self.device)
2256
  )
2257
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
 
 
2258
  self.single_token_mode = single_token_mode
2259
 
2260
  def compute_lm(
@@ -3348,7 +3350,8 @@ class BinaryMaxF1(F1Binary):
3348
 
3349
  best_thr = -1
3350
  best_f1 = -1
3351
- for thr in set(float_predictions):
 
3352
  new_predictions = [
3353
  "1" if float_prediction >= thr else "0"
3354
  for float_prediction in float_predictions
 
2255
  self.model_class().from_pretrained(self.model_name).to(self.device)
2256
  )
2257
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
2258
+ if self.tokenizer.pad_token_id is None:
2259
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
2260
  self.single_token_mode = single_token_mode
2261
 
2262
  def compute_lm(
 
3350
 
3351
  best_thr = -1
3352
  best_f1 = -1
3353
+ thrs = {round(fp, 3) for fp in float_predictions}
3354
+ for thr in thrs:
3355
  new_predictions = [
3356
  "1" if float_prediction >= thr else "0"
3357
  for float_prediction in float_predictions
operators.py CHANGED
@@ -1704,6 +1704,66 @@ class Shuffle(PagedStreamOperator):
1704
  yield from page
1705
 
1706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1707
  class EncodeLabels(StreamInstanceOperator):
1708
  """Encode each value encountered in any field in 'fields' into the integers 0,1,...
1709
 
 
1704
  yield from page
1705
 
1706
 
1707
+ class FeatureGroupedShuffle(Shuffle):
1708
+ """Class for shuffling an input dataset by instance 'blocks', not on the individual instance level.
1709
+
1710
+ Example is if the dataset consists of questions with paraphrases of it, and each question falls into a topic.
1711
+ All paraphrases have the same ID value as the original.
1712
+ In this case, we may want to shuffle on grouping_features = ['question ID'],
1713
+ to keep the paraphrases and original question together.
1714
+ We may also want to group by both 'question ID' and 'topic', if the question IDs are repeated between topics.
1715
+ In this case, grouping_features = ['question ID', 'topic']
1716
+
1717
+ Args:
1718
+ grouping_features (list of strings): list of feature names to use to define the groups.
1719
+ a group is defined by each unique observed combination of data values for features in grouping_features
1720
+ shuffle_within_group (bool): whether to further shuffle the instances within each group block, keeping the block order
1721
+
1722
+ Args (of superclass):
1723
+ page_size (int): The size of each page in the stream. Defaults to 1000.
1724
+ Note: shuffle_by_grouping_features determines the unique groups (unique combinations of values of grouping_features)
1725
+ separately by page (determined by page_size). If a block of instances in the same group are split
1726
+ into separate pages (either by a page break falling in the group, or the dataset was not sorted by
1727
+ grouping_features), these instances will be shuffled separately and thus the grouping may be
1728
+ broken up by pages. If the user wants to ensure the shuffle does the grouping and shuffling
1729
+ across all pages, set the page_size to be larger than the dataset size.
1730
+ See outputs_2features_bigpage and outputs_2features_smallpage in test_grouped_shuffle.
1731
+ """
1732
+
1733
+ grouping_features: List[str] = None
1734
+ shuffle_within_group: bool = False
1735
+
1736
+ def process(self, page: List[Dict], stream_name: Optional[str] = None) -> Generator:
1737
+ if self.grouping_features is None:
1738
+ super().process(page, stream_name)
1739
+ else:
1740
+ yield from self.shuffle_by_grouping_features(page)
1741
+
1742
+ def shuffle_by_grouping_features(self, page):
1743
+ import itertools
1744
+ from collections import defaultdict
1745
+
1746
+ groups_to_instances = defaultdict(list)
1747
+ for item in page:
1748
+ groups_to_instances[
1749
+ tuple(item[ff] for ff in self.grouping_features)
1750
+ ].append(item)
1751
+ # now extract the groups (i.e., lists of dicts with order preserved)
1752
+ page_blocks = list(groups_to_instances.values())
1753
+ # and now shuffle the blocks
1754
+ self.random_generator.shuffle(page_blocks)
1755
+ if self.shuffle_within_group:
1756
+ blocks = []
1757
+ # reshuffle the instances within each block, but keep the blocks in order
1758
+ for block in page_blocks:
1759
+ self.random_generator.shuffle(block)
1760
+ blocks.append(block)
1761
+ page_blocks = blocks
1762
+
1763
+ # now flatten the list so it consists of individual dicts, but in (randomized) block order
1764
+ return list(itertools.chain(*page_blocks))
1765
+
1766
+
1767
  class EncodeLabels(StreamInstanceOperator):
1768
  """Encode each value encountered in any field in 'fields' into the integers 0,1,...
1769
 
processors.py CHANGED
@@ -46,6 +46,14 @@ class RegexParser(FieldOperator):
46
  return re.findall(self.regex, text)
47
 
48
 
 
 
 
 
 
 
 
 
49
  class LoadJson(FieldOperator):
50
  def process_value(self, text: Any) -> Any:
51
  try:
 
46
  return re.findall(self.regex, text)
47
 
48
 
49
+ class ExtractWithRegex(RegexParser):
50
+ def process_value(self, text: Any) -> Any:
51
+ matches = super().process_value(text)
52
+ if matches:
53
+ return matches[0]
54
+ return ""
55
+
56
+
57
  class LoadJson(FieldOperator):
58
  def process_value(self, text: Any) -> Any:
59
  try:
renderers.py DELETED
@@ -1,132 +0,0 @@
1
- from abc import ABC
2
- from typing import Any, Dict, List, Optional
3
-
4
- from .dataclass import InternalField
5
- from .formats import Format, ICLFormat
6
- from .instructions import Instruction
7
- from .operator import Operator, SequentialOperator, StreamInstanceOperator
8
- from .random_utils import get_random
9
- from .templates import Template
10
-
11
-
12
- class Renderer(ABC):
13
- pass
14
- # @abstractmethod
15
- # def get_postprocessors(self) -> List[str]:
16
- # pass
17
-
18
-
19
- class RenderTemplate(Renderer, StreamInstanceOperator):
20
- template: Template
21
- random_reference: bool = False
22
- skip_rendered_instance: bool = True
23
-
24
- def process(
25
- self, instance: Dict[str, Any], stream_name: Optional[str] = None
26
- ) -> Dict[str, Any]:
27
- if self.skip_rendered_instance:
28
- if (
29
- "inputs" not in instance
30
- and "outputs" not in instance
31
- and "source" in instance
32
- and "target" in instance
33
- and "references" in instance
34
- ):
35
- return instance
36
-
37
- inputs = instance["inputs"]
38
- outputs = instance["outputs"]
39
-
40
- source = self.template.process_inputs(inputs)
41
- targets = self.template.process_outputs(outputs)
42
-
43
- if self.template.is_multi_reference:
44
- assert isinstance(targets, list), f"{targets} must be a list"
45
- references = targets
46
- if self.random_reference:
47
- target = get_random().choice(references)
48
- else:
49
- if len(references) == 0:
50
- raise ValueError("No references found")
51
- target = references[0]
52
- else:
53
- references = [targets]
54
- target = targets
55
-
56
- instance.update(
57
- {
58
- "source": source,
59
- "target": target,
60
- "references": references,
61
- }
62
- )
63
-
64
- return instance
65
-
66
-
67
- class RenderDemonstrations(RenderTemplate):
68
- demos_field: str
69
-
70
- def process(
71
- self, instance: Dict[str, Any], stream_name: Optional[str] = None
72
- ) -> Dict[str, Any]:
73
- demos = instance.get(self.demos_field, [])
74
-
75
- processed_demos = []
76
- for demo_instance in demos:
77
- demo_instance = super().process(demo_instance)
78
- processed_demos.append(demo_instance)
79
-
80
- instance[self.demos_field] = processed_demos
81
-
82
- return instance
83
-
84
-
85
- class RenderInstruction(Renderer, StreamInstanceOperator):
86
- instruction: Instruction
87
-
88
- def process(
89
- self, instance: Dict[str, Any], stream_name: Optional[str] = None
90
- ) -> Dict[str, Any]:
91
- if self.instruction is not None:
92
- instance["instruction"] = self.instruction()
93
- else:
94
- instance["instruction"] = ""
95
- return instance
96
-
97
-
98
- class RenderFormat(Renderer, StreamInstanceOperator):
99
- format: Format
100
- demos_field: str = None
101
-
102
- def process(
103
- self, instance: Dict[str, Any], stream_name: Optional[str] = None
104
- ) -> Dict[str, Any]:
105
- demos_instances = instance.pop(self.demos_field, None)
106
- if demos_instances is not None:
107
- instance["source"] = self.format.format(
108
- instance, demos_instances=demos_instances
109
- )
110
- else:
111
- instance["source"] = self.format.format(instance)
112
- return instance
113
-
114
-
115
- class StandardRenderer(Renderer, SequentialOperator):
116
- template: Template
117
- instruction: Instruction = None
118
- demos_field: str = None
119
- format: ICLFormat = None
120
-
121
- steps: List[Operator] = InternalField(default_factory=list)
122
-
123
- def prepare(self):
124
- self.steps = [
125
- RenderTemplate(template=self.template),
126
- RenderDemonstrations(template=self.template, demos_field=self.demos_field),
127
- RenderInstruction(instruction=self.instruction),
128
- RenderFormat(format=self.format, demos_field=self.demos_field),
129
- ]
130
-
131
- def get_postprocessors(self):
132
- return self.template.get_postprocessors()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
serializers.py DELETED
@@ -1,130 +0,0 @@
1
- from abc import ABC, abstractmethod
2
- from copy import deepcopy
3
- from typing import (
4
- Any,
5
- Dict,
6
- List,
7
- )
8
-
9
- from .operators import FieldOperator
10
-
11
- """
12
- TableSerializer converts a given table into a flat sequence with special symbols.
13
- Input table format must be:
14
- {"header": ["col1", "col2"], "rows": [["row11", "row12"], ["row21", "row22"], ["row31", "row32"]]}
15
- Output format varies depending on the chosen serializer. Abstract class at the top defines structure of a typical table serializer that any concrete implementation should follow.
16
- """
17
-
18
-
19
- class TableSerializer(ABC, FieldOperator):
20
- # main method to serialize a table
21
- @abstractmethod
22
- def serialize_table(self, table_content: Dict) -> str:
23
- pass
24
-
25
- # method to process table header
26
- @abstractmethod
27
- def process_header(self, header: List):
28
- pass
29
-
30
- # method to process a table row
31
- @abstractmethod
32
- def process_row(self, row: List, row_index: int):
33
- pass
34
-
35
-
36
- # Concrete classes implementing table serializers follow..
37
- """
38
- Indexed Row Major Table Serializer.
39
- Commonly used row major serialization format.
40
- Format: col : col1 | col2 | col 3 row 1 : val1 | val2 | val3 | val4 row 2 : val1 | ...
41
- """
42
-
43
-
44
- class IndexedRowMajorTableSerializer(TableSerializer):
45
- def process_value(self, table: Any) -> Any:
46
- table_input = deepcopy(table)
47
- return self.serialize_table(table_content=table_input)
48
-
49
- # main method that processes a table
50
- # table_content must be in the presribed input format
51
- def serialize_table(self, table_content: Dict) -> str:
52
- # Extract headers and rows from the dictionary
53
- header = table_content.get("header", [])
54
- rows = table_content.get("rows", [])
55
-
56
- assert header and rows, "Incorrect input table format"
57
-
58
- # Process table header first
59
- serialized_tbl_str = self.process_header(header) + " "
60
-
61
- # Process rows sequentially starting from row 1
62
- for i, row in enumerate(rows, start=1):
63
- serialized_tbl_str += self.process_row(row, row_index=i) + " "
64
-
65
- # return serialized table as a string
66
- return serialized_tbl_str.strip()
67
-
68
- # serialize header into a string containing the list of column names separated by '|' symbol
69
- def process_header(self, header: List):
70
- return "col : " + " | ".join(header)
71
-
72
- # serialize a table row into a string containing the list of cell values separated by '|'
73
- def process_row(self, row: List, row_index: int):
74
- serialized_row_str = ""
75
- row_cell_values = [
76
- str(value) if isinstance(value, (int, float)) else value for value in row
77
- ]
78
-
79
- serialized_row_str += " | ".join(row_cell_values)
80
-
81
- return f"row {row_index} : {serialized_row_str}"
82
-
83
-
84
- """
85
- Markdown Table Serializer.
86
- Markdown table format is used in GitHub code primarily.
87
- Format:
88
- |col1|col2|col3|
89
- |---|---|---|
90
- |A|4|1|
91
- |I|2|1|
92
- ...
93
- """
94
-
95
-
96
- class MarkdownTableSerializer(TableSerializer):
97
- def process_value(self, table: Any) -> Any:
98
- table_input = deepcopy(table)
99
- return self.serialize_table(table_content=table_input)
100
-
101
- # main method that serializes a table.
102
- # table_content must be in the presribed input format.
103
- def serialize_table(self, table_content: Dict) -> str:
104
- # Extract headers and rows from the dictionary
105
- header = table_content.get("header", [])
106
- rows = table_content.get("rows", [])
107
-
108
- assert header and rows, "Incorrect input table format"
109
-
110
- # Process table header first
111
- serialized_tbl_str = self.process_header(header)
112
-
113
- # Process rows sequentially starting from row 1
114
- for i, row in enumerate(rows, start=1):
115
- serialized_tbl_str += self.process_row(row, row_index=i)
116
-
117
- # return serialized table as a string
118
- return serialized_tbl_str.strip()
119
-
120
- # serialize header into a string containing the list of column names
121
- def process_header(self, header: List):
122
- header_str = "|{}|\n".format("|".join(header))
123
- header_str += "|{}|\n".format("|".join(["---"] * len(header)))
124
- return header_str
125
-
126
- # serialize a table row into a string containing the list of cell values
127
- def process_row(self, row: List, row_index: int):
128
- row_str = ""
129
- row_str += "|{}|\n".format("|".join(str(cell) for cell in row))
130
- return row_str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
standard.py CHANGED
@@ -187,6 +187,11 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
187
  return list(multi_stream["__inference__"])
188
 
189
  def prepare(self):
 
 
 
 
 
190
  self.set_pipelines()
191
 
192
  loader = self.card.loader
@@ -220,7 +225,7 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
220
  self.augmentor.set_task_input_fields(self.card.task.augmentable_inputs)
221
  self.processing.steps.append(self.augmentor)
222
 
223
- if self.demos_pool_size is not None:
224
  self.processing.steps.append(
225
  CreateDemosPool(
226
  from_split=self.demos_taken_from,
@@ -229,8 +234,6 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
229
  remove_targets_from_source_split=self.demos_removed_from_data,
230
  )
231
  )
232
-
233
- if self.num_demos > 0:
234
  if self.sampler is None:
235
  if self.card.sampler is None:
236
  raise ValueError(
 
187
  return list(multi_stream["__inference__"])
188
 
189
  def prepare(self):
190
+ # To avoid the Python's mutable default list trap, we set the default value to None
191
+ # and then set it to an empty list if it is None.
192
+ if self.card.preprocess_steps is None:
193
+ self.card.preprocess_steps = []
194
+
195
  self.set_pipelines()
196
 
197
  loader = self.card.loader
 
225
  self.augmentor.set_task_input_fields(self.card.task.augmentable_inputs)
226
  self.processing.steps.append(self.augmentor)
227
 
228
+ if self.num_demos > 0:
229
  self.processing.steps.append(
230
  CreateDemosPool(
231
  from_split=self.demos_taken_from,
 
234
  remove_targets_from_source_split=self.demos_removed_from_data,
235
  )
236
  )
 
 
237
  if self.sampler is None:
238
  if self.card.sampler is None:
239
  raise ValueError(
task.py CHANGED
@@ -1,6 +1,9 @@
1
- from typing import Any, Dict, List, Optional
2
 
 
 
3
  from .operator import StreamInstanceOperator
 
4
 
5
 
6
  class Tasker:
@@ -10,41 +13,88 @@ class Tasker:
10
  class FormTask(Tasker, StreamInstanceOperator):
11
  """FormTask packs the different instance fields into dictionaries by their roles in the task.
12
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  The output instance contains three fields:
14
  "inputs" whose value is a sub-dictionary of the input instance, consisting of all the fields listed in Arg 'inputs'.
15
  "outputs" -- for the fields listed in Arg "outputs".
16
  "metrics" -- to contain the value of Arg 'metrics'
17
-
18
  """
19
 
20
- inputs: List[str]
21
- outputs: List[str]
22
  metrics: List[str]
 
23
  augmentable_inputs: List[str] = []
24
 
25
  def verify(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  for augmentable_input in self.augmentable_inputs:
27
  assert (
28
  augmentable_input in self.inputs
29
  ), f"augmentable_input {augmentable_input} is not part of {self.inputs}"
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def process(
32
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
33
  ) -> Dict[str, Any]:
34
- try:
35
- inputs = {key: instance[key] for key in self.inputs}
36
- except KeyError as e:
37
- raise KeyError(
38
- f"Unexpected FormTask input column names ({[key for key in self.inputs if key not in instance]})."
39
- f"The available input names: {list(instance.keys())}"
40
- ) from e
41
- try:
42
- outputs = {key: instance[key] for key in self.outputs}
43
- except KeyError as e:
44
- raise KeyError(
45
- f"Unexpected FormTask output column names: {[key for key in self.outputs if key not in instance]}"
46
- f" \n available names:{list(instance.keys())}\n given output names:{self.outputs}"
47
- ) from e
48
 
49
  return {
50
  "inputs": inputs,
 
1
+ from typing import Any, Dict, List, Optional, Union
2
 
3
+ from .artifact import fetch_artifact
4
+ from .logging_utils import get_logger
5
  from .operator import StreamInstanceOperator
6
+ from .type_utils import isoftype, parse_type_string, verify_required_schema
7
 
8
 
9
  class Tasker:
 
13
  class FormTask(Tasker, StreamInstanceOperator):
14
  """FormTask packs the different instance fields into dictionaries by their roles in the task.
15
 
16
+ Attributes:
17
+ inputs (Union[Dict[str, str], List[str]]):
18
+ Dictionary with string names of instance input fields and types of respective values.
19
+ In case a list is passed, each type will be assumed to be Any.
20
+ outputs (Union[Dict[str, str], List[str]]):
21
+ Dictionary with string names of instance output fields and types of respective values.
22
+ In case a list is passed, each type will be assumed to be Any.
23
+ metrics (List[str]): List of names of metrics to be used in the task.
24
+ prediction_type (Optional[str]):
25
+ Need to be consistent with all used metrics. Defaults to None, which means that it will
26
+ be set to Any.
27
+
28
  The output instance contains three fields:
29
  "inputs" whose value is a sub-dictionary of the input instance, consisting of all the fields listed in Arg 'inputs'.
30
  "outputs" -- for the fields listed in Arg "outputs".
31
  "metrics" -- to contain the value of Arg 'metrics'
 
32
  """
33
 
34
+ inputs: Union[Dict[str, str], List[str]]
35
+ outputs: Union[Dict[str, str], List[str]]
36
  metrics: List[str]
37
+ prediction_type: Optional[str] = None
38
  augmentable_inputs: List[str] = []
39
 
40
  def verify(self):
41
+ for io_type in ["inputs", "outputs"]:
42
+ data = self.inputs if io_type == "inputs" else self.outputs
43
+ if not isoftype(data, Dict[str, str]):
44
+ get_logger().warning(
45
+ f"'{io_type}' field of Task should be a dictionary of field names and their types. "
46
+ f"For example, {{'text': 'str', 'classes': 'List[str]'}}. Instead only '{data}' was "
47
+ f"passed. All types will be assumed to be 'Any'. In future version of unitxt this "
48
+ f"will raise an exception."
49
+ )
50
+ data = {key: "Any" for key in data}
51
+ if io_type == "inputs":
52
+ self.inputs = data
53
+ else:
54
+ self.outputs = data
55
+
56
+ if not self.prediction_type:
57
+ get_logger().warning(
58
+ "'prediction_type' was not set in Task. It is used to check the output of "
59
+ "template post processors is compatible with the expected input of the metrics. "
60
+ "Setting `prediction_type` to 'Any' (no checking is done). In future version "
61
+ "of unitxt this will raise an exception."
62
+ )
63
+ self.prediction_type = "Any"
64
+
65
+ self.check_metrics_type()
66
+
67
  for augmentable_input in self.augmentable_inputs:
68
  assert (
69
  augmentable_input in self.inputs
70
  ), f"augmentable_input {augmentable_input} is not part of {self.inputs}"
71
 
72
+ def check_metrics_type(self) -> None:
73
+ prediction_type = parse_type_string(self.prediction_type)
74
+ for metric_name in self.metrics:
75
+ metric = fetch_artifact(metric_name)[0]
76
+ metric_prediction_type = metric.get_prediction_type()
77
+
78
+ if (
79
+ prediction_type == metric_prediction_type
80
+ or prediction_type == Any
81
+ or metric_prediction_type == Any
82
+ ):
83
+ continue
84
+
85
+ raise ValueError(
86
+ f"The task's prediction type ({prediction_type}) and '{metric_name}' "
87
+ f"metric's prediction type ({metric_prediction_type}) are different."
88
+ )
89
+
90
  def process(
91
  self, instance: Dict[str, Any], stream_name: Optional[str] = None
92
  ) -> Dict[str, Any]:
93
+ verify_required_schema(self.inputs, instance)
94
+ verify_required_schema(self.outputs, instance)
95
+
96
+ inputs = {key: instance[key] for key in self.inputs.keys()}
97
+ outputs = {key: instance[key] for key in self.outputs.keys()}
 
 
 
 
 
 
 
 
 
98
 
99
  return {
100
  "inputs": inputs,
templates.py CHANGED
@@ -137,7 +137,7 @@ class InputOutputTemplate(Template):
137
  return target, references
138
 
139
 
140
- class InputOutputReferenceTemplate(InputOutputTemplate):
141
  reference: str
142
 
143
  def outputs_to_target_and_references(self, outputs: Dict[str, object]) -> str:
 
137
  return target, references
138
 
139
 
140
+ class InputOutputTemplateWithCustomTarget(InputOutputTemplate):
141
  reference: str
142
 
143
  def outputs_to_target_and_references(self, outputs: Dict[str, object]) -> str:
tests.py DELETED
@@ -1,17 +0,0 @@
1
- test_cases = [
2
- {
3
- "predictions": [0, 0],
4
- "references": [1, 1],
5
- "result": {"metric_score": 0}
6
- },
7
- {
8
- "predictions": [1, 1],
9
- "references": [1, 1],
10
- "result": {"metric_score": 1}
11
- },
12
- {
13
- "predictions": [1, 0],
14
- "references": [1, 1],
15
- "result": {"metric_score": 0.5}
16
- }
17
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
type_utils.py CHANGED
@@ -841,3 +841,34 @@ def to_float_or_default(v, failure_default=0):
841
  if failure_default is None:
842
  raise e
843
  return failure_default
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
841
  if failure_default is None:
842
  raise e
843
  return failure_default
844
+
845
+
846
+ def verify_required_schema(
847
+ required_schema_dict: typing.Dict[str, str],
848
+ input_dict: typing.Dict[str, typing.Any],
849
+ ) -> None:
850
+ """Verifies if passed input_dict has all required fields, and they are of proper types according to required_schema_dict.
851
+
852
+ Parameters:
853
+ required_schema_dict (Dict[str, str]):
854
+ Schema where a key is name of a field and a value is a string
855
+ representing a type of its value.
856
+ input_dict (Dict[str, Any]):
857
+ Dict with input fields and their respective values.
858
+ """
859
+ for field_name, data_type_string in required_schema_dict.items():
860
+ try:
861
+ value = input_dict[field_name]
862
+ except KeyError as e:
863
+ raise KeyError(
864
+ f"Unexpected field name: '{field_name}'. "
865
+ f"The available names: {list(input_dict.keys())}."
866
+ ) from e
867
+
868
+ data_type = parse_type_string(data_type_string)
869
+
870
+ if not isoftype(value, data_type):
871
+ raise ValueError(
872
+ f"Passed value '{value}' of field '{field_name}' is not "
873
+ f"of required type: ({data_type_string})."
874
+ )
version.py CHANGED
@@ -1 +1 @@
1
- version = "1.7.7"
 
1
+ version = "1.7.8"