Upload folder using huggingface_hub
Browse files- _version.py +0 -1
- app.py +0 -3
- common.py +0 -104
- dataset.py +3 -7
- hf_utils.py +2 -2
- inference.py +124 -2
- load.py +0 -15
- metric.py +1 -6
- metrics.py +4 -1
- operators.py +60 -0
- processors.py +8 -0
- renderers.py +0 -132
- serializers.py +0 -130
- standard.py +6 -3
- task.py +68 -18
- templates.py +1 -1
- tests.py +0 -17
- type_utils.py +31 -0
- version.py +1 -1
_version.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
def get_current_version(): return '1.0.31'
|
|
|
|
app.py
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
from unitxt.ui import launch
|
2 |
-
|
3 |
-
launch()
|
|
|
|
|
|
|
|
common.py
DELETED
@@ -1,104 +0,0 @@
|
|
1 |
-
from typing import Union
|
2 |
-
|
3 |
-
from .card import TaskCard
|
4 |
-
from .collections import ItemPicker, RandomPicker
|
5 |
-
from .dataclass import OptionalField
|
6 |
-
from .operator import SourceOperator
|
7 |
-
from .recipe import Recipe, SequentialRecipe
|
8 |
-
from .schema import ToUnitxtGroup
|
9 |
-
from .splitters import RandomSampler, Sampler, SeparateSplit, SliceSplit, SpreadSplit
|
10 |
-
from .stream import MultiStream
|
11 |
-
from .templates import RenderTemplatedICL
|
12 |
-
|
13 |
-
|
14 |
-
class CommonRecipe(Recipe, SourceOperator):
|
15 |
-
card: TaskCard
|
16 |
-
demos_pool_name: str = "demos_pool"
|
17 |
-
demos_taken_from: str = "train"
|
18 |
-
demos_pool_size: int = None
|
19 |
-
demos_field: str = "demos"
|
20 |
-
num_demos: int = None
|
21 |
-
sampler: Sampler = None
|
22 |
-
instruction_item: Union[str, int] = None
|
23 |
-
template_item: Union[str, int] = None
|
24 |
-
system_prompt: str = None
|
25 |
-
|
26 |
-
def verify(self):
|
27 |
-
super().verify()
|
28 |
-
|
29 |
-
def prepare(self):
|
30 |
-
steps = [
|
31 |
-
self.card.loader,
|
32 |
-
]
|
33 |
-
|
34 |
-
if self.card.preprocess_steps is not None:
|
35 |
-
steps.extend(self.card.preprocess_steps)
|
36 |
-
|
37 |
-
steps.append(self.card.task)
|
38 |
-
|
39 |
-
if self.demos_pool_size is not None:
|
40 |
-
steps.append(
|
41 |
-
SeparateSplit(
|
42 |
-
from_split=self.demos_taken_from,
|
43 |
-
to_split_names=[self.demos_pool_name, self.demos_taken_from],
|
44 |
-
to_split_sizes=[int(self.demos_pool_size)],
|
45 |
-
)
|
46 |
-
)
|
47 |
-
|
48 |
-
if self.num_demos is not None:
|
49 |
-
sampler = self.card.sampler
|
50 |
-
|
51 |
-
if self.sampler is not None:
|
52 |
-
sampler = self.sampler
|
53 |
-
|
54 |
-
sampler.set_size(self.num_demos)
|
55 |
-
|
56 |
-
steps.append(
|
57 |
-
SpreadSplit(
|
58 |
-
source_stream=self.demos_pool_name,
|
59 |
-
target_field=self.demos_field,
|
60 |
-
sampler=sampler,
|
61 |
-
)
|
62 |
-
)
|
63 |
-
|
64 |
-
if self.card.instructions is not None:
|
65 |
-
if not self.instruction_item is None:
|
66 |
-
picker = ItemPicker(int(self.instruction_item))
|
67 |
-
else:
|
68 |
-
picker = RandomPicker()
|
69 |
-
instruction = picker(self.card.instructions)
|
70 |
-
else:
|
71 |
-
instruction = None
|
72 |
-
|
73 |
-
if self.card.templates is not None:
|
74 |
-
if self.template_item is None:
|
75 |
-
picker = RandomPicker()
|
76 |
-
else:
|
77 |
-
picker = ItemPicker(self.template_item)
|
78 |
-
template = picker(self.card.templates)
|
79 |
-
else:
|
80 |
-
template = None
|
81 |
-
|
82 |
-
render = RenderTemplatedICL(
|
83 |
-
instruction=instruction,
|
84 |
-
template=template,
|
85 |
-
demos_field=self.demos_field,
|
86 |
-
system_prompt=self.system_prompt,
|
87 |
-
)
|
88 |
-
|
89 |
-
steps.append(render)
|
90 |
-
|
91 |
-
postprocessors = render.get_postprocessors()
|
92 |
-
|
93 |
-
steps.append(
|
94 |
-
ToUnitxtGroup(
|
95 |
-
group="unitxt",
|
96 |
-
metrics=self.card.task.metrics,
|
97 |
-
postprocessors=postprocessors,
|
98 |
-
)
|
99 |
-
)
|
100 |
-
|
101 |
-
self.recipe = SequentialRecipe(steps)
|
102 |
-
|
103 |
-
def process(self) -> MultiStream:
|
104 |
-
return self.recipe()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset.py
CHANGED
@@ -10,7 +10,6 @@ from .catalog import __file__ as _
|
|
10 |
from .collections import __file__ as _
|
11 |
from .collections_operators import __file__ as _
|
12 |
from .dataclass import __file__ as _
|
13 |
-
from .dataset_utils import __file__ as _
|
14 |
from .dataset_utils import get_dataset_artifact
|
15 |
from .deprecation_utils import __file__ as _
|
16 |
from .dialog_operators import __file__ as _
|
@@ -20,13 +19,11 @@ from .file_utils import __file__ as _
|
|
20 |
from .formats import __file__ as _
|
21 |
from .fusion import __file__ as _
|
22 |
from .generator_utils import __file__ as _
|
23 |
-
from .hf_utils import __file__ as _
|
24 |
from .hf_utils import verify_versions_compatibility
|
25 |
from .inference import __file__ as _
|
26 |
from .instructions import __file__ as _
|
27 |
from .llm_as_judge import __file__ as _
|
28 |
from .loaders import __file__ as _
|
29 |
-
from .logging_utils import __file__ as _
|
30 |
from .logging_utils import get_logger
|
31 |
from .metric import __file__ as _
|
32 |
from .metric_utils import __file__ as _
|
@@ -40,7 +37,6 @@ from .random_utils import __file__ as _
|
|
40 |
from .recipe import __file__ as _
|
41 |
from .register import __file__ as _
|
42 |
from .schema import __file__ as _
|
43 |
-
from .settings_utils import __file__ as _
|
44 |
from .settings_utils import get_constants
|
45 |
from .span_lableing_operators import __file__ as _
|
46 |
from .split_utils import __file__ as _
|
@@ -54,7 +50,6 @@ from .task import __file__ as _
|
|
54 |
from .templates import __file__ as _
|
55 |
from .text_utils import __file__ as _
|
56 |
from .type_utils import __file__ as _
|
57 |
-
from .utils import __file__ as _
|
58 |
from .utils import is_package_installed
|
59 |
from .validate import __file__ as _
|
60 |
from .version import __file__ as _
|
@@ -75,8 +70,9 @@ class Dataset(datasets.GeneratorBasedBuilder):
|
|
75 |
if is_package_installed("unitxt"):
|
76 |
verify_versions_compatibility("dataset", self.VERSION)
|
77 |
|
78 |
-
from unitxt.dataset_utils import
|
79 |
-
get_dataset_artifact as get_dataset_artifact_installed
|
|
|
80 |
|
81 |
logger.info("Loading with installed unitxt library...")
|
82 |
dataset = get_dataset_artifact_installed(self.config.name)
|
|
|
10 |
from .collections import __file__ as _
|
11 |
from .collections_operators import __file__ as _
|
12 |
from .dataclass import __file__ as _
|
|
|
13 |
from .dataset_utils import get_dataset_artifact
|
14 |
from .deprecation_utils import __file__ as _
|
15 |
from .dialog_operators import __file__ as _
|
|
|
19 |
from .formats import __file__ as _
|
20 |
from .fusion import __file__ as _
|
21 |
from .generator_utils import __file__ as _
|
|
|
22 |
from .hf_utils import verify_versions_compatibility
|
23 |
from .inference import __file__ as _
|
24 |
from .instructions import __file__ as _
|
25 |
from .llm_as_judge import __file__ as _
|
26 |
from .loaders import __file__ as _
|
|
|
27 |
from .logging_utils import get_logger
|
28 |
from .metric import __file__ as _
|
29 |
from .metric_utils import __file__ as _
|
|
|
37 |
from .recipe import __file__ as _
|
38 |
from .register import __file__ as _
|
39 |
from .schema import __file__ as _
|
|
|
40 |
from .settings_utils import get_constants
|
41 |
from .span_lableing_operators import __file__ as _
|
42 |
from .split_utils import __file__ as _
|
|
|
50 |
from .templates import __file__ as _
|
51 |
from .text_utils import __file__ as _
|
52 |
from .type_utils import __file__ as _
|
|
|
53 |
from .utils import is_package_installed
|
54 |
from .validate import __file__ as _
|
55 |
from .version import __file__ as _
|
|
|
70 |
if is_package_installed("unitxt"):
|
71 |
verify_versions_compatibility("dataset", self.VERSION)
|
72 |
|
73 |
+
from unitxt.dataset_utils import (
|
74 |
+
get_dataset_artifact as get_dataset_artifact_installed,
|
75 |
+
)
|
76 |
|
77 |
logger.info("Loading with installed unitxt library...")
|
78 |
dataset = get_dataset_artifact_installed(self.config.name)
|
hf_utils.py
CHANGED
@@ -24,9 +24,9 @@ class UnitxtVersionsConflictError(ValueError):
|
|
24 |
def __init__(self, error_in: str, hf_unitxt_version, installed_unitxt_version):
|
25 |
assert hf_unitxt_version != installed_unitxt_version
|
26 |
if compare_versions(hf_unitxt_version, installed_unitxt_version) == 1:
|
27 |
-
msg = f"Located installed unitxt version {installed_unitxt_version} that is older
|
28 |
if compare_versions(hf_unitxt_version, installed_unitxt_version) == -1:
|
29 |
-
msg = f"Located installed unitxt version {installed_unitxt_version} that is newer
|
30 |
super().__init__(msg)
|
31 |
|
32 |
|
|
|
24 |
def __init__(self, error_in: str, hf_unitxt_version, installed_unitxt_version):
|
25 |
assert hf_unitxt_version != installed_unitxt_version
|
26 |
if compare_versions(hf_unitxt_version, installed_unitxt_version) == 1:
|
27 |
+
msg = f"Located installed unitxt version {installed_unitxt_version} that is older than unitxt {error_in} version {hf_unitxt_version}. Please update unitxt package or uninstall it to avoid conflicts."
|
28 |
if compare_versions(hf_unitxt_version, installed_unitxt_version) == -1:
|
29 |
+
msg = f"Located installed unitxt version {installed_unitxt_version} that is newer than unitxt {error_in} version {hf_unitxt_version}. Please force-reload the {error_in} or downgrade unitxt to {error_in} version or uninstall unitxt to avoid conflicts."
|
30 |
super().__init__(msg)
|
31 |
|
32 |
|
inference.py
CHANGED
@@ -1,6 +1,11 @@
|
|
1 |
import abc
|
|
|
|
|
|
|
2 |
|
3 |
from .artifact import Artifact
|
|
|
|
|
4 |
|
5 |
|
6 |
class InferenceEngine(abc.ABC, Artifact):
|
@@ -11,12 +16,21 @@ class InferenceEngine(abc.ABC, Artifact):
|
|
11 |
"""Perform inference on the input dataset."""
|
12 |
pass
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
class HFPipelineBasedInferenceEngine(Artifact):
|
16 |
-
"""Abstract base class for inference."""
|
17 |
|
|
|
18 |
model_name: str
|
19 |
max_new_tokens: int
|
|
|
|
|
|
|
20 |
|
21 |
def prepare(self):
|
22 |
from transformers import pipeline
|
@@ -31,3 +45,111 @@ class HFPipelineBasedInferenceEngine(Artifact):
|
|
31 |
max_new_tokens=self.max_new_tokens,
|
32 |
)
|
33 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import abc
|
2 |
+
import os
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import List, Optional, Union
|
5 |
|
6 |
from .artifact import Artifact
|
7 |
+
from .operator import PackageRequirementsMixin
|
8 |
+
from .settings_utils import get_settings
|
9 |
|
10 |
|
11 |
class InferenceEngine(abc.ABC, Artifact):
|
|
|
16 |
"""Perform inference on the input dataset."""
|
17 |
pass
|
18 |
|
19 |
+
@staticmethod
|
20 |
+
def _assert_allow_passing_data_to_remote_api(remote_api_label: str):
|
21 |
+
assert get_settings().allow_passing_data_to_remote_api, (
|
22 |
+
f"LlmAsJudge metric cannot run send data to remote APIs ({remote_api_label}) when"
|
23 |
+
f" unitxt.settings.allow_passing_data_to_remote_api=False."
|
24 |
+
f" Set UNITXT_ALLOW_PASSING_DATA_TO_REMOTE_API environment variable, if you want to allow this. "
|
25 |
+
)
|
26 |
|
|
|
|
|
27 |
|
28 |
+
class HFPipelineBasedInferenceEngine(InferenceEngine, PackageRequirementsMixin):
|
29 |
model_name: str
|
30 |
max_new_tokens: int
|
31 |
+
_requirement = {
|
32 |
+
"transformers": "Install huggingface package using 'pip install --upgrade transformers"
|
33 |
+
}
|
34 |
|
35 |
def prepare(self):
|
36 |
from transformers import pipeline
|
|
|
45 |
max_new_tokens=self.max_new_tokens,
|
46 |
)
|
47 |
]
|
48 |
+
|
49 |
+
|
50 |
+
@dataclass()
|
51 |
+
class IbmGenAiInferenceEngineParams:
|
52 |
+
decoding_method: str = None
|
53 |
+
max_new_tokens: Optional[int] = None
|
54 |
+
min_new_tokens: Optional[int] = None
|
55 |
+
random_seed: Optional[int] = None
|
56 |
+
repetition_penalty: Optional[float] = None
|
57 |
+
stop_sequences: Optional[List[str]] = None
|
58 |
+
temperature: Optional[float] = None
|
59 |
+
top_k: Optional[int] = None
|
60 |
+
top_p: Optional[float] = None
|
61 |
+
typical_p: Optional[float] = None
|
62 |
+
|
63 |
+
|
64 |
+
class IbmGenAiInferenceEngine(InferenceEngine, PackageRequirementsMixin):
|
65 |
+
label: str = "ibm_genai"
|
66 |
+
model_name: str
|
67 |
+
parameters: IbmGenAiInferenceEngineParams = IbmGenAiInferenceEngineParams()
|
68 |
+
_requirement = {
|
69 |
+
"genai": "Install ibm-genai package using 'pip install --upgrade ibm-generative-ai"
|
70 |
+
}
|
71 |
+
|
72 |
+
def prepare(self):
|
73 |
+
from genai import Client, Credentials
|
74 |
+
|
75 |
+
api_key_env_var_name = "GENAI_KEY"
|
76 |
+
api_key = os.environ.get(api_key_env_var_name)
|
77 |
+
assert api_key is not None, (
|
78 |
+
f"Error while trying to run IbmGenAiInferenceEngine."
|
79 |
+
f" Please set the environment param '{api_key_env_var_name}'."
|
80 |
+
)
|
81 |
+
api_endpoint = os.environ.get("GENAI_KEY")
|
82 |
+
credentials = Credentials(api_key=api_key, api_endpoint=api_endpoint)
|
83 |
+
self.client = Client(credentials=credentials)
|
84 |
+
|
85 |
+
self._assert_allow_passing_data_to_remote_api(self.label)
|
86 |
+
|
87 |
+
def infer(self, dataset):
|
88 |
+
from genai.schema import TextGenerationParameters
|
89 |
+
|
90 |
+
genai_params = TextGenerationParameters(**self.parameters.__dict__)
|
91 |
+
return list(
|
92 |
+
self.client.text.generation.create(
|
93 |
+
model_id=self.model_name,
|
94 |
+
inputs=[instance["source"] for instance in dataset],
|
95 |
+
parameters=genai_params,
|
96 |
+
)
|
97 |
+
)
|
98 |
+
|
99 |
+
|
100 |
+
@dataclass
|
101 |
+
class OpenAiInferenceEngineParams:
|
102 |
+
frequency_penalty: Optional[float] = None
|
103 |
+
presence_penalty: Optional[float] = None
|
104 |
+
max_tokens: Optional[int] = None
|
105 |
+
seed: Optional[int] = None
|
106 |
+
stop: Union[Optional[str], List[str]] = None
|
107 |
+
temperature: Optional[float] = None
|
108 |
+
top_p: Optional[float] = None
|
109 |
+
|
110 |
+
|
111 |
+
class OpenAiInferenceEngine(InferenceEngine, PackageRequirementsMixin):
|
112 |
+
label: str = "openai"
|
113 |
+
model_name: str
|
114 |
+
parameters: OpenAiInferenceEngineParams = OpenAiInferenceEngineParams()
|
115 |
+
_requirement = {
|
116 |
+
"openai": "Install openai package using 'pip install --upgrade openai"
|
117 |
+
}
|
118 |
+
|
119 |
+
def prepare(self):
|
120 |
+
from openai import OpenAI
|
121 |
+
|
122 |
+
api_key_env_var_name = "OPENAI_API_KEY"
|
123 |
+
api_key = os.environ.get(api_key_env_var_name)
|
124 |
+
assert api_key is not None, (
|
125 |
+
f"Error while trying to run OpenAiInferenceEngine."
|
126 |
+
f" Please set the environment param '{api_key_env_var_name}'."
|
127 |
+
)
|
128 |
+
|
129 |
+
self.client = OpenAI(api_key=api_key)
|
130 |
+
self._assert_allow_passing_data_to_remote_api(self.label)
|
131 |
+
|
132 |
+
def infer(self, dataset):
|
133 |
+
return [
|
134 |
+
self.client.chat.completions.create(
|
135 |
+
messages=[
|
136 |
+
# {
|
137 |
+
# "role": "system",
|
138 |
+
# "content": self.system_prompt,
|
139 |
+
# },
|
140 |
+
{
|
141 |
+
"role": "user",
|
142 |
+
"content": instance["source"],
|
143 |
+
}
|
144 |
+
],
|
145 |
+
model=self.model_name,
|
146 |
+
frequency_penalty=self.parameters.frequency_penalty,
|
147 |
+
presence_penalty=self.parameters.presence_penalty,
|
148 |
+
max_tokens=self.parameters.max_tokens,
|
149 |
+
seed=self.parameters.seed,
|
150 |
+
stop=self.parameters.stop,
|
151 |
+
temperature=self.parameters.temperature,
|
152 |
+
top_p=self.parameters.top_p,
|
153 |
+
)
|
154 |
+
for instance in dataset
|
155 |
+
]
|
load.py
DELETED
@@ -1,15 +0,0 @@
|
|
1 |
-
from typing import Union
|
2 |
-
|
3 |
-
from datasets import DatasetDict
|
4 |
-
|
5 |
-
from .artifact import fetch_artifact
|
6 |
-
from .operator import StreamSource
|
7 |
-
|
8 |
-
|
9 |
-
def load_dataset(source: Union[StreamSource, str]) -> DatasetDict:
|
10 |
-
assert isinstance(
|
11 |
-
source, (StreamSource, str)
|
12 |
-
), "source must be a StreamSource or a string"
|
13 |
-
if isinstance(source, str):
|
14 |
-
source, _ = fetch_artifact(source)
|
15 |
-
return source().to_dataset()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metric.py
CHANGED
@@ -19,16 +19,13 @@ from .file_utils import __file__ as _
|
|
19 |
from .formats import __file__ as _
|
20 |
from .fusion import __file__ as _
|
21 |
from .generator_utils import __file__ as _
|
22 |
-
from .hf_utils import __file__ as _
|
23 |
from .hf_utils import verify_versions_compatibility
|
24 |
from .inference import __file__ as _
|
25 |
from .instructions import __file__ as _
|
26 |
from .llm_as_judge import __file__ as _
|
27 |
from .loaders import __file__ as _
|
28 |
from .logging_utils import __file__ as _
|
29 |
-
from .metric_utils import UNITXT_METRIC_SCHEMA
|
30 |
-
from .metric_utils import __file__ as _
|
31 |
-
from .metric_utils import _compute
|
32 |
from .metrics import __file__ as _
|
33 |
from .normalizers import __file__ as _
|
34 |
from .operator import __file__ as _
|
@@ -39,7 +36,6 @@ from .random_utils import __file__ as _
|
|
39 |
from .recipe import __file__ as _
|
40 |
from .register import __file__ as _
|
41 |
from .schema import __file__ as _
|
42 |
-
from .settings_utils import __file__ as _
|
43 |
from .settings_utils import get_constants
|
44 |
from .span_lableing_operators import __file__ as _
|
45 |
from .split_utils import __file__ as _
|
@@ -53,7 +49,6 @@ from .task import __file__ as _
|
|
53 |
from .templates import __file__ as _
|
54 |
from .text_utils import __file__ as _
|
55 |
from .type_utils import __file__ as _
|
56 |
-
from .utils import __file__ as _
|
57 |
from .utils import is_package_installed
|
58 |
from .validate import __file__ as _
|
59 |
from .version import __file__ as _
|
|
|
19 |
from .formats import __file__ as _
|
20 |
from .fusion import __file__ as _
|
21 |
from .generator_utils import __file__ as _
|
|
|
22 |
from .hf_utils import verify_versions_compatibility
|
23 |
from .inference import __file__ as _
|
24 |
from .instructions import __file__ as _
|
25 |
from .llm_as_judge import __file__ as _
|
26 |
from .loaders import __file__ as _
|
27 |
from .logging_utils import __file__ as _
|
28 |
+
from .metric_utils import UNITXT_METRIC_SCHEMA, _compute
|
|
|
|
|
29 |
from .metrics import __file__ as _
|
30 |
from .normalizers import __file__ as _
|
31 |
from .operator import __file__ as _
|
|
|
36 |
from .recipe import __file__ as _
|
37 |
from .register import __file__ as _
|
38 |
from .schema import __file__ as _
|
|
|
39 |
from .settings_utils import get_constants
|
40 |
from .span_lableing_operators import __file__ as _
|
41 |
from .split_utils import __file__ as _
|
|
|
49 |
from .templates import __file__ as _
|
50 |
from .text_utils import __file__ as _
|
51 |
from .type_utils import __file__ as _
|
|
|
52 |
from .utils import is_package_installed
|
53 |
from .validate import __file__ as _
|
54 |
from .version import __file__ as _
|
metrics.py
CHANGED
@@ -2255,6 +2255,8 @@ class Perplexity(BulkInstanceMetric):
|
|
2255 |
self.model_class().from_pretrained(self.model_name).to(self.device)
|
2256 |
)
|
2257 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
|
|
|
|
2258 |
self.single_token_mode = single_token_mode
|
2259 |
|
2260 |
def compute_lm(
|
@@ -3348,7 +3350,8 @@ class BinaryMaxF1(F1Binary):
|
|
3348 |
|
3349 |
best_thr = -1
|
3350 |
best_f1 = -1
|
3351 |
-
for
|
|
|
3352 |
new_predictions = [
|
3353 |
"1" if float_prediction >= thr else "0"
|
3354 |
for float_prediction in float_predictions
|
|
|
2255 |
self.model_class().from_pretrained(self.model_name).to(self.device)
|
2256 |
)
|
2257 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
2258 |
+
if self.tokenizer.pad_token_id is None:
|
2259 |
+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
2260 |
self.single_token_mode = single_token_mode
|
2261 |
|
2262 |
def compute_lm(
|
|
|
3350 |
|
3351 |
best_thr = -1
|
3352 |
best_f1 = -1
|
3353 |
+
thrs = {round(fp, 3) for fp in float_predictions}
|
3354 |
+
for thr in thrs:
|
3355 |
new_predictions = [
|
3356 |
"1" if float_prediction >= thr else "0"
|
3357 |
for float_prediction in float_predictions
|
operators.py
CHANGED
@@ -1704,6 +1704,66 @@ class Shuffle(PagedStreamOperator):
|
|
1704 |
yield from page
|
1705 |
|
1706 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1707 |
class EncodeLabels(StreamInstanceOperator):
|
1708 |
"""Encode each value encountered in any field in 'fields' into the integers 0,1,...
|
1709 |
|
|
|
1704 |
yield from page
|
1705 |
|
1706 |
|
1707 |
+
class FeatureGroupedShuffle(Shuffle):
|
1708 |
+
"""Class for shuffling an input dataset by instance 'blocks', not on the individual instance level.
|
1709 |
+
|
1710 |
+
Example is if the dataset consists of questions with paraphrases of it, and each question falls into a topic.
|
1711 |
+
All paraphrases have the same ID value as the original.
|
1712 |
+
In this case, we may want to shuffle on grouping_features = ['question ID'],
|
1713 |
+
to keep the paraphrases and original question together.
|
1714 |
+
We may also want to group by both 'question ID' and 'topic', if the question IDs are repeated between topics.
|
1715 |
+
In this case, grouping_features = ['question ID', 'topic']
|
1716 |
+
|
1717 |
+
Args:
|
1718 |
+
grouping_features (list of strings): list of feature names to use to define the groups.
|
1719 |
+
a group is defined by each unique observed combination of data values for features in grouping_features
|
1720 |
+
shuffle_within_group (bool): whether to further shuffle the instances within each group block, keeping the block order
|
1721 |
+
|
1722 |
+
Args (of superclass):
|
1723 |
+
page_size (int): The size of each page in the stream. Defaults to 1000.
|
1724 |
+
Note: shuffle_by_grouping_features determines the unique groups (unique combinations of values of grouping_features)
|
1725 |
+
separately by page (determined by page_size). If a block of instances in the same group are split
|
1726 |
+
into separate pages (either by a page break falling in the group, or the dataset was not sorted by
|
1727 |
+
grouping_features), these instances will be shuffled separately and thus the grouping may be
|
1728 |
+
broken up by pages. If the user wants to ensure the shuffle does the grouping and shuffling
|
1729 |
+
across all pages, set the page_size to be larger than the dataset size.
|
1730 |
+
See outputs_2features_bigpage and outputs_2features_smallpage in test_grouped_shuffle.
|
1731 |
+
"""
|
1732 |
+
|
1733 |
+
grouping_features: List[str] = None
|
1734 |
+
shuffle_within_group: bool = False
|
1735 |
+
|
1736 |
+
def process(self, page: List[Dict], stream_name: Optional[str] = None) -> Generator:
|
1737 |
+
if self.grouping_features is None:
|
1738 |
+
super().process(page, stream_name)
|
1739 |
+
else:
|
1740 |
+
yield from self.shuffle_by_grouping_features(page)
|
1741 |
+
|
1742 |
+
def shuffle_by_grouping_features(self, page):
|
1743 |
+
import itertools
|
1744 |
+
from collections import defaultdict
|
1745 |
+
|
1746 |
+
groups_to_instances = defaultdict(list)
|
1747 |
+
for item in page:
|
1748 |
+
groups_to_instances[
|
1749 |
+
tuple(item[ff] for ff in self.grouping_features)
|
1750 |
+
].append(item)
|
1751 |
+
# now extract the groups (i.e., lists of dicts with order preserved)
|
1752 |
+
page_blocks = list(groups_to_instances.values())
|
1753 |
+
# and now shuffle the blocks
|
1754 |
+
self.random_generator.shuffle(page_blocks)
|
1755 |
+
if self.shuffle_within_group:
|
1756 |
+
blocks = []
|
1757 |
+
# reshuffle the instances within each block, but keep the blocks in order
|
1758 |
+
for block in page_blocks:
|
1759 |
+
self.random_generator.shuffle(block)
|
1760 |
+
blocks.append(block)
|
1761 |
+
page_blocks = blocks
|
1762 |
+
|
1763 |
+
# now flatten the list so it consists of individual dicts, but in (randomized) block order
|
1764 |
+
return list(itertools.chain(*page_blocks))
|
1765 |
+
|
1766 |
+
|
1767 |
class EncodeLabels(StreamInstanceOperator):
|
1768 |
"""Encode each value encountered in any field in 'fields' into the integers 0,1,...
|
1769 |
|
processors.py
CHANGED
@@ -46,6 +46,14 @@ class RegexParser(FieldOperator):
|
|
46 |
return re.findall(self.regex, text)
|
47 |
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
class LoadJson(FieldOperator):
|
50 |
def process_value(self, text: Any) -> Any:
|
51 |
try:
|
|
|
46 |
return re.findall(self.regex, text)
|
47 |
|
48 |
|
49 |
+
class ExtractWithRegex(RegexParser):
|
50 |
+
def process_value(self, text: Any) -> Any:
|
51 |
+
matches = super().process_value(text)
|
52 |
+
if matches:
|
53 |
+
return matches[0]
|
54 |
+
return ""
|
55 |
+
|
56 |
+
|
57 |
class LoadJson(FieldOperator):
|
58 |
def process_value(self, text: Any) -> Any:
|
59 |
try:
|
renderers.py
DELETED
@@ -1,132 +0,0 @@
|
|
1 |
-
from abc import ABC
|
2 |
-
from typing import Any, Dict, List, Optional
|
3 |
-
|
4 |
-
from .dataclass import InternalField
|
5 |
-
from .formats import Format, ICLFormat
|
6 |
-
from .instructions import Instruction
|
7 |
-
from .operator import Operator, SequentialOperator, StreamInstanceOperator
|
8 |
-
from .random_utils import get_random
|
9 |
-
from .templates import Template
|
10 |
-
|
11 |
-
|
12 |
-
class Renderer(ABC):
|
13 |
-
pass
|
14 |
-
# @abstractmethod
|
15 |
-
# def get_postprocessors(self) -> List[str]:
|
16 |
-
# pass
|
17 |
-
|
18 |
-
|
19 |
-
class RenderTemplate(Renderer, StreamInstanceOperator):
|
20 |
-
template: Template
|
21 |
-
random_reference: bool = False
|
22 |
-
skip_rendered_instance: bool = True
|
23 |
-
|
24 |
-
def process(
|
25 |
-
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
26 |
-
) -> Dict[str, Any]:
|
27 |
-
if self.skip_rendered_instance:
|
28 |
-
if (
|
29 |
-
"inputs" not in instance
|
30 |
-
and "outputs" not in instance
|
31 |
-
and "source" in instance
|
32 |
-
and "target" in instance
|
33 |
-
and "references" in instance
|
34 |
-
):
|
35 |
-
return instance
|
36 |
-
|
37 |
-
inputs = instance["inputs"]
|
38 |
-
outputs = instance["outputs"]
|
39 |
-
|
40 |
-
source = self.template.process_inputs(inputs)
|
41 |
-
targets = self.template.process_outputs(outputs)
|
42 |
-
|
43 |
-
if self.template.is_multi_reference:
|
44 |
-
assert isinstance(targets, list), f"{targets} must be a list"
|
45 |
-
references = targets
|
46 |
-
if self.random_reference:
|
47 |
-
target = get_random().choice(references)
|
48 |
-
else:
|
49 |
-
if len(references) == 0:
|
50 |
-
raise ValueError("No references found")
|
51 |
-
target = references[0]
|
52 |
-
else:
|
53 |
-
references = [targets]
|
54 |
-
target = targets
|
55 |
-
|
56 |
-
instance.update(
|
57 |
-
{
|
58 |
-
"source": source,
|
59 |
-
"target": target,
|
60 |
-
"references": references,
|
61 |
-
}
|
62 |
-
)
|
63 |
-
|
64 |
-
return instance
|
65 |
-
|
66 |
-
|
67 |
-
class RenderDemonstrations(RenderTemplate):
|
68 |
-
demos_field: str
|
69 |
-
|
70 |
-
def process(
|
71 |
-
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
72 |
-
) -> Dict[str, Any]:
|
73 |
-
demos = instance.get(self.demos_field, [])
|
74 |
-
|
75 |
-
processed_demos = []
|
76 |
-
for demo_instance in demos:
|
77 |
-
demo_instance = super().process(demo_instance)
|
78 |
-
processed_demos.append(demo_instance)
|
79 |
-
|
80 |
-
instance[self.demos_field] = processed_demos
|
81 |
-
|
82 |
-
return instance
|
83 |
-
|
84 |
-
|
85 |
-
class RenderInstruction(Renderer, StreamInstanceOperator):
|
86 |
-
instruction: Instruction
|
87 |
-
|
88 |
-
def process(
|
89 |
-
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
90 |
-
) -> Dict[str, Any]:
|
91 |
-
if self.instruction is not None:
|
92 |
-
instance["instruction"] = self.instruction()
|
93 |
-
else:
|
94 |
-
instance["instruction"] = ""
|
95 |
-
return instance
|
96 |
-
|
97 |
-
|
98 |
-
class RenderFormat(Renderer, StreamInstanceOperator):
|
99 |
-
format: Format
|
100 |
-
demos_field: str = None
|
101 |
-
|
102 |
-
def process(
|
103 |
-
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
104 |
-
) -> Dict[str, Any]:
|
105 |
-
demos_instances = instance.pop(self.demos_field, None)
|
106 |
-
if demos_instances is not None:
|
107 |
-
instance["source"] = self.format.format(
|
108 |
-
instance, demos_instances=demos_instances
|
109 |
-
)
|
110 |
-
else:
|
111 |
-
instance["source"] = self.format.format(instance)
|
112 |
-
return instance
|
113 |
-
|
114 |
-
|
115 |
-
class StandardRenderer(Renderer, SequentialOperator):
|
116 |
-
template: Template
|
117 |
-
instruction: Instruction = None
|
118 |
-
demos_field: str = None
|
119 |
-
format: ICLFormat = None
|
120 |
-
|
121 |
-
steps: List[Operator] = InternalField(default_factory=list)
|
122 |
-
|
123 |
-
def prepare(self):
|
124 |
-
self.steps = [
|
125 |
-
RenderTemplate(template=self.template),
|
126 |
-
RenderDemonstrations(template=self.template, demos_field=self.demos_field),
|
127 |
-
RenderInstruction(instruction=self.instruction),
|
128 |
-
RenderFormat(format=self.format, demos_field=self.demos_field),
|
129 |
-
]
|
130 |
-
|
131 |
-
def get_postprocessors(self):
|
132 |
-
return self.template.get_postprocessors()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
serializers.py
DELETED
@@ -1,130 +0,0 @@
|
|
1 |
-
from abc import ABC, abstractmethod
|
2 |
-
from copy import deepcopy
|
3 |
-
from typing import (
|
4 |
-
Any,
|
5 |
-
Dict,
|
6 |
-
List,
|
7 |
-
)
|
8 |
-
|
9 |
-
from .operators import FieldOperator
|
10 |
-
|
11 |
-
"""
|
12 |
-
TableSerializer converts a given table into a flat sequence with special symbols.
|
13 |
-
Input table format must be:
|
14 |
-
{"header": ["col1", "col2"], "rows": [["row11", "row12"], ["row21", "row22"], ["row31", "row32"]]}
|
15 |
-
Output format varies depending on the chosen serializer. Abstract class at the top defines structure of a typical table serializer that any concrete implementation should follow.
|
16 |
-
"""
|
17 |
-
|
18 |
-
|
19 |
-
class TableSerializer(ABC, FieldOperator):
|
20 |
-
# main method to serialize a table
|
21 |
-
@abstractmethod
|
22 |
-
def serialize_table(self, table_content: Dict) -> str:
|
23 |
-
pass
|
24 |
-
|
25 |
-
# method to process table header
|
26 |
-
@abstractmethod
|
27 |
-
def process_header(self, header: List):
|
28 |
-
pass
|
29 |
-
|
30 |
-
# method to process a table row
|
31 |
-
@abstractmethod
|
32 |
-
def process_row(self, row: List, row_index: int):
|
33 |
-
pass
|
34 |
-
|
35 |
-
|
36 |
-
# Concrete classes implementing table serializers follow..
|
37 |
-
"""
|
38 |
-
Indexed Row Major Table Serializer.
|
39 |
-
Commonly used row major serialization format.
|
40 |
-
Format: col : col1 | col2 | col 3 row 1 : val1 | val2 | val3 | val4 row 2 : val1 | ...
|
41 |
-
"""
|
42 |
-
|
43 |
-
|
44 |
-
class IndexedRowMajorTableSerializer(TableSerializer):
|
45 |
-
def process_value(self, table: Any) -> Any:
|
46 |
-
table_input = deepcopy(table)
|
47 |
-
return self.serialize_table(table_content=table_input)
|
48 |
-
|
49 |
-
# main method that processes a table
|
50 |
-
# table_content must be in the presribed input format
|
51 |
-
def serialize_table(self, table_content: Dict) -> str:
|
52 |
-
# Extract headers and rows from the dictionary
|
53 |
-
header = table_content.get("header", [])
|
54 |
-
rows = table_content.get("rows", [])
|
55 |
-
|
56 |
-
assert header and rows, "Incorrect input table format"
|
57 |
-
|
58 |
-
# Process table header first
|
59 |
-
serialized_tbl_str = self.process_header(header) + " "
|
60 |
-
|
61 |
-
# Process rows sequentially starting from row 1
|
62 |
-
for i, row in enumerate(rows, start=1):
|
63 |
-
serialized_tbl_str += self.process_row(row, row_index=i) + " "
|
64 |
-
|
65 |
-
# return serialized table as a string
|
66 |
-
return serialized_tbl_str.strip()
|
67 |
-
|
68 |
-
# serialize header into a string containing the list of column names separated by '|' symbol
|
69 |
-
def process_header(self, header: List):
|
70 |
-
return "col : " + " | ".join(header)
|
71 |
-
|
72 |
-
# serialize a table row into a string containing the list of cell values separated by '|'
|
73 |
-
def process_row(self, row: List, row_index: int):
|
74 |
-
serialized_row_str = ""
|
75 |
-
row_cell_values = [
|
76 |
-
str(value) if isinstance(value, (int, float)) else value for value in row
|
77 |
-
]
|
78 |
-
|
79 |
-
serialized_row_str += " | ".join(row_cell_values)
|
80 |
-
|
81 |
-
return f"row {row_index} : {serialized_row_str}"
|
82 |
-
|
83 |
-
|
84 |
-
"""
|
85 |
-
Markdown Table Serializer.
|
86 |
-
Markdown table format is used in GitHub code primarily.
|
87 |
-
Format:
|
88 |
-
|col1|col2|col3|
|
89 |
-
|---|---|---|
|
90 |
-
|A|4|1|
|
91 |
-
|I|2|1|
|
92 |
-
...
|
93 |
-
"""
|
94 |
-
|
95 |
-
|
96 |
-
class MarkdownTableSerializer(TableSerializer):
|
97 |
-
def process_value(self, table: Any) -> Any:
|
98 |
-
table_input = deepcopy(table)
|
99 |
-
return self.serialize_table(table_content=table_input)
|
100 |
-
|
101 |
-
# main method that serializes a table.
|
102 |
-
# table_content must be in the presribed input format.
|
103 |
-
def serialize_table(self, table_content: Dict) -> str:
|
104 |
-
# Extract headers and rows from the dictionary
|
105 |
-
header = table_content.get("header", [])
|
106 |
-
rows = table_content.get("rows", [])
|
107 |
-
|
108 |
-
assert header and rows, "Incorrect input table format"
|
109 |
-
|
110 |
-
# Process table header first
|
111 |
-
serialized_tbl_str = self.process_header(header)
|
112 |
-
|
113 |
-
# Process rows sequentially starting from row 1
|
114 |
-
for i, row in enumerate(rows, start=1):
|
115 |
-
serialized_tbl_str += self.process_row(row, row_index=i)
|
116 |
-
|
117 |
-
# return serialized table as a string
|
118 |
-
return serialized_tbl_str.strip()
|
119 |
-
|
120 |
-
# serialize header into a string containing the list of column names
|
121 |
-
def process_header(self, header: List):
|
122 |
-
header_str = "|{}|\n".format("|".join(header))
|
123 |
-
header_str += "|{}|\n".format("|".join(["---"] * len(header)))
|
124 |
-
return header_str
|
125 |
-
|
126 |
-
# serialize a table row into a string containing the list of cell values
|
127 |
-
def process_row(self, row: List, row_index: int):
|
128 |
-
row_str = ""
|
129 |
-
row_str += "|{}|\n".format("|".join(str(cell) for cell in row))
|
130 |
-
return row_str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
standard.py
CHANGED
@@ -187,6 +187,11 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
187 |
return list(multi_stream["__inference__"])
|
188 |
|
189 |
def prepare(self):
|
|
|
|
|
|
|
|
|
|
|
190 |
self.set_pipelines()
|
191 |
|
192 |
loader = self.card.loader
|
@@ -220,7 +225,7 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
220 |
self.augmentor.set_task_input_fields(self.card.task.augmentable_inputs)
|
221 |
self.processing.steps.append(self.augmentor)
|
222 |
|
223 |
-
if self.
|
224 |
self.processing.steps.append(
|
225 |
CreateDemosPool(
|
226 |
from_split=self.demos_taken_from,
|
@@ -229,8 +234,6 @@ class BaseRecipe(Recipe, SourceSequentialOperator):
|
|
229 |
remove_targets_from_source_split=self.demos_removed_from_data,
|
230 |
)
|
231 |
)
|
232 |
-
|
233 |
-
if self.num_demos > 0:
|
234 |
if self.sampler is None:
|
235 |
if self.card.sampler is None:
|
236 |
raise ValueError(
|
|
|
187 |
return list(multi_stream["__inference__"])
|
188 |
|
189 |
def prepare(self):
|
190 |
+
# To avoid the Python's mutable default list trap, we set the default value to None
|
191 |
+
# and then set it to an empty list if it is None.
|
192 |
+
if self.card.preprocess_steps is None:
|
193 |
+
self.card.preprocess_steps = []
|
194 |
+
|
195 |
self.set_pipelines()
|
196 |
|
197 |
loader = self.card.loader
|
|
|
225 |
self.augmentor.set_task_input_fields(self.card.task.augmentable_inputs)
|
226 |
self.processing.steps.append(self.augmentor)
|
227 |
|
228 |
+
if self.num_demos > 0:
|
229 |
self.processing.steps.append(
|
230 |
CreateDemosPool(
|
231 |
from_split=self.demos_taken_from,
|
|
|
234 |
remove_targets_from_source_split=self.demos_removed_from_data,
|
235 |
)
|
236 |
)
|
|
|
|
|
237 |
if self.sampler is None:
|
238 |
if self.card.sampler is None:
|
239 |
raise ValueError(
|
task.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1 |
-
from typing import Any, Dict, List, Optional
|
2 |
|
|
|
|
|
3 |
from .operator import StreamInstanceOperator
|
|
|
4 |
|
5 |
|
6 |
class Tasker:
|
@@ -10,41 +13,88 @@ class Tasker:
|
|
10 |
class FormTask(Tasker, StreamInstanceOperator):
|
11 |
"""FormTask packs the different instance fields into dictionaries by their roles in the task.
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
The output instance contains three fields:
|
14 |
"inputs" whose value is a sub-dictionary of the input instance, consisting of all the fields listed in Arg 'inputs'.
|
15 |
"outputs" -- for the fields listed in Arg "outputs".
|
16 |
"metrics" -- to contain the value of Arg 'metrics'
|
17 |
-
|
18 |
"""
|
19 |
|
20 |
-
inputs: List[str]
|
21 |
-
outputs: List[str]
|
22 |
metrics: List[str]
|
|
|
23 |
augmentable_inputs: List[str] = []
|
24 |
|
25 |
def verify(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
for augmentable_input in self.augmentable_inputs:
|
27 |
assert (
|
28 |
augmentable_input in self.inputs
|
29 |
), f"augmentable_input {augmentable_input} is not part of {self.inputs}"
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
def process(
|
32 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
33 |
) -> Dict[str, Any]:
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
f"The available input names: {list(instance.keys())}"
|
40 |
-
) from e
|
41 |
-
try:
|
42 |
-
outputs = {key: instance[key] for key in self.outputs}
|
43 |
-
except KeyError as e:
|
44 |
-
raise KeyError(
|
45 |
-
f"Unexpected FormTask output column names: {[key for key in self.outputs if key not in instance]}"
|
46 |
-
f" \n available names:{list(instance.keys())}\n given output names:{self.outputs}"
|
47 |
-
) from e
|
48 |
|
49 |
return {
|
50 |
"inputs": inputs,
|
|
|
1 |
+
from typing import Any, Dict, List, Optional, Union
|
2 |
|
3 |
+
from .artifact import fetch_artifact
|
4 |
+
from .logging_utils import get_logger
|
5 |
from .operator import StreamInstanceOperator
|
6 |
+
from .type_utils import isoftype, parse_type_string, verify_required_schema
|
7 |
|
8 |
|
9 |
class Tasker:
|
|
|
13 |
class FormTask(Tasker, StreamInstanceOperator):
|
14 |
"""FormTask packs the different instance fields into dictionaries by their roles in the task.
|
15 |
|
16 |
+
Attributes:
|
17 |
+
inputs (Union[Dict[str, str], List[str]]):
|
18 |
+
Dictionary with string names of instance input fields and types of respective values.
|
19 |
+
In case a list is passed, each type will be assumed to be Any.
|
20 |
+
outputs (Union[Dict[str, str], List[str]]):
|
21 |
+
Dictionary with string names of instance output fields and types of respective values.
|
22 |
+
In case a list is passed, each type will be assumed to be Any.
|
23 |
+
metrics (List[str]): List of names of metrics to be used in the task.
|
24 |
+
prediction_type (Optional[str]):
|
25 |
+
Need to be consistent with all used metrics. Defaults to None, which means that it will
|
26 |
+
be set to Any.
|
27 |
+
|
28 |
The output instance contains three fields:
|
29 |
"inputs" whose value is a sub-dictionary of the input instance, consisting of all the fields listed in Arg 'inputs'.
|
30 |
"outputs" -- for the fields listed in Arg "outputs".
|
31 |
"metrics" -- to contain the value of Arg 'metrics'
|
|
|
32 |
"""
|
33 |
|
34 |
+
inputs: Union[Dict[str, str], List[str]]
|
35 |
+
outputs: Union[Dict[str, str], List[str]]
|
36 |
metrics: List[str]
|
37 |
+
prediction_type: Optional[str] = None
|
38 |
augmentable_inputs: List[str] = []
|
39 |
|
40 |
def verify(self):
|
41 |
+
for io_type in ["inputs", "outputs"]:
|
42 |
+
data = self.inputs if io_type == "inputs" else self.outputs
|
43 |
+
if not isoftype(data, Dict[str, str]):
|
44 |
+
get_logger().warning(
|
45 |
+
f"'{io_type}' field of Task should be a dictionary of field names and their types. "
|
46 |
+
f"For example, {{'text': 'str', 'classes': 'List[str]'}}. Instead only '{data}' was "
|
47 |
+
f"passed. All types will be assumed to be 'Any'. In future version of unitxt this "
|
48 |
+
f"will raise an exception."
|
49 |
+
)
|
50 |
+
data = {key: "Any" for key in data}
|
51 |
+
if io_type == "inputs":
|
52 |
+
self.inputs = data
|
53 |
+
else:
|
54 |
+
self.outputs = data
|
55 |
+
|
56 |
+
if not self.prediction_type:
|
57 |
+
get_logger().warning(
|
58 |
+
"'prediction_type' was not set in Task. It is used to check the output of "
|
59 |
+
"template post processors is compatible with the expected input of the metrics. "
|
60 |
+
"Setting `prediction_type` to 'Any' (no checking is done). In future version "
|
61 |
+
"of unitxt this will raise an exception."
|
62 |
+
)
|
63 |
+
self.prediction_type = "Any"
|
64 |
+
|
65 |
+
self.check_metrics_type()
|
66 |
+
|
67 |
for augmentable_input in self.augmentable_inputs:
|
68 |
assert (
|
69 |
augmentable_input in self.inputs
|
70 |
), f"augmentable_input {augmentable_input} is not part of {self.inputs}"
|
71 |
|
72 |
+
def check_metrics_type(self) -> None:
|
73 |
+
prediction_type = parse_type_string(self.prediction_type)
|
74 |
+
for metric_name in self.metrics:
|
75 |
+
metric = fetch_artifact(metric_name)[0]
|
76 |
+
metric_prediction_type = metric.get_prediction_type()
|
77 |
+
|
78 |
+
if (
|
79 |
+
prediction_type == metric_prediction_type
|
80 |
+
or prediction_type == Any
|
81 |
+
or metric_prediction_type == Any
|
82 |
+
):
|
83 |
+
continue
|
84 |
+
|
85 |
+
raise ValueError(
|
86 |
+
f"The task's prediction type ({prediction_type}) and '{metric_name}' "
|
87 |
+
f"metric's prediction type ({metric_prediction_type}) are different."
|
88 |
+
)
|
89 |
+
|
90 |
def process(
|
91 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
92 |
) -> Dict[str, Any]:
|
93 |
+
verify_required_schema(self.inputs, instance)
|
94 |
+
verify_required_schema(self.outputs, instance)
|
95 |
+
|
96 |
+
inputs = {key: instance[key] for key in self.inputs.keys()}
|
97 |
+
outputs = {key: instance[key] for key in self.outputs.keys()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
return {
|
100 |
"inputs": inputs,
|
templates.py
CHANGED
@@ -137,7 +137,7 @@ class InputOutputTemplate(Template):
|
|
137 |
return target, references
|
138 |
|
139 |
|
140 |
-
class
|
141 |
reference: str
|
142 |
|
143 |
def outputs_to_target_and_references(self, outputs: Dict[str, object]) -> str:
|
|
|
137 |
return target, references
|
138 |
|
139 |
|
140 |
+
class InputOutputTemplateWithCustomTarget(InputOutputTemplate):
|
141 |
reference: str
|
142 |
|
143 |
def outputs_to_target_and_references(self, outputs: Dict[str, object]) -> str:
|
tests.py
DELETED
@@ -1,17 +0,0 @@
|
|
1 |
-
test_cases = [
|
2 |
-
{
|
3 |
-
"predictions": [0, 0],
|
4 |
-
"references": [1, 1],
|
5 |
-
"result": {"metric_score": 0}
|
6 |
-
},
|
7 |
-
{
|
8 |
-
"predictions": [1, 1],
|
9 |
-
"references": [1, 1],
|
10 |
-
"result": {"metric_score": 1}
|
11 |
-
},
|
12 |
-
{
|
13 |
-
"predictions": [1, 0],
|
14 |
-
"references": [1, 1],
|
15 |
-
"result": {"metric_score": 0.5}
|
16 |
-
}
|
17 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
type_utils.py
CHANGED
@@ -841,3 +841,34 @@ def to_float_or_default(v, failure_default=0):
|
|
841 |
if failure_default is None:
|
842 |
raise e
|
843 |
return failure_default
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
841 |
if failure_default is None:
|
842 |
raise e
|
843 |
return failure_default
|
844 |
+
|
845 |
+
|
846 |
+
def verify_required_schema(
|
847 |
+
required_schema_dict: typing.Dict[str, str],
|
848 |
+
input_dict: typing.Dict[str, typing.Any],
|
849 |
+
) -> None:
|
850 |
+
"""Verifies if passed input_dict has all required fields, and they are of proper types according to required_schema_dict.
|
851 |
+
|
852 |
+
Parameters:
|
853 |
+
required_schema_dict (Dict[str, str]):
|
854 |
+
Schema where a key is name of a field and a value is a string
|
855 |
+
representing a type of its value.
|
856 |
+
input_dict (Dict[str, Any]):
|
857 |
+
Dict with input fields and their respective values.
|
858 |
+
"""
|
859 |
+
for field_name, data_type_string in required_schema_dict.items():
|
860 |
+
try:
|
861 |
+
value = input_dict[field_name]
|
862 |
+
except KeyError as e:
|
863 |
+
raise KeyError(
|
864 |
+
f"Unexpected field name: '{field_name}'. "
|
865 |
+
f"The available names: {list(input_dict.keys())}."
|
866 |
+
) from e
|
867 |
+
|
868 |
+
data_type = parse_type_string(data_type_string)
|
869 |
+
|
870 |
+
if not isoftype(value, data_type):
|
871 |
+
raise ValueError(
|
872 |
+
f"Passed value '{value}' of field '{field_name}' is not "
|
873 |
+
f"of required type: ({data_type_string})."
|
874 |
+
)
|
version.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
version = "1.7.
|
|
|
1 |
+
version = "1.7.8"
|