File size: 1,856 Bytes
b5ae7e6
36f3d38
1849dad
36f3d38
1849dad
 
b5ae7e6
1849dad
 
 
 
b5ae7e6
1849dad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5ae7e6
 
 
 
1849dad
b5ae7e6
1849dad
 
 
 
 
b5ae7e6
 
 
 
 
 
 
1849dad
 
b5ae7e6
1849dad
b5ae7e6
1849dad
b5ae7e6
1849dad
b5ae7e6
 
1849dad
36f3d38
b5ae7e6
1849dad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import importlib
import inspect
import os

from .artifact import Artifact, Artifactories
from .catalog import LocalCatalog, GithubCatalog, PATHS_SEP
from .utils import Singleton


UNITXT_ARTIFACTORIES_ENV_VAR = 'UNITXT_ARTIFACTORIES'

# Usage
non_registered_files = [
    "__init__.py",
    "artifact.py",
    "utils.py",
    "register.py",
    "metric.py",
    "dataset.py",
    "blocks.py",
]


def _register_all_catalogs():
    Artifactories().register_atrifactory(LocalCatalog())
    if UNITXT_ARTIFACTORIES_ENV_VAR in os.environ:
        for path in os.environ[UNITXT_ARTIFACTORIES_ENV_VAR].split(PATHS_SEP):
            Artifactories().register_atrifactory(LocalCatalog(location=path))
    Artifactories().register_atrifactory(GithubCatalog())

def _register_all_artifacts():
    dir = os.path.dirname(__file__)
    file_name = os.path.basename(__file__)

    for file in os.listdir(dir):
        if file.endswith(".py") and file not in non_registered_files and file != file_name:
            module_name = file.replace(".py", "")

            module = importlib.import_module("." + module_name, __package__)

            for name, obj in inspect.getmembers(module):
                # Make sure the object is a class
                if inspect.isclass(obj):
                    # Make sure the class is a subclass of Artifact (but not Artifact itself)
                    if issubclass(obj, Artifact) and obj is not Artifact:
                        Artifact.register_class(obj)


class ProjectArtifactRegisterer(metaclass=Singleton):
    def __init__(self):
        if not hasattr(self, "_registered"):
            self._registered = False

        if not self._registered:
            _register_all_catalogs()
            _register_all_artifacts()
            self._registered = True


def register_all_artifacts():
    ProjectArtifactRegisterer()