|
import copy |
|
import functools |
|
import importlib.util |
|
import os |
|
from contextlib import contextmanager |
|
from pathlib import Path |
|
from types import ModuleType |
|
from typing import Any, Callable, Generic, Optional, Sequence, TypeVar, Union |
|
import catalogue |
|
T = TypeVar('T') |
|
TypeBoundT = TypeVar('TypeBoundT', bound=type) |
|
CallableBoundT = TypeVar('CallableBoundT', bound=Callable[..., Any]) |
|
|
|
class TypedRegistry(catalogue.Registry, Generic[T]): |
|
"""A thin wrapper around catalogue.Registry to add static typing and. |
|
|
|
descriptions. |
|
""" |
|
|
|
def __init__(self, namespace: Sequence[str], entry_points: bool=False, description: str='') -> None: |
|
super().__init__(namespace, entry_points=entry_points) |
|
self.description = description |
|
|
|
def __call__(self, name: str, func: Optional[T]=None) -> Callable[[T], T]: |
|
return super().__call__(name, func) |
|
|
|
def register(self, name: str, *, func: Optional[T]=None) -> T: |
|
return super().register(name, func=func) |
|
|
|
def register_class(self, name: str, *, func: Optional[TypeBoundT]=None) -> TypeBoundT: |
|
return super().register(name, func=func) |
|
|
|
def get(self, name: str) -> T: |
|
return super().get(name) |
|
|
|
def get_all(self) -> dict[str, T]: |
|
return super().get_all() |
|
|
|
def get_entry_point(self, name: str, default: Optional[T]=None) -> T: |
|
return super().get_entry_point(name, default=default) |
|
|
|
def get_entry_points(self) -> dict[str, T]: |
|
return super().get_entry_points() |
|
S = TypeVar('S') |
|
|
|
def create_registry(*namespace: str, generic_type: type[S], entry_points: bool=False, description: str='') -> 'TypedRegistry[S]': |
|
"""Create a new registry. |
|
|
|
Args: |
|
namespace (str): The namespace, e.g. "llmfoundry.loggers" |
|
generic_type (Type[S]): The type of the registry. |
|
entry_points (bool): Accept registered functions from entry points. |
|
description (str): A description of the registry. |
|
|
|
Returns: |
|
The TypedRegistry object. |
|
""" |
|
if catalogue.check_exists(*namespace): |
|
raise catalogue.RegistryError(f'Namespace already exists: {namespace}') |
|
return TypedRegistry[generic_type](namespace, entry_points=entry_points, description=description) |
|
|
|
def construct_from_registry(name: str, registry: TypedRegistry, partial_function: bool=True, pre_validation_function: Optional[Union[Callable[[Any], None], type]]=None, post_validation_function: Optional[Callable[[Any], None]]=None, kwargs: Optional[dict[str, Any]]=None) -> Any: |
|
"""Helper function to build an item from the registry. |
|
|
|
Args: |
|
name (str): The name of the registered item |
|
registry (catalogue.Registry): The registry to fetch the item from |
|
partial_function (bool, optional): Whether to return a partial function for registered callables. Defaults to True. |
|
pre_validation_function (Optional[Union[Callable[[Any], None], type]], optional): An optional validation function called |
|
before constructing the item to return. This should throw an exception if validation fails. Defaults to None. |
|
post_validation_function (Optional[Callable[[Any], None]], optional): An optional validation function called after |
|
constructing the item to return. This should throw an exception if validation fails. Defaults to None. |
|
kwargs (Optional[Dict[str, Any]]): Other relevant keyword arguments. |
|
|
|
Raises: |
|
ValueError: If the validation functions failed or the registered item is invalid |
|
|
|
Returns: |
|
Any: The constructed item from the registry |
|
""" |
|
if kwargs is None: |
|
kwargs = {} |
|
registered_constructor = registry.get(name) |
|
print('construction') |
|
print(registered_constructor) |
|
if pre_validation_function is not None: |
|
if isinstance(pre_validation_function, type): |
|
if not issubclass(registered_constructor, pre_validation_function): |
|
raise ValueError(f'Expected {name} to be of type {pre_validation_function}, but got {type(registered_constructor)}') |
|
elif isinstance(pre_validation_function, Callable): |
|
pre_validation_function(registered_constructor) |
|
else: |
|
raise ValueError(f'Expected pre_validation_function to be a callable or a type, but got {type(pre_validation_function)}') |
|
if isinstance(registered_constructor, type) or (callable(registered_constructor) and (not partial_function)): |
|
constructed_item = registered_constructor(**kwargs) |
|
elif callable(registered_constructor): |
|
constructed_item = functools.partial(registered_constructor, **kwargs) |
|
else: |
|
raise ValueError(f'Expected {name} to be a class or function, but got {type(registered_constructor)}') |
|
if post_validation_function is not None: |
|
post_validation_function(constructed_item) |
|
return constructed_item |
|
|
|
def import_file(loc: Union[str, Path]) -> ModuleType: |
|
"""Import module from a file. |
|
|
|
Used to run arbitrary python code. |
|
|
|
Args: |
|
name (str): Name of module to load. |
|
loc (str / Path): Path to the file. |
|
|
|
Returns: |
|
ModuleType: The module object. |
|
""" |
|
if not os.path.exists(loc): |
|
raise FileNotFoundError(f'File {loc} does not exist.') |
|
spec = importlib.util.spec_from_file_location('python_code', str(loc)) |
|
assert spec is not None |
|
assert spec.loader is not None |
|
module = importlib.util.module_from_spec(spec) |
|
try: |
|
spec.loader.exec_module(module) |
|
except Exception as e: |
|
raise RuntimeError(f'Error executing {loc}') from e |
|
return module |
|
|
|
@contextmanager |
|
def save_registry(): |
|
"""Save the registry state and restore after the context manager exits.""" |
|
saved_registry_state = copy.deepcopy(catalogue.REGISTRY) |
|
yield |
|
catalogue.REGISTRY = saved_registry_state |