From 44d79582306521f8961003c6a2cc191abcc6f7ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pablo=20Mart=C3=ADnez?= <58857054+elpekenin@users.noreply.github.com> Date: Thu, 13 Feb 2025 21:11:14 +0100 Subject: [PATCH] More typing (#76) * some more typing improvements * make comment more readable, do not leak defaults in `@overload` * move `@overload`, try and word the comment better * move things around --- docs/api_milc_interface.md | 11 +++--- docs/api_questions.md | 6 ++-- milc/milc_interface.py | 12 ++++--- milc/questions.py | 72 +++++++++++++++++++++++++++++++++++--- 4 files changed, 85 insertions(+), 16 deletions(-) diff --git a/docs/api_milc_interface.md b/docs/api_milc_interface.md index d206d96..4d5df46 100644 --- a/docs/api_milc_interface.md +++ b/docs/api_milc_interface.md @@ -117,7 +117,8 @@ Release the MILC lock. #### argument ```python -def argument(*args: Any, **kwargs: Any) -> Callable[..., Any] +def argument(*args: Any, + **kwargs: Any) -> Callable[[Callable[P, R]], Callable[P, R]] ``` Decorator to add an argument to a MILC command or subcommand. @@ -147,8 +148,10 @@ Execute the entrypoint function. #### entrypoint ```python -def entrypoint(description: str, - deprecated: Optional[str] = None) -> Callable[..., Any] +def entrypoint( + description: str, + deprecated: Optional[str] = None +) -> Callable[[Callable[P, R]], Callable[P, R]] ``` Decorator that marks the entrypoint used when a subcommand is not supplied. @@ -168,7 +171,7 @@ Decorator that marks the entrypoint used when a subcommand is not supplied. ```python def subcommand(description: str, hidden: bool = False, - **kwargs: Any) -> Callable[..., Any] + **kwargs: Any) -> Callable[[Callable[P, R]], Callable[P, R]] ``` Decorator to register a subcommand. diff --git a/docs/api_questions.md b/docs/api_questions.md index 7f7af2e..48ed9b4 100644 --- a/docs/api_questions.md +++ b/docs/api_questions.md @@ -66,9 +66,9 @@ def question(prompt: str, *args: Any, default: Optional[str] = None, confirm: bool = False, - answer_type: Callable[[str], str] = str, - validate: Optional[Callable[..., bool]] = None, - **kwargs: Any) -> Union[str, Any] + answer_type: Optional[Callable[[str], T]] = None, + validate: Optional[Callable[Concatenate[str, P], bool]] = None, + **kwargs: Any) -> Union[str, T, None] ``` Allow the user to type in a free-form string to answer. diff --git a/milc/milc_interface.py b/milc/milc_interface.py index 74378df..ec03e80 100644 --- a/milc/milc_interface.py +++ b/milc/milc_interface.py @@ -7,14 +7,18 @@ from logging import Logger from pathlib import Path from types import TracebackType -from typing import Any, Callable, Dict, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, Optional, Sequence, Type, TypeVar, Union from halo import Halo # type: ignore +from typing_extensions import ParamSpec from .attrdict import AttrDict from .configuration import Configuration from .milc import MILC +P = ParamSpec("P") +R = TypeVar("R") + class MILCInterface: def __init__(self) -> None: @@ -145,7 +149,7 @@ def release_lock(self) -> None: """ return self.milc.release_lock() - def argument(self, *args: Any, **kwargs: Any) -> Callable[..., Any]: + def argument(self, *args: Any, **kwargs: Any) -> Callable[[Callable[P, R]], Callable[P, R]]: """Decorator to add an argument to a MILC command or subcommand. """ return self.milc.argument(*args, **kwargs) @@ -160,7 +164,7 @@ def __call__(self) -> Any: """ return self.milc() - def entrypoint(self, description: str, deprecated: Optional[str] = None) -> Callable[..., Any]: + def entrypoint(self, description: str, deprecated: Optional[str] = None) -> Callable[[Callable[P, R]], Callable[P, R]]: """Decorator that marks the entrypoint used when a subcommand is not supplied. Args: description @@ -171,7 +175,7 @@ def entrypoint(self, description: str, deprecated: Optional[str] = None) -> Call """ return self.milc.entrypoint(description, deprecated) - def subcommand(self, description: str, hidden: bool = False, **kwargs: Any) -> Callable[..., Any]: + def subcommand(self, description: str, hidden: bool = False, **kwargs: Any) -> Callable[[Callable[P, R]], Callable[P, R]]: """Decorator to register a subcommand. Args: diff --git a/milc/questions.py b/milc/questions.py index 90c1c79..8bba3e4 100644 --- a/milc/questions.py +++ b/milc/questions.py @@ -1,11 +1,16 @@ """Sometimes you need to ask the user a question. MILC provides basic functions for collecting and validating user input. You can find these in the `milc.questions` module. """ from getpass import getpass -from typing import Any, Callable, Optional, Sequence, Union +from typing import Any, Callable, Optional, Sequence, TypeVar, Union, overload + +from typing_extensions import Concatenate, ParamSpec from milc import cli from .ansi import format_ansi +T = TypeVar("T") +P = ParamSpec("P") + def yesno(prompt: str, *args: Any, default: Optional[bool] = None, **kwargs: Any) -> bool: """Displays `prompt` to the user and gets a yes or no response. @@ -122,7 +127,7 @@ def password( return None -def _cast_answer(answer_type: Callable[[str], str], answer: str) -> Any: +def _cast_answer(answer_type: Callable[[str], T], answer: str) -> Optional[T]: """Attempt to convert answer to answer_type. """ try: @@ -132,15 +137,52 @@ def _cast_answer(answer_type: Callable[[str], str], answer: str) -> Any: return None +@overload +def question( + prompt: str, + *args: Any, + default: Optional[str] = ..., + confirm: bool = ..., + answer_type: None = ..., + validate: Optional[Callable[Concatenate[str, P], bool]] = ..., + **kwargs: Any, +) -> Optional[str]: + ... + + +@overload +def question( + prompt: str, + *args: Any, + default: Optional[str] = ..., + confirm: bool = ..., + answer_type: Callable[[str], T] = ..., + validate: Optional[Callable[Concatenate[str, P], bool]] = ..., + **kwargs: Any, +) -> Optional[T]: + ... + + +# NOTE: can't have a default value on an argument whose type annotation is a TypeVar +# this means that `answer_type: Callable[[str], T] = str` gives a typing error +# see https://github.com/python/mypy/issues/3737 +# +# due to this, we leave the default as `None`, while the actual implementation +# lives on a private function that receives all of its arguments from the public API. +# by doing this, the default value is "resolved" on callsite instead, making mypy happy. +# +# for better expresiveness, @overload variants are defined, to let the user know: +# a) no `answer_type` provided: return str | None +# b) `answer_type` converts str into T: return T | None def question( prompt: str, *args: Any, default: Optional[str] = None, confirm: bool = False, - answer_type: Callable[[str], str] = str, - validate: Optional[Callable[..., bool]] = None, + answer_type: Optional[Callable[[str], T]] = None, + validate: Optional[Callable[Concatenate[str, P], bool]] = None, **kwargs: Any, -) -> Union[str, Any]: +) -> Union[str, T, None]: """Allow the user to type in a free-form string to answer. | Argument | Description | @@ -151,6 +193,26 @@ def question( | answer_type | Specify a type function for the answer. Will re-prompt the user if the function raises any errors. Common choices here include int, float, and decimal.Decimal. | | validate | This is an optional function that can be used to validate the answer. It should return True or False and have the following signature:

`def function_name(answer, *args, **kwargs):` | """ + return _question( + prompt, + *args, + default=default, + confirm=confirm, + answer_type=answer_type or str, + validate=validate, + **kwargs, + ) + + +def _question( + prompt: str, + *args: Any, + default: Optional[str], + confirm: bool, + answer_type: Callable[[str], T], + validate: Optional[Callable[Concatenate[str, P], bool]], + **kwargs: Any, +) -> Union[str, T, None]: if not cli.interactive: return default