Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add generic typing to registry #69

Draft
wants to merge 4 commits into
base: v3
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 30 additions & 22 deletions catalogue/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Sequence, Any, Dict, Tuple, Callable, Optional, TypeVar, Union
from typing import Sequence, Any, Dict, Tuple, Callable, Optional, TypeVar, Union, Generic, Type
from types import ModuleType, MethodType, FunctionType, TracebackType, FrameType, CodeType
from typing import List
import inspect
import importlib.metadata
Expand All @@ -11,9 +12,9 @@


InFunc = TypeVar("InFunc")
S = TypeVar('S')


def create(*namespace: str, entry_points: bool = False) -> "Registry":
def create(*namespace: str, entry_points: bool = False, generic_type: Optional[Type[S]] = None) -> "Registry[S]":
"""Create a new registry.

*namespace (str): The namespace, e.g. "spacy" or "spacy", "architectures".
Expand All @@ -22,10 +23,14 @@ def create(*namespace: str, entry_points: bool = False) -> "Registry":
"""
if check_exists(*namespace):
raise RegistryError(f"Namespace already exists: {namespace}")
return Registry(namespace, entry_points=entry_points)

if generic_type is None:
return Registry[Any](namespace, entry_points=entry_points)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't exactly right...but I couldn't figure out how to get mypy to accept a direct default for the argument

else:
return Registry[S](namespace, entry_points=entry_points)


class Registry(object):
class Registry(Generic[InFunc]):
def __init__(self, namespace: Sequence[str], entry_points: bool = False) -> None:
"""Initialize a new registry.

Expand All @@ -43,27 +48,27 @@ def __contains__(self, name: str) -> bool:
RETURNS (bool): Whether the name is in the registry.
"""
namespace = tuple(list(self.namespace) + [name])
has_entry_point = self.entry_points and self.get_entry_point(name)
has_entry_point = self.entry_points and self.get_entry_point(name) is not None
return has_entry_point or namespace in REGISTRY

def __call__(
self, name: str, func: Optional[Any] = None
) -> Callable[[InFunc], InFunc]:
self, name: str, func: Optional[InFunc] = None
) -> Union[Callable[[InFunc], InFunc], InFunc]:
"""Register a function for a given namespace. Same as Registry.register.

name (str): The name to register under the namespace.
func (Any): Optional function to register (if not used as decorator).
func (InFunc): Optional function to register (if not used as decorator).
RETURNS (Callable): The decorator.
"""
return self.register(name, func=func)

def register(
self, name: str, *, func: Optional[Any] = None
) -> Callable[[InFunc], InFunc]:
self, name: str, *, func: Optional[InFunc] = None
) -> Union[Callable[[InFunc], InFunc], InFunc]:
"""Register a function for a given namespace.

name (str): The name to register under the namespace.
func (Any): Optional function to register (if not used as decorator).
func (InFunc): Optional function to register (if not used as decorator).
RETURNS (Callable): The decorator.
"""

Expand All @@ -75,11 +80,11 @@ def do_registration(func):
return do_registration(func)
return do_registration

def get(self, name: str) -> Any:
def get(self, name: str) -> InFunc:
"""Get the registered function for a given name.

name (str): The name.
RETURNS (Any): The registered function.
RETURNS (InFunc): The registered function.
"""
if self.entry_points:
from_entry_point = self.get_entry_point(name)
Expand All @@ -94,10 +99,10 @@ def get(self, name: str) -> Any:
)
return _get(namespace)

def get_all(self) -> Dict[str, Any]:
def get_all(self) -> Dict[str, InFunc]:
"""Get all functions belonging to this registry's namespace.

RETURNS (Dict[str, Any]): The functions, keyed by name.
RETURNS (Dict[str, InFunc]): The functions, keyed by name.
"""
global REGISTRY
result = {}
Expand All @@ -109,33 +114,34 @@ def get_all(self) -> Dict[str, Any]:
result[keys[-1]] = value
return result

def get_entry_points(self) -> Dict[str, Any]:
def get_entry_points(self) -> Dict[str, InFunc]:
"""Get registered entry points from other packages for this namespace.

RETURNS (Dict[str, Any]): Entry points, keyed by name.
RETURNS (Dict[str, InFunc]): Entry points, keyed by name.
"""
result = {}
for entry_point in self._get_entry_points():
result[entry_point.name] = entry_point.load()
return result

def get_entry_point(self, name: str, default: Optional[Any] = None) -> Any:
def get_entry_point(self, name: str, default: Optional[InFunc] = None) -> Optional[InFunc]:
"""Check if registered entry point is available for a given name in the
namespace and load it. Otherwise, return the default value.

name (str): Name of entry point to load.
default (Any): The default value to return.
default (InFunc): The default value to return.
RETURNS (Any): The loaded entry point or the default value.
"""
for entry_point in self._get_entry_points():
if entry_point.name == name:
return entry_point.load()
return default

def _get_entry_points(self) -> List[importlib.metadata.EntryPoint]:
def _get_entry_points(self) -> Union[List[importlib.metadata.EntryPoint], importlib.metadata.EntryPoints]:
if hasattr(AVAILABLE_ENTRY_POINTS, "select"):
return AVAILABLE_ENTRY_POINTS.select(group=self.entry_point_namespace)
else: # dict
assert isinstance(AVAILABLE_ENTRY_POINTS, dict)
return AVAILABLE_ENTRY_POINTS.get(self.entry_point_namespace, [])

def find(self, name: str) -> Dict[str, Optional[Union[str, int]]]:
Expand All @@ -152,6 +158,9 @@ def find(self, name: str) -> Dict[str, Optional[Union[str, int]]]:
line_no: Optional[int] = None
file_name: Optional[str] = None
try:
if not isinstance(func, (ModuleType, MethodType, FunctionType, TracebackType, FrameType, CodeType, type)):
raise TypeError(f"func type {type(func)} is not a valid type for inspect.getsourcelines()")

_, line_no = inspect.getsourcelines(func)
file_name = inspect.getfile(func)
except (TypeError, ValueError):
Expand All @@ -164,7 +173,6 @@ def find(self, name: str) -> Dict[str, Optional[Union[str, int]]]:
"docstring": inspect.cleandoc(docstring) if docstring else None,
}


def check_exists(*namespace: str) -> bool:
"""Check if a namespace exists.

Expand Down
28 changes: 28 additions & 0 deletions catalogue/tests/test_catalogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,31 @@ def a():
assert info["file"] == str(Path(__file__))
assert info["docstring"] == "This is a registered function."
assert info["line_no"]

def test_registry_find_module():
import json

test_registry = catalogue.create("test_registry_find_module")

test_registry.register("json", func=json)

info = test_registry.find("json")
assert info["module"] == "json"
assert info["file"] == json.__file__
assert info["docstring"] == json.__doc__.strip('\n')
assert info["line_no"] == 0

def test_registry_find_class():
test_registry = catalogue.create("test_registry_find_class")

class TestClass:
"""This is a registered class."""
pass

test_registry.register("test_class", func=TestClass)

info = test_registry.find("test_class")
assert info["module"] == "catalogue.tests.test_catalogue"
assert info["file"] == str(Path(__file__))
assert info["docstring"] == TestClass.__doc__
assert info["line_no"]
Loading