Skip to content

Commit

Permalink
Fix linting errors
Browse files Browse the repository at this point in the history
  • Loading branch information
marcromeyn committed Aug 22, 2024
1 parent 7e9c6b6 commit 62de001
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 19 deletions.
2 changes: 1 addition & 1 deletion examples/entrypoint/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def train_model(
epochs (int, optional): Number of training epochs. Defaults to 10.
batch_size (int, optional): Batch size for training. Defaults to 32.
"""
print(f"Training model with the following configuration:")
print("Training model with the following configuration:")
print(f"Model: {model}")
print(f"Optimizer: {optimizer}")
print(f"Epochs: {epochs}")
Expand Down
2 changes: 1 addition & 1 deletion examples/entrypoint/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def train_model(
epochs (int, optional): Number of training epochs. Defaults to 10.
batch_size (int, optional): Batch size for training. Defaults to 32.
"""
print(f"Training model with the following configuration:")
print("Training model with the following configuration:")
print(f"Model: {model}")
print(f"Optimizer: {optimizer}")
print(f"Epochs: {epochs}")
Expand Down
6 changes: 3 additions & 3 deletions src/nemo_run/cli/cli_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def parse(self, arg: str) -> Dict[str, Any]:
except ValueError:
raise ArgumentParsingError(f"Invalid operation: {op_str}", arg, {"key": key, "value": value})
return {key: (op, value)}
raise ArgumentParsingError(f"Invalid argument format", arg, {})
raise ArgumentParsingError("Invalid argument format", arg, {})

def parse_value(self, value: str) -> Any:
"""
Expand Down Expand Up @@ -860,7 +860,7 @@ def infer_type(self, value: str) -> Type:
if isinstance(parsed, bool):
return bool
return type(parsed)
except:
except Exception:
return str


Expand Down Expand Up @@ -1114,7 +1114,7 @@ def _args_to_kwargs(fn: Callable, args: List[str]) -> List[str]:
signature = inspect.signature(fn.__class__)
if signature is None:
for arg in args:
if not "=" in arg:
if "=" not in arg:
raise ArgumentParsingError(
"Positional argument found after keyword argument",
arg,
Expand Down
1 change: 0 additions & 1 deletion src/nemo_run/run/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import sys
import time
import traceback
import types
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Optional, Type, Union
Expand Down
26 changes: 13 additions & 13 deletions test/cli/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@
from dataclasses import dataclass
from test.dummy_factory import DummyModel, dummy_entrypoint
from typing import Optional, Union
from unittest.mock import ANY, MagicMock, Mock, patch
from unittest.mock import MagicMock, Mock, patch

import fiddle as fdl
import pytest
from importlib_metadata import EntryPoint, EntryPoints
from rich.console import Console
from typer.testing import CliRunner

import nemo_run as run

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note test

Module 'nemo_run' is imported with both 'import' and 'import from'.
Expand Down Expand Up @@ -95,15 +94,15 @@ def func(ctx, a: int, b: str, c: float = 1.0):
def test_run_context_initialization(self):
ctx = RunContext(name="test_run")
assert ctx.name == "test_run"
assert ctx.direct == False
assert ctx.dryrun == False
assert not ctx.direct
assert not ctx.dryrun
assert ctx.factory is None
assert ctx.load is None
assert ctx.repl == False
assert ctx.sequential == True
assert ctx.detach == False
assert ctx.require_confirmation == True
assert ctx.tail_logs == False
assert not ctx.repl
assert ctx.sequential
assert not ctx.detach
assert ctx.require_confirmation
assert not ctx.tail_logs

def test_run_context_parse_args(self):
ctx = RunContext(name="test_run")
Expand Down Expand Up @@ -242,18 +241,20 @@ def test_run_context_with_invalid_entrypoint_type(self, sample_function):
@patch("nemo_run.cli.api.RunContext.run")
def test_run_context_run_task(self, mock_run):
ctx = RunContext(name="test_run")
sample_function = lambda a, b: None
def sample_function(a, b):
return None

ctx.run(sample_function, ["a=10", "b=hello"])

mock_run.assert_called_once_with(sample_function, ["a=10", "b=hello"])

def test_run_context_run_with_sequential(self):
ctx = RunContext(name="test_run", require_confirmation=False)
sample_function = lambda a, b: None
def sample_function(a, b):
return None

ctx.run(sample_function, ["a=10", "b=hello", "run.sequential=False"])
assert ctx.sequential == False
assert not ctx.sequential


@dataclass
Expand Down Expand Up @@ -393,7 +394,6 @@ def test_resolve_entrypoints(self):
)

def test_help(self):
from nemo_run import api

registry_details = []
for t in config.get_underlying_types(Optional[Optimizer]):
Expand Down

0 comments on commit 62de001

Please sign in to comment.