From ffe1f8dca4bce6aedc6c8f6d507d1ccf6fa82006 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 15 Feb 2024 23:13:33 -0800 Subject: [PATCH 1/4] add generic --- catalogue/__init__.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/catalogue/__init__.py b/catalogue/__init__.py index 8c13003..6f84612 100644 --- a/catalogue/__init__.py +++ b/catalogue/__init__.py @@ -1,4 +1,4 @@ -from typing import Sequence, Any, Dict, Tuple, Callable, Optional, TypeVar, Union +from typing import Sequence, Any, Dict, Tuple, Callable, Optional, TypeVar, Union, Generic from typing import List import inspect import importlib.metadata @@ -25,7 +25,7 @@ def create(*namespace: str, entry_points: bool = False) -> "Registry": return Registry(namespace, entry_points=entry_points) -class Registry(object): +class Registry(Generic[InFunc]): def __init__(self, namespace: Sequence[str], entry_points: bool = False) -> None: """Initialize a new registry. @@ -47,7 +47,7 @@ def __contains__(self, name: str) -> bool: return has_entry_point or namespace in REGISTRY def __call__( - self, name: str, func: Optional[Any] = None + self, name: str, func: Optional[InFunc] = None ) -> Callable[[InFunc], InFunc]: """Register a function for a given namespace. Same as Registry.register. @@ -58,7 +58,7 @@ def __call__( return self.register(name, func=func) def register( - self, name: str, *, func: Optional[Any] = None + self, name: str, *, func: Optional[InFunc] = None ) -> Callable[[InFunc], InFunc]: """Register a function for a given namespace. @@ -75,7 +75,7 @@ def do_registration(func): return do_registration(func) return do_registration - def get(self, name: str) -> Any: + def get(self, name: str) -> InFunc: """Get the registered function for a given name. name (str): The name. @@ -94,7 +94,7 @@ def get(self, name: str) -> Any: ) return _get(namespace) - def get_all(self) -> Dict[str, Any]: + def get_all(self) -> Dict[str, InFunc]: """Get all functions belonging to this registry's namespace. RETURNS (Dict[str, Any]): The functions, keyed by name. @@ -109,7 +109,7 @@ def get_all(self) -> Dict[str, Any]: result[keys[-1]] = value return result - def get_entry_points(self) -> Dict[str, Any]: + def get_entry_points(self) -> Dict[str, InFunc]: """Get registered entry points from other packages for this namespace. RETURNS (Dict[str, Any]): Entry points, keyed by name. @@ -119,7 +119,7 @@ def get_entry_points(self) -> Dict[str, Any]: result[entry_point.name] = entry_point.load() return result - def get_entry_point(self, name: str, default: Optional[Any] = None) -> Any: + def get_entry_point(self, name: str, default: Optional[InFunc] = None) -> InFunc: """Check if registered entry point is available for a given name in the namespace and load it. Otherwise, return the default value. From fde93ba9d1e4815459fb0e580bbfc63664c95ed4 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 15 Feb 2024 23:36:36 -0800 Subject: [PATCH 2/4] fix --- catalogue/__init__.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/catalogue/__init__.py b/catalogue/__init__.py index 6f84612..f075acf 100644 --- a/catalogue/__init__.py +++ b/catalogue/__init__.py @@ -43,7 +43,7 @@ def __contains__(self, name: str) -> bool: RETURNS (bool): Whether the name is in the registry. """ namespace = tuple(list(self.namespace) + [name]) - has_entry_point = self.entry_points and self.get_entry_point(name) + has_entry_point = self.entry_points and self.get_entry_point(name) is not None return has_entry_point or namespace in REGISTRY def __call__( @@ -52,7 +52,7 @@ def __call__( """Register a function for a given namespace. Same as Registry.register. name (str): The name to register under the namespace. - func (Any): Optional function to register (if not used as decorator). + func (InFunc): Optional function to register (if not used as decorator). RETURNS (Callable): The decorator. """ return self.register(name, func=func) @@ -63,7 +63,7 @@ def register( """Register a function for a given namespace. name (str): The name to register under the namespace. - func (Any): Optional function to register (if not used as decorator). + func (InFunc): Optional function to register (if not used as decorator). RETURNS (Callable): The decorator. """ @@ -79,7 +79,7 @@ def get(self, name: str) -> InFunc: """Get the registered function for a given name. name (str): The name. - RETURNS (Any): The registered function. + RETURNS (InFunc): The registered function. """ if self.entry_points: from_entry_point = self.get_entry_point(name) @@ -97,7 +97,7 @@ def get(self, name: str) -> InFunc: def get_all(self) -> Dict[str, InFunc]: """Get all functions belonging to this registry's namespace. - RETURNS (Dict[str, Any]): The functions, keyed by name. + RETURNS (Dict[str, InFunc]): The functions, keyed by name. """ global REGISTRY result = {} @@ -112,19 +112,19 @@ def get_all(self) -> Dict[str, InFunc]: def get_entry_points(self) -> Dict[str, InFunc]: """Get registered entry points from other packages for this namespace. - RETURNS (Dict[str, Any]): Entry points, keyed by name. + RETURNS (Dict[str, InFunc]): Entry points, keyed by name. """ result = {} for entry_point in self._get_entry_points(): result[entry_point.name] = entry_point.load() return result - def get_entry_point(self, name: str, default: Optional[InFunc] = None) -> InFunc: + def get_entry_point(self, name: str, default: Optional[InFunc] = None) -> Optional[InFunc]: """Check if registered entry point is available for a given name in the namespace and load it. Otherwise, return the default value. name (str): Name of entry point to load. - default (Any): The default value to return. + default (InFunc): The default value to return. RETURNS (Any): The loaded entry point or the default value. """ for entry_point in self._get_entry_points(): From a7395c5d9c544df136427bf473cd784b5e6a0807 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 16 Feb 2024 00:11:26 -0800 Subject: [PATCH 3/4] get mypy to pass --- catalogue/__init__.py | 4 ++++ catalogue/tests/test_catalogue.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/catalogue/__init__.py b/catalogue/__init__.py index f075acf..d23d33f 100644 --- a/catalogue/__init__.py +++ b/catalogue/__init__.py @@ -1,4 +1,5 @@ from typing import Sequence, Any, Dict, Tuple, Callable, Optional, TypeVar, Union, Generic +from types import ModuleType, MethodType, FunctionType, TracebackType, FrameType, CodeType from typing import List import inspect import importlib.metadata @@ -152,6 +153,9 @@ def find(self, name: str) -> Dict[str, Optional[Union[str, int]]]: line_no: Optional[int] = None file_name: Optional[str] = None try: + if not isinstance(func, (ModuleType, MethodType, FunctionType, TracebackType, FrameType, CodeType, type)): + raise TypeError(f"func type {type(func)} is not a valid type for inspect.getsourcelines()") + _, line_no = inspect.getsourcelines(func) file_name = inspect.getfile(func) except (TypeError, ValueError): diff --git a/catalogue/tests/test_catalogue.py b/catalogue/tests/test_catalogue.py index 5e910ac..56ee469 100644 --- a/catalogue/tests/test_catalogue.py +++ b/catalogue/tests/test_catalogue.py @@ -156,3 +156,31 @@ def a(): assert info["file"] == str(Path(__file__)) assert info["docstring"] == "This is a registered function." assert info["line_no"] + +def test_registry_find_module(): + import json + + test_registry = catalogue.create("test_registry_find_module") + + test_registry.register("json", func=json) + + info = test_registry.find("json") + assert info["module"] == "json" + assert info["file"] == json.__file__ + assert info["docstring"] == json.__doc__.strip('\n') + assert info["line_no"] == 0 + +def test_registry_find_class(): + test_registry = catalogue.create("test_registry_find_class") + + class TestClass: + """This is a registered class.""" + pass + + test_registry.register("test_class", func=TestClass) + + info = test_registry.find("test_class") + assert info["module"] == "catalogue.tests.test_catalogue" + assert info["file"] == str(Path(__file__)) + assert info["docstring"] == TestClass.__doc__ + assert info["line_no"] \ No newline at end of file From b37aede76d200507017bdf1d4c8f7df5ebc978f7 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Sun, 18 Feb 2024 01:11:01 -0800 Subject: [PATCH 4/4] finish --- catalogue/__init__.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/catalogue/__init__.py b/catalogue/__init__.py index d23d33f..702e797 100644 --- a/catalogue/__init__.py +++ b/catalogue/__init__.py @@ -1,4 +1,4 @@ -from typing import Sequence, Any, Dict, Tuple, Callable, Optional, TypeVar, Union, Generic +from typing import Sequence, Any, Dict, Tuple, Callable, Optional, TypeVar, Union, Generic, Type from types import ModuleType, MethodType, FunctionType, TracebackType, FrameType, CodeType from typing import List import inspect @@ -12,9 +12,9 @@ InFunc = TypeVar("InFunc") +S = TypeVar('S') - -def create(*namespace: str, entry_points: bool = False) -> "Registry": +def create(*namespace: str, entry_points: bool = False, generic_type: Optional[Type[S]] = None) -> "Registry[S]": """Create a new registry. *namespace (str): The namespace, e.g. "spacy" or "spacy", "architectures". @@ -23,7 +23,11 @@ def create(*namespace: str, entry_points: bool = False) -> "Registry": """ if check_exists(*namespace): raise RegistryError(f"Namespace already exists: {namespace}") - return Registry(namespace, entry_points=entry_points) + + if generic_type is None: + return Registry[Any](namespace, entry_points=entry_points) + else: + return Registry[S](namespace, entry_points=entry_points) class Registry(Generic[InFunc]): @@ -49,7 +53,7 @@ def __contains__(self, name: str) -> bool: def __call__( self, name: str, func: Optional[InFunc] = None - ) -> Callable[[InFunc], InFunc]: + ) -> Union[Callable[[InFunc], InFunc], InFunc]: """Register a function for a given namespace. Same as Registry.register. name (str): The name to register under the namespace. @@ -60,7 +64,7 @@ def __call__( def register( self, name: str, *, func: Optional[InFunc] = None - ) -> Callable[[InFunc], InFunc]: + ) -> Union[Callable[[InFunc], InFunc], InFunc]: """Register a function for a given namespace. name (str): The name to register under the namespace. @@ -133,10 +137,11 @@ def get_entry_point(self, name: str, default: Optional[InFunc] = None) -> Option return entry_point.load() return default - def _get_entry_points(self) -> List[importlib.metadata.EntryPoint]: + def _get_entry_points(self) -> Union[List[importlib.metadata.EntryPoint], importlib.metadata.EntryPoints]: if hasattr(AVAILABLE_ENTRY_POINTS, "select"): return AVAILABLE_ENTRY_POINTS.select(group=self.entry_point_namespace) else: # dict + assert isinstance(AVAILABLE_ENTRY_POINTS, dict) return AVAILABLE_ENTRY_POINTS.get(self.entry_point_namespace, []) def find(self, name: str) -> Dict[str, Optional[Union[str, int]]]: @@ -168,7 +173,6 @@ def find(self, name: str) -> Dict[str, Optional[Union[str, int]]]: "docstring": inspect.cleandoc(docstring) if docstring else None, } - def check_exists(*namespace: str) -> bool: """Check if a namespace exists.