Skip to content

Commit

Permalink
Updates
Browse files Browse the repository at this point in the history
Signed-off-by: Mark Graham <[email protected]>
  • Loading branch information
marksgraham committed Oct 27, 2023
1 parent 8a51d79 commit 7e6fa86
Showing 1 changed file with 13 additions and 16 deletions.
29 changes: 13 additions & 16 deletions monai/networks/layers/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,32 +79,29 @@ class LayerFactory(ComponentStore):
callables. These functions are referred to by name and can be added at any time.
"""

def add_factory_callable(self, func: Callable, name: str | None = None, desc: str | None = None) -> None:
name = name if name is not None else getattr(func, "__name__", "???")
def __init__(self, name: str, description: str) -> None:
super().__init__(name, description)
self.__doc__ = (
f"Layer Factory '{name}': {description}\n".strip()
+ "\nPlease see :py:class:`monai.networks.layers.split_args` for additional args parsing."
+ "\n\nThe supported members are:"
)

def add_factory_callable(self, name: str, func: Callable, desc: str | None = None) -> None:
"""
Add the factory function to this object under the given name, with optional description.
"""
description: str = desc or func.__doc__ or ""
self.add(name.upper(), description, func)
self.__doc__ = (
"The supported member"
+ ("s are: " if len(self.names) > 1 else " is: ")
+ ", ".join(f"``{name}``" for name in self.names)
+ ".\nPlease see :py:class:`monai.networks.layers.split_args` for additional args parsing."
)
# append name to the docstring
assert self.__doc__ is not None
self.__doc__ += f"{', ' if len(self.names)>1 else ' '}``{name}``"

def add_factory_class(self, name: str, cls: type, desc: str | None = None) -> None:
"""
Adds a factory function which returns the supplied class under the given name, with optional description.
"""
description: str = desc or cls.__doc__ or ""
self.add(name.upper(), description, lambda x=None: cls)
self.__doc__ = (
"The supported member"
+ ("s are: " if len(self.names) > 1 else " is: ")
+ ", ".join(f"``{name}``" for name in self.names)
+ ".\nPlease see :py:class:`monai.networks.layers.split_args` for additional args parsing."
)
self.add_factory_callable(name, lambda x=None: cls, desc)

def factory_function(self, name: str) -> Callable:
"""
Expand Down

0 comments on commit 7e6fa86

Please sign in to comment.