metric / operators.py
Elron's picture
Upload operators.py with huggingface_hub
d292ceb
raw
history blame
No virus
10.6 kB
from dataclasses import field
from typing import Any, Dict, Generator, Iterable, List, Optional, Union
from .text_utils import nested_tuple_to_string
from .artifact import Artifact, fetch_artifact
from .operator import (
MultiStream,
MultiStreamOperator,
SingleStreamOperator,
SingleStreamReducer,
Stream,
StreamInitializerOperator,
StreamInstanceOperator,
PagedStreamOperator,
)
from .stream import MultiStream, Stream
from .utils import flatten_dict
import random
from .utils import dict_query
class FromIterables(StreamInitializerOperator):
"""
Creates a MultiStream from iterables.
Args:
iterables (Dict[str, Iterable]): A dictionary where each key-value pair represents a stream name and its corresponding iterable.
"""
def process(self, iterables: Dict[str, Iterable]) -> MultiStream:
return MultiStream.from_iterables(iterables)
class MapInstanceValues(StreamInstanceOperator):
"""
Maps values in each instance of a stream based on the provided mappers.
Args:
mappers (Dict[str, Dict[str, str]]): A dictionary where each key-value pair represents a field in the instance and a mapper for that field.
strict (bool): If True, the operator will raise a KeyError if a value is not in its corresponding mapper. If False, unmapped values will be left unchanged. Defaults to True.
"""
mappers: Dict[str, Dict[str, str]]
strict: bool = True
def verify(self):
# make sure the mappers are valid
for key, mapper in self.mappers.items():
assert isinstance(mapper, dict), f"Mapper for given field {key} should be a dict, got {type(mapper)}"
for k, v in mapper.items():
assert isinstance(k, str), f'Key "{k}" in mapper for field "{key}" should be a string, got {type(k)}'
def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
result = {}
for key, value in instance.items():
str_value = str(value)
if key in self.mappers:
mapper = self.mappers[key]
if self.strict:
value = mapper[str_value]
else:
if str_value in mapper:
value = mapper[str_value]
result[key] = value
return result
class FlattenInstances(StreamInstanceOperator):
"""
Flattens each instance in a stream, making nested dictionary entries into top-level entries.
Args:
parent_key (str): A prefix to use for the flattened keys. Defaults to an empty string.
sep (str): The separator to use when concatenating nested keys. Defaults to "_".
"""
parent_key: str = ""
sep: str = "_"
def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
return flatten_dict(instance, parent_key=self.parent_key, sep=self.sep)
class AddFields(StreamInstanceOperator):
"""
Adds specified fields to each instance in a stream.
Args:
fields (Dict[str, object]): The fields to add to each instance.
"""
fields: Dict[str, object]
def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
instance.update(self.fields)
return instance
class MapNestedDictValuesByQueries(StreamInstanceOperator):
field_to_query: Dict[str, str]
def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
updates = {}
for field, query in self.field_to_query.items():
updates[field] = dict_query(instance, query)
instance.update(updates)
return instance
class ArtifactFetcherMixin:
"""
Provides a way to fetch and cache artifacts in the system.
Args:
cache (Dict[str, Artifact]): A cache for storing fetched artifacts.
"""
cache: Dict[str, Artifact] = {}
@classmethod
def get_artifact(cls, artifact_identifier: str) -> Artifact:
if artifact_identifier not in cls.cache:
artifact, artifactory = fetch_artifact(artifact_identifier)
cls.cache[artifact_identifier] = artifact
return cls.cache[artifact_identifier]
class ApplyValueOperatorsField(StreamInstanceOperator, ArtifactFetcherMixin):
"""
Applies value operators to each instance in a stream based on specified fields.
Args:
value_field (str): The field containing the value to be operated on.
operators_field (str): The field containing the operators to be applied.
default_operators (List[str]): A list of default operators to be used if no operators are found in the instance.
"""
value_field: str
operators_field: str
default_operators: List[str] = None
def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
operator_names = instance.get(self.operators_field)
if operator_names is None:
assert (
self.default_operators is not None
), f"No operators found in {self.field} field and no default operators provided"
operator_names = self.default_operators
if isinstance(operator_names, str):
operator_names = [operator_names]
for name in operator_names:
operator = self.get_artifact(name)
instance = operator(instance, self.value_field)
return instance
class FilterByValues(SingleStreamOperator):
"""
Filters a stream, yielding only instances that match specified values.
Args:
values (Dict[str, Any]): The values that instances should match to be included in the output.
"""
values: Dict[str, Any]
def process(self, stream: Stream, stream_name: str = None) -> Generator:
for instance in stream:
if all(instance[key] == value for key, value in self.values.items()):
yield instance
class Unique(SingleStreamReducer):
"""
Reduces a stream to unique instances based on specified fields.
Args:
fields (List[str]): The fields that should be unique in each instance.
"""
fields: List[str] = field(default_factory=list)
@staticmethod
def to_tuple(instance: dict, fields: List[str]) -> tuple:
result = []
for field in fields:
value = instance[field]
if isinstance(value, list):
value = tuple(value)
result.append(value)
return tuple(result)
def process(self, stream: Stream) -> Stream:
seen = set()
for instance in stream:
values = self.to_tuple(instance, self.fields)
if values not in seen:
seen.add(values)
return list(seen)
class SplitByValue(MultiStreamOperator):
"""
Splits a MultiStream into multiple streams based on unique values in specified fields.
Args:
fields (List[str]): The fields to use when splitting the MultiStream.
"""
fields: List[str] = field(default_factory=list)
def process(self, multi_stream: MultiStream) -> MultiStream:
uniques = Unique(fields=self.fields)(multi_stream)
result = {}
for stream_name, stream in multi_stream.items():
stream_unique_values = uniques[stream_name]
for unique_values in stream_unique_values:
filtering_values = {field: value for field, value in zip(self.fields, unique_values)}
filtered_streams = FilterByValues(values=filtering_values)._process_single_stream(stream)
filtered_stream_name = stream_name + "_" + nested_tuple_to_string(unique_values)
result[filtered_stream_name] = filtered_streams
return MultiStream(result)
class ApplyStreamOperatorsField(SingleStreamOperator, ArtifactFetcherMixin):
"""
Applies stream operators to a stream based on specified fields in each instance.
Args:
field (str): The field containing the operators to be applied.
reversed (bool): Whether to apply the operators in reverse order.
"""
field: str
reversed: bool = False
def process(self, stream: Stream, stream_name: str = None) -> Generator:
first_instance = stream.peak()
operators = first_instance.get(self.field, [])
if isinstance(operators, str):
operators = [operators]
if self.reversed:
operators = list(reversed(operators))
for operator_name in operators:
operator = self.get_artifact(operator_name)
assert isinstance(
operator, SingleStreamOperator
), f"Operator {operator_name} must be a SingleStreamOperator"
stream = operator.process(stream)
yield from stream
class AddFieldNamePrefix(StreamInstanceOperator):
"""
Adds a prefix to each field name in each instance of a stream.
Args:
prefix_dict (Dict[str, str]): A dictionary mapping stream names to prefixes.
"""
prefix_dict: Dict[str, str]
def prepare(self):
return super().prepare()
def process(self, instance: Dict[str, Any], stream_name: str = None) -> Dict[str, Any]:
return {self.prefix_dict[stream_name] + key: value for key, value in instance.items()}
class MergeStreams(MultiStreamOperator):
"""
Merges multiple streams into a single stream.
Args:
new_stream_name (str): The name of the new stream resulting from the merge.
add_origin_stream_name (bool): Whether to add the origin stream name to each instance.
origin_stream_name_field_name (str): The field name for the origin stream name.
"""
new_stream_name: str = "all"
add_origin_stream_name: bool = True
origin_stream_name_field_name: str = "origin"
def merge(self, multi_stream):
for stream_name, stream in multi_stream.items():
for instance in stream:
if self.add_origin_stream_name:
instance[self.origin_stream_name_field_name] = stream_name
yield instance
def process(self, multi_stream: MultiStream) -> MultiStream:
return MultiStream({self.new_stream_name: Stream(self.merge, gen_kwargs={"multi_stream": multi_stream})})
class Shuffle(PagedStreamOperator):
"""
Shuffles the order of instances in each page of a stream.
Args:
page_size (int): The size of each page in the stream. Defaults to 1000.
"""
def process(self, page: List[Dict], stream_name: str = None) -> Generator:
random.shuffle(page)
yield from page