Skip to content

Commit

Permalink
Run fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
marcromeyn committed Aug 22, 2024
1 parent fb81aeb commit 476bcee
Show file tree
Hide file tree
Showing 16 changed files with 525 additions and 302 deletions.
21 changes: 6 additions & 15 deletions examples/entrypoint/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
@dataclass
class Model:
"""Dummy model config"""

hidden_size: int
num_layers: int
activation: str
Expand All @@ -15,18 +16,14 @@ class Model:
@dataclass
class Optimizer:
"""Dummy optimizer config"""

learning_rate: float
weight_decay: float
betas: List[float]


@run.cli.entrypoint
def train_model(
model: Model,
optimizer: Optimizer,
epochs: int = 10,
batch_size: int = 32
):
def train_model(model: Model, optimizer: Optimizer, epochs: int = 10, batch_size: int = 32):
"""
Train a model using the specified configuration.
Expand All @@ -51,11 +48,7 @@ def train_model(

@run.cli.factory
@run.autoconvert
def my_model(
hidden_size: int = 256,
num_layers: int = 3,
activation: str = 'relu'
) -> Model:
def my_model(hidden_size: int = 256, num_layers: int = 3, activation: str = "relu") -> Model:
"""
Create a model configuration.
"""
Expand All @@ -65,9 +58,7 @@ def my_model(
@run.cli.factory
@run.autoconvert
def my_optimizer(
learning_rate: float = 0.001,
weight_decay: float = 1e-5,
betas: List[float] = [0.9, 0.999]
learning_rate: float = 0.001, weight_decay: float = 1e-5, betas: List[float] = [0.9, 0.999]
) -> Optimizer:
"""
Create an optimizer configuration.
Expand All @@ -87,7 +78,7 @@ def train_models_experiment(
models: List[Model] = [my_model(), my_model(hidden_size=512)],
optimizers: List[Optimizer] = [my_optimizer(), my_optimizer(learning_rate=0.01)],
epochs: int = 10,
batch_size: int = 32
batch_size: int = 32,
):
"""
Run an experiment to train multiple models with different configurations.
Expand Down
20 changes: 8 additions & 12 deletions examples/entrypoint/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
@dataclass
class Model:
"""Dummy model config"""

hidden_size: int
num_layers: int
activation: str
Expand All @@ -15,20 +16,15 @@ class Model:
@dataclass
class Optimizer:
"""Dummy optimizer config"""

learning_rate: float
weight_decay: float
betas: List[float]




@run.cli.factory
@run.autoconvert
def my_model(
hidden_size: int = 256,
num_layers: int = 3,
activation: str = 'relu'
) -> Model:
def my_model(hidden_size: int = 256, num_layers: int = 3, activation: str = "relu") -> Model:
"""
Create a model configuration.
"""
Expand All @@ -37,20 +33,20 @@ def my_model(

@run.cli.factory
def my_optimizer(
learning_rate: float = 0.001,
weight_decay: float = 1e-5,
betas: List[float] = [0.9, 0.999]
learning_rate: float = 0.001, weight_decay: float = 1e-5, betas: List[float] = [0.9, 0.999]
) -> run.Config[Optimizer]:
"""Create an optimizer configuration."""
return run.Config(Optimizer, learning_rate=learning_rate, weight_decay=weight_decay, betas=betas)
return run.Config(
Optimizer, learning_rate=learning_rate, weight_decay=weight_decay, betas=betas
)


@run.cli.entrypoint
def train_model(
model: Model = my_model(),
optimizer: Optimizer = my_optimizer(),
epochs: int = 10,
batch_size: int = 32
batch_size: int = 32,
):
"""
Train a model using the specified configuration.
Expand Down
3 changes: 1 addition & 2 deletions src/nemo_run/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
from nemo_run import cli
from nemo_run.api import autoconvert, dryrun_fn
from nemo_run.config import Config, Partial, Script
from nemo_run.core.execution.base import (Executor, ExecutorMacros,
FaultTolerance, Torchrun)
from nemo_run.core.execution.base import Executor, ExecutorMacros, FaultTolerance, Torchrun
from nemo_run.core.execution.local import LocalExecutor
from nemo_run.core.execution.skypilot import SkypilotExecutor
from nemo_run.core.execution.slurm import SlurmExecutor
Expand Down
27 changes: 18 additions & 9 deletions src/nemo_run/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,22 @@
# limitations under the License.

from functools import wraps
from typing import (Any, Callable, Concatenate, List, Literal, Optional,
ParamSpec, Protocol, Type, TypeVar, Union, cast, overload,
runtime_checkable)
from typing import (
Any,
Callable,
Concatenate,
List,
Literal,
Optional,
ParamSpec,
Protocol,
Type,
TypeVar,
Union,
cast,
overload,
runtime_checkable,
)

import fiddle as fdl
from fiddle.experimental import auto_config as _auto_config
Expand Down Expand Up @@ -105,9 +118,7 @@ def autoconvert(


def autoconvert(
fn: Optional[
Callable[P, T] | Callable[P, Config[T]] | Callable[P, Partial[T]]
] = None,
fn: Optional[Callable[P, T] | Callable[P, Config[T]] | Callable[P, Partial[T]]] = None,
*,
partial: bool = False,
to_buildable_fn: Callable[
Expand Down Expand Up @@ -199,9 +210,7 @@ def dryrun_fn(

fn = configured_fn.__fn_or_cls__
console = CONSOLE
console.print(
f"[bold cyan]Dry run for task {fn.__module__}:{fn.__name__}[/bold cyan]"
)
console.print(f"[bold cyan]Dry run for task {fn.__module__}:{fn.__name__}[/bold cyan]")

table_resolved_args = Table(show_header=True, header_style="bold magenta")
table_resolved_args.add_column("Argument Name", style="dim", width=20)
Expand Down
13 changes: 10 additions & 3 deletions src/nemo_run/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo_run.cli.api import (RunContext, create_cli, entrypoint, factory,
list_entrypoints, list_factories, main,
resolve_factory)
from nemo_run.cli.api import (
RunContext,
create_cli,
entrypoint,
factory,
list_entrypoints,
list_factories,
main,
resolve_factory,
)
from nemo_run.cli.cli_parser import parse_cli_args, parse_config, parse_partial

__all__ = [
Expand Down
79 changes: 53 additions & 26 deletions src/nemo_run/cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,22 @@
import sys
from dataclasses import dataclass, field
from functools import cache, wraps
from typing import (Any, Callable, Generic, List, Literal, Optional, Protocol,
Tuple, Type, TypeVar, Union, get_args, overload,
runtime_checkable)
from typing import (
Any,
Callable,
Generic,
List,
Literal,
Optional,
Protocol,
Tuple,
Type,
TypeVar,
Union,
get_args,
overload,
runtime_checkable,
)

import catalogue
import fiddle as fdl
Expand All @@ -38,10 +51,15 @@
from nemo_run.cli import devspace as devspace_cli
from nemo_run.cli import experiment as experiment_cli
from nemo_run.cli.cli_parser import parse_cli_args, parse_factory
from nemo_run.config import (NEMORUN_HOME, Config, Partial, Script,
get_type_namespace, get_underlying_types)
from nemo_run.core.execution import (LocalExecutor, SkypilotExecutor,
SlurmExecutor)
from nemo_run.config import (
NEMORUN_HOME,
Config,
Partial,
Script,
get_type_namespace,
get_underlying_types,
)
from nemo_run.core.execution import LocalExecutor, SkypilotExecutor, SlurmExecutor
from nemo_run.core.execution.base import Executor
from nemo_run.run.experiment import Experiment
from nemo_run.run.plugin import ExperimentPlugin as Plugin
Expand All @@ -67,7 +85,7 @@ def entrypoint(
require_confirmation: bool = True,
enable_executor: bool = True,
entrypoint_cls: Optional[Type["Entrypoint"]] = None,
type: Literal["task", "experiment"] = "task"
type: Literal["task", "experiment"] = "task",
) -> F | Callable[[F], F]:
"""
Decorator to register a function as a CLI entrypoint in the NeMo Run framework.
Expand Down Expand Up @@ -153,7 +171,7 @@ def wrapper(f: F) -> F:
help_str=help,
require_confirmation=require_confirmation,
enable_executor=enable_executor,
type=type
type=type,
)

if _namespace:
Expand Down Expand Up @@ -300,8 +318,11 @@ def wrapper(fn: Callable[Params, T]) -> Callable[Params, T]:
if not target and not hasattr(fn, "__auto_config__"):
return_type = _get_return_type(fn)
if not (
isinstance(return_type, (Config, Partial)) or
(hasattr(return_type, "__origin__") and issubclass(return_type.__origin__, (Config, Partial)))
isinstance(return_type, (Config, Partial))
or (
hasattr(return_type, "__origin__")
and issubclass(return_type.__origin__, (Config, Partial))
)
):
raise ValueError(
f"Factory function {fn} has a return type which is not a subclass of Config or Partial. "
Expand All @@ -322,7 +343,7 @@ def as_factory(*args: Params.args, **kwargs: Params.kwargs) -> T:
target_arg=target_arg,
name=name,
namespace=namespace,
is_target_default=is_target_default
is_target_default=is_target_default,
)

return as_factory
Expand All @@ -331,7 +352,8 @@ def as_factory(*args: Params.args, **kwargs: Params.kwargs) -> T:


def resolve_factory(
target: Type[T] | str, name: str,
target: Type[T] | str,
name: str,
) -> Callable[..., Config[T] | Partial[T]]:
"""
Helper function to resolve the factory for the give type or namespace.
Expand Down Expand Up @@ -532,9 +554,9 @@ def _register_factory(
_namespace = get_type_namespace(target)
else:
_return_type = _get_return_type(fn)
if (
isinstance(_return_type, (Config, Partial)) or
(hasattr(_return_type, "__origin__") and issubclass(_return_type.__origin__, (Config, Partial)))
if isinstance(_return_type, (Config, Partial)) or (
hasattr(_return_type, "__origin__")
and issubclass(_return_type.__origin__, (Config, Partial))
):
_return_type = get_args(_return_type)[0]

Expand Down Expand Up @@ -642,10 +664,7 @@ def add(
)

def run(
self,
fn: Callable,
args: List[str],
entrypoint_type: Literal["task", "experiment"] = "task"
self, fn: Callable, args: List[str], entrypoint_type: Literal["task", "experiment"] = "task"
):
_, run_args, filtered_args = _parse_prefixed_args(args, "run")
self.parse_args(run_args)
Expand Down Expand Up @@ -723,7 +742,7 @@ def _execute_experiment(self, fn: Callable, experiment_args: List[str]):
sequential=self.sequential,
detach=self.detach,
direct=self.direct or self.executor is None,
tail_logs=self.tail_logs
tail_logs=self.tail_logs,
)

def _should_continue(self, require_confirmation: bool) -> bool:
Expand Down Expand Up @@ -806,14 +825,18 @@ def __init__(
help_str=None,
enable_executor: bool = True,
require_confirmation: bool = True,
type: Literal["task", "experiment"] = "task"
type: Literal["task", "experiment"] = "task",
):
if type == "task":
if "executor" in inspect.signature(fn).parameters:
raise ValueError("The function cannot have an argument named `executor` as it is a reserved keyword.")
raise ValueError(
"The function cannot have an argument named `executor` as it is a reserved keyword."
)
elif type in ("sequential_experiment", "parallel_experiment"):
if "ctx" not in inspect.signature(fn).parameters:
raise ValueError("The function must have an argument named `ctx` as it is a required argument for experiments.")
raise ValueError(
"The function must have an argument named `ctx` as it is a required argument for experiments."
)

self.fn = fn
self.arg_types = {}
Expand Down Expand Up @@ -1007,7 +1030,9 @@ def format_help(self, ctx, formatter):
return out


def _parse_prefixed_args(args: List[str], prefix: str) -> Tuple[Optional[str], List[str], List[str]]:
def _parse_prefixed_args(
args: List[str], prefix: str
) -> Tuple[Optional[str], List[str], List[str]]:
"""
Parse arguments to separate prefixed args from others.
Expand All @@ -1033,7 +1058,9 @@ def _parse_prefixed_args(args: List[str], prefix: str) -> Tuple[Optional[str], L
prefixed_arg_value = arg.split("=")[1]
else:
if not arg.startswith(f"{prefix}.") and not arg.startswith(f"{prefix}["):
raise ValueError(f"{prefix.capitalize()} overwrites must start with '{prefix}.'. Got {arg}")
raise ValueError(
f"{prefix.capitalize()} overwrites must start with '{prefix}.'. Got {arg}"
)
if arg.startswith(f"{prefix}."):
prefixed_args.append(arg.replace(f"{prefix}.", ""))
elif arg.startswith(f"{prefix}["):
Expand Down
Loading

0 comments on commit 476bcee

Please sign in to comment.