From 48b944cdc3f82b6eb4ff98fe0b6638291df82b14 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Mon, 9 Dec 2024 13:57:27 +0100 Subject: [PATCH] Add patch to test --- .actions/assistant.py | 11 ++--------- src/lightning/pytorch/cli.py | 19 +++++++++++++------ tests/parity_fabric/test_parity_ddp.py | 3 +++ 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/.actions/assistant.py b/.actions/assistant.py index 52f152770209c..7d2dfcd61710d 100644 --- a/.actions/assistant.py +++ b/.actions/assistant.py @@ -483,16 +483,9 @@ def convert_version2nightly(ver_file: str = "src/version.info") -> None: if __name__ == "__main__": - import sys - import jsonargparse - from jsonargparse import ArgumentParser - - def _parse_known_args_patch(self: ArgumentParser, args: Any = None, namespace: Any = None) -> tuple[Any, Any]: - namespace, args = super(ArgumentParser, self)._parse_known_args(args, namespace, intermixed=False) # type: ignore - return namespace, args + from lightning.pytorch.cli import patch_jsonargparse_python_3_12_8 - if sys.version_info >= (3, 12, 8): - setattr(ArgumentParser, "_parse_known_args", _parse_known_args_patch) + patch_jsonargparse_python_3_12_8() # Required until fix https://github.com/omni-us/jsonargparse/issues/641 jsonargparse.CLI(AssistantCLI, as_positional=False) diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 330241948c5da..bc44a47b945e8 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -37,6 +37,18 @@ _JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache("jsonargparse[signatures]>=4.27.7") + +def patch_jsonargparse_python_3_12_8(): + if sys.version_info < (3, 12, 8): + return + + def _parse_known_args_patch(self: ArgumentParser, args: Any = None, namespace: Any = None) -> tuple[Any, Any]: + namespace, args = super(ArgumentParser, self)._parse_known_args(args, namespace, intermixed=False) # type: ignore + return namespace, args + + setattr(ArgumentParser, "_parse_known_args", _parse_known_args_patch) + + if _JSONARGPARSE_SIGNATURES_AVAILABLE: import docstring_parser from jsonargparse import ( @@ -48,12 +60,7 @@ set_config_read_mode, ) - def _parse_known_args_patch(self: ArgumentParser, args: Any = None, namespace: Any = None) -> tuple[Any, Any]: - namespace, args = super(ArgumentParser, self)._parse_known_args(args, namespace, intermixed=False) # type: ignore - return namespace, args - - if sys.version_info >= (3, 12, 8): - setattr(ArgumentParser, "_parse_known_args", _parse_known_args_patch) + patch_jsonargparse_python_3_12_8() # Required until fix https://github.com/omni-us/jsonargparse/issues/641 register_unresolvable_import_paths(torch) # Required until fix https://github.com/pytorch/pytorch/issues/74483 set_config_read_mode(fsspec_enabled=True) diff --git a/tests/parity_fabric/test_parity_ddp.py b/tests/parity_fabric/test_parity_ddp.py index 217d401ad6fba..aebd9064b31fd 100644 --- a/tests/parity_fabric/test_parity_ddp.py +++ b/tests/parity_fabric/test_parity_ddp.py @@ -162,5 +162,8 @@ def run_parity_test(accelerator: str = "cpu", devices: int = 2, tolerance: float if __name__ == "__main__": from jsonargparse import CLI + from lightning.pytorch.cli import patch_jsonargparse_python_3_12_8 + + patch_jsonargparse_python_3_12_8() # Required until fix https://github.com/omni-us/jsonargparse/issues/641 CLI(run_parity_test)