Skip to content

Commit

Permalink
Add patch to test
Browse files Browse the repository at this point in the history
  • Loading branch information
lantiga committed Dec 9, 2024
1 parent 68677d3 commit 48b944c
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 15 deletions.
11 changes: 2 additions & 9 deletions .actions/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,16 +483,9 @@ def convert_version2nightly(ver_file: str = "src/version.info") -> None:


if __name__ == "__main__":
import sys

import jsonargparse
from jsonargparse import ArgumentParser

def _parse_known_args_patch(self: ArgumentParser, args: Any = None, namespace: Any = None) -> tuple[Any, Any]:
namespace, args = super(ArgumentParser, self)._parse_known_args(args, namespace, intermixed=False) # type: ignore
return namespace, args
from lightning.pytorch.cli import patch_jsonargparse_python_3_12_8

if sys.version_info >= (3, 12, 8):
setattr(ArgumentParser, "_parse_known_args", _parse_known_args_patch)
patch_jsonargparse_python_3_12_8() # Required until fix https://github.com/omni-us/jsonargparse/issues/641

jsonargparse.CLI(AssistantCLI, as_positional=False)
19 changes: 13 additions & 6 deletions src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@

_JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache("jsonargparse[signatures]>=4.27.7")


def patch_jsonargparse_python_3_12_8():
if sys.version_info < (3, 12, 8):
return

def _parse_known_args_patch(self: ArgumentParser, args: Any = None, namespace: Any = None) -> tuple[Any, Any]:
namespace, args = super(ArgumentParser, self)._parse_known_args(args, namespace, intermixed=False) # type: ignore
return namespace, args

setattr(ArgumentParser, "_parse_known_args", _parse_known_args_patch)


if _JSONARGPARSE_SIGNATURES_AVAILABLE:
import docstring_parser
from jsonargparse import (
Expand All @@ -48,12 +60,7 @@
set_config_read_mode,
)

def _parse_known_args_patch(self: ArgumentParser, args: Any = None, namespace: Any = None) -> tuple[Any, Any]:
namespace, args = super(ArgumentParser, self)._parse_known_args(args, namespace, intermixed=False) # type: ignore
return namespace, args

if sys.version_info >= (3, 12, 8):
setattr(ArgumentParser, "_parse_known_args", _parse_known_args_patch)
patch_jsonargparse_python_3_12_8() # Required until fix https://github.com/omni-us/jsonargparse/issues/641

register_unresolvable_import_paths(torch) # Required until fix https://github.com/pytorch/pytorch/issues/74483
set_config_read_mode(fsspec_enabled=True)
Expand Down
3 changes: 3 additions & 0 deletions tests/parity_fabric/test_parity_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,5 +162,8 @@ def run_parity_test(accelerator: str = "cpu", devices: int = 2, tolerance: float

if __name__ == "__main__":
from jsonargparse import CLI
from lightning.pytorch.cli import patch_jsonargparse_python_3_12_8

patch_jsonargparse_python_3_12_8() # Required until fix https://github.com/omni-us/jsonargparse/issues/641

CLI(run_parity_test)

0 comments on commit 48b944c

Please sign in to comment.