import csv import io from abc import abstractmethod from typing import Any, Dict, List, Union from .dataclass import AbstractField, Field from .operators import InstanceFieldOperator from .type_utils import isoftype, to_type_string from .types import Dialog, Image, Number, Table class Serializer(InstanceFieldOperator): def process_instance_value(self, value: Any, instance: Dict[str, Any]) -> str: return self.serialize(value, instance) @abstractmethod def serialize(self, value: Any, instance: Dict[str, Any]) -> str: pass class DefaultSerializer(Serializer): def serialize(self, value: Any, instance: Dict[str, Any]) -> str: return str(value) class SingleTypeSerializer(InstanceFieldOperator): serialized_type: object = AbstractField() def process_instance_value(self, value: Any, instance: Dict[str, Any]) -> str: if not isoftype(value, self.serialized_type): raise ValueError( f"SingleTypeSerializer for type {self.serialized_type} should get this type. got {to_type_string(value)}" ) return self.serialize(value, instance) class DefaultListSerializer(Serializer): def serialize(self, value: Any, instance: Dict[str, Any]) -> str: if isinstance(value, list): return ", ".join(str(item) for item in value) return str(value) class ListSerializer(SingleTypeSerializer): serialized_type = list def serialize(self, value: Any, instance: Dict[str, Any]) -> str: return ", ".join(str(item) for item in value) class DialogSerializer(SingleTypeSerializer): serialized_type = Dialog def serialize(self, value: Dialog, instance: Dict[str, Any]) -> str: # Convert the Dialog into a string representation, typically combining roles and content return "\n".join(f"{turn['role']}: {turn['content']}" for turn in value) class NumberSerializer(SingleTypeSerializer): serialized_type = Number def serialize(self, value: Number, instance: Dict[str, Any]) -> str: # Check if the value is an integer or a float if isinstance(value, int): return str(value) # For floats, format to one decimal place if isinstance(value, float): return f"{value:.1f}" raise ValueError("Unsupported type for NumberSerializer") class NumberQuantizingSerializer(NumberSerializer): serialized_type = Number quantum: Union[float, int] = 0.1 def serialize(self, value: Number, instance: Dict[str, Any]) -> str: if isoftype(value, Number): quantized_value = round(value / self.quantum) / (1 / self.quantum) if isinstance(self.quantum, int): quantized_value = int(quantized_value) return str(quantized_value) raise ValueError("Unsupported type for NumberSerializer") class TableSerializer(SingleTypeSerializer): serialized_type = Table def serialize(self, value: Table, instance: Dict[str, Any]) -> str: output = io.StringIO() writer = csv.writer(output, lineterminator="\n") # Write the header and rows to the CSV writer writer.writerow(value["header"]) writer.writerows(value["rows"]) # Retrieve the CSV string return output.getvalue().strip() class ImageSerializer(SingleTypeSerializer): serialized_type = Image def serialize(self, value: Image, instance: Dict[str, Any]) -> str: if "media" not in instance: instance["media"] = {} if "images" not in instance["media"]: instance["media"]["images"] = [] idx = len(instance["media"]["images"]) instance["media"]["images"].append(value["image"]) value["image"] = f'' return value["image"] class MultiTypeSerializer(Serializer): serializers: List[SingleTypeSerializer] = Field( default_factory=lambda: [ ImageSerializer(), TableSerializer(), DialogSerializer(), ] ) def verify(self): super().verify() self._verify_serializers(self.serializers) def _verify_serializers(self, serializers): if not isoftype(serializers, List[SingleTypeSerializer]): raise ValueError( "MultiTypeSerializer requires the list of serializers to be List[SingleTypeSerializer]." ) def add_serializers(self, serializers: List[SingleTypeSerializer]): self._verify_serializers(serializers) self.serializers = serializers + self.serializers def serialize(self, value: Any, instance: Dict[str, Any]) -> Any: for serializer in self.serializers: if isoftype(value, serializer.serialized_type): return serializer.serialize(value, instance) return str(value)