File size: 403 Bytes
8fb5471
77a9fb3
4740821
77a9fb3
8fb5471
 
 
 
 
9b0a562
 
1
2
3
4
5
6
7
8
9
10
11
12
from typing import Union
from datasets import DatasetDict
from .artifact import fetch_artifact
from .operator import StreamSource


def load_dataset(source: Union[StreamSource, str]) -> DatasetDict:
    assert isinstance(source, (StreamSource, str)), "source must be a StreamSource or a string"
    if isinstance(source, str):
        source, _ = fetch_artifact(source)
    return source().to_dataset()