File size: 1,388 Bytes
7aa5a5e
3c36ff5
 
 
 
 
6502654
3c36ff5
 
a4795aa
3c36ff5
 
 
 
a4795aa
3c36ff5
a4795aa
 
3c36ff5
 
 
 
 
 
 
 
 
 
 
 
 
7aa5a5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from functools import lru_cache
from typing import Any, Dict, List, Union

from datasets import DatasetDict

from .artifact import fetch_artifact
from .dataset_utils import get_dataset_artifact
from .logging_utils import get_logger
from .metric_utils import _compute
from .operator import SourceOperator

logger = get_logger()


def load(source: Union[SourceOperator, str]) -> DatasetDict:
    assert isinstance(
        source, (SourceOperator, str)
    ), "source must be a SourceOperator or a string"
    if isinstance(source, str):
        source, _ = fetch_artifact(source)
    return source().to_dataset()


def load_dataset(dataset_query: str) -> DatasetDict:
    dataset_query = dataset_query.replace("sys_prompt", "instruction")
    dataset_stream = get_dataset_artifact(dataset_query)
    return dataset_stream().to_dataset()


def evaluate(predictions, data) -> List[Dict[str, Any]]:
    return _compute(predictions=predictions, references=data)


@lru_cache
def _get_produce_with_cache(recipe_query):
    return get_dataset_artifact(recipe_query).produce


def produce(instance_or_instances, recipe_query):
    is_list = isinstance(instance_or_instances, list)
    if not is_list:
        instance_or_instances = [instance_or_instances]
    result = _get_produce_with_cache(recipe_query)(instance_or_instances)
    if not is_list:
        result = result[0]
    return result