Skip to content

Commit

Permalink
improve type hint for decorators
Browse files Browse the repository at this point in the history
  • Loading branch information
elpekenin committed Jan 25, 2025
1 parent 711e931 commit f249502
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 17 deletions.
15 changes: 9 additions & 6 deletions docs/api_milc.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ Locate the config file.
#### argument

```python
def argument(*args: Any, **kwargs: Any) -> Callable[..., Any]
def argument(*args: Any,
**kwargs: Any) -> Callable[[Callable[P, R]], Callable[P, R]]
```

Decorator to call self.add_argument or self.<subcommand>.add_argument.
Expand Down Expand Up @@ -269,8 +270,10 @@ Execute the entrypoint function.
#### entrypoint

```python
def entrypoint(description: str,
deprecated: Optional[str] = None) -> Callable[..., Any]
def entrypoint(
description: str,
deprecated: Optional[str] = None
) -> Callable[[Callable[P, R]], Callable[P, R]]
```

Decorator that marks the entrypoint used when a subcommand is not supplied.
Expand All @@ -288,11 +291,11 @@ Decorator that marks the entrypoint used when a subcommand is not supplied.
#### add\_subcommand

```python
def add_subcommand(handler: Callable[..., Any],
def add_subcommand(handler: Callable[P, R],
description: str,
hidden: bool = False,
deprecated: Optional[str] = None,
**kwargs: Any) -> Callable[..., Any]
**kwargs: Any) -> Callable[P, R]
```

Register a subcommand.
Expand Down Expand Up @@ -320,7 +323,7 @@ Register a subcommand.
```python
def subcommand(description: str,
hidden: bool = False,
**kwargs: Any) -> Callable[..., Any]
**kwargs: Any) -> Callable[[Callable[P, R]], Callable[P, R]]
```

Decorator to register a subcommand.
Expand Down
24 changes: 13 additions & 11 deletions milc/milc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from platform import platform
from tempfile import NamedTemporaryFile
from types import TracebackType
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union

try:
import threading
Expand All @@ -23,14 +23,16 @@
import colorama
from halo import Halo # type: ignore
from platformdirs import user_config_dir
from typing_extensions import ParamSpec
from spinners.spinners import Spinners # type: ignore

from .ansi import MILCFormatter, ansi_colors, ansi_config, ansi_escape, format_ansi
from .attrdict import AttrDict
from .configuration import Configuration, SubparserWrapper, get_argument_name, get_argument_strings, handle_store_boolean
from ._in_argv import _in_argv, _index_argv

# FIXME: Replace Callable[..., Any] with better definitions
P = ParamSpec("P")
R = TypeVar("R")


class MILC(object):
Expand Down Expand Up @@ -59,7 +61,7 @@ def __init__(self, name: Optional[str] = None, author: Optional[str] = None, ver
self.author = author
self._config_store_true: Sequence[str] = []
self._config_store_false: Sequence[str] = []
self._entrypoint: Callable[[Any], Any] = lambda _: None
self._entrypoint: Callable[..., Any] = lambda _: None
self._spinners: Dict[str, Dict[str, Union[int, Sequence[str]]]] = {}
self._subcommand = None
self._inside_context_manager = False
Expand Down Expand Up @@ -407,13 +409,13 @@ def _handle_arg_parsing(self, config_name: str, arg_name: str, args: Sequence[An
if _in_argv(arg):
self.args_passed[config_name][arg_name] = True

def argument(self, *args: Any, **kwargs: Any) -> Callable[..., Any]:
def argument(self, *args: Any, **kwargs: Any) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""Decorator to call self.add_argument or self.<subcommand>.add_argument.
"""
if self._inside_context_manager:
raise RuntimeError('You must run this before the with statement!')

def argument_function(handler: Callable[..., Any]) -> Callable[..., Any]:
def argument_function(handler: Callable[P, R]) -> Callable[P, R]:
config_name = handler.__name__
subcommand_name = config_name.replace("_", "-")
arg_name = get_argument_name(self._arg_parser, *args, **kwargs)
Expand Down Expand Up @@ -618,7 +620,7 @@ def __call__(self) -> Any:

raise RuntimeError('No entrypoint provided!')

def entrypoint(self, description: str, deprecated: Optional[str] = None) -> Callable[..., Any]:
def entrypoint(self, description: str, deprecated: Optional[str] = None) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""Decorator that marks the entrypoint used when a subcommand is not supplied.
Args:
description
Expand All @@ -634,7 +636,7 @@ def entrypoint(self, description: str, deprecated: Optional[str] = None) -> Call
self.description = description
self.release_lock()

def entrypoint_func(handler: Callable[..., Any]) -> Callable[..., Any]:
def entrypoint_func(handler: Callable[P, R]) -> Callable[P, R]:
self.acquire_lock()

if deprecated:
Expand All @@ -650,12 +652,12 @@ def entrypoint_func(handler: Callable[..., Any]) -> Callable[..., Any]:

def add_subcommand(
self,
handler: Callable[..., Any],
handler: Callable[P, R],
description: str,
hidden: bool = False,
deprecated: Optional[str] = None,
**kwargs: Any,
) -> Callable[..., Any]:
) -> Callable[P, R]:
"""Register a subcommand.
Args:
Expand Down Expand Up @@ -699,7 +701,7 @@ def add_subcommand(

return handler

def subcommand(self, description: str, hidden: bool = False, **kwargs: Any) -> Callable[..., Any]:
def subcommand(self, description: str, hidden: bool = False, **kwargs: Any) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""Decorator to register a subcommand.
Args:
Expand All @@ -710,7 +712,7 @@ def subcommand(self, description: str, hidden: bool = False, **kwargs: Any) -> C
hidden
When True don't display this command in --help
"""
def subcommand_function(handler: Callable[..., Any]) -> Callable[..., Any]:
def subcommand_function(handler: Callable[P, R]) -> Callable[P, R]:
return self.add_subcommand(handler, description, hidden=hidden, **kwargs)

return subcommand_function
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,6 @@
"platformdirs",
"spinners",
"types-colorama",
"typing_extensions",
],
)

0 comments on commit f249502

Please sign in to comment.