Skip to content

Commit

Permalink
Using RunContext to simplify things
Browse files Browse the repository at this point in the history
  • Loading branch information
marcromeyn committed Aug 22, 2024
1 parent 15b9f06 commit c2c565d
Show file tree
Hide file tree
Showing 9 changed files with 481 additions and 682 deletions.
12 changes: 5 additions & 7 deletions examples/entrypoint/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,9 @@ def my_optimizer(
return Optimizer(learning_rate=learning_rate, weight_decay=weight_decay, betas=betas)


@run.cli.entrypoint(type="sequential_experiment")
@run.cli.entrypoint(type="experiment")
def train_and_evaluate(
experiment: run.Experiment,
executor: run.Executor,
ctx: run.cli.RunContext,
model: Model = my_model(),
optimizer: Optimizer = my_optimizer(),
train_epochs: int = 10,
Expand All @@ -98,10 +97,9 @@ def train_and_evaluate(
train = run.Partial(train_model, model=model, optimizer=optimizer, epochs=train_epochs)
evaluate = run.Partial(train_model, model=model, optimizer=optimizer, epochs=eval_epochs)

experiment.add(train, executor=executor, name="train")
experiment.add(evaluate, executor=executor, name="evaluate")

return experiment
ctx.sequential = True
ctx.add(train, executor=ctx.executor, name="train")
ctx.add(evaluate, executor=ctx.executor, name="evaluate")


if __name__ == "__main__":
Expand Down
56 changes: 29 additions & 27 deletions examples/entrypoint/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,6 @@ class Optimizer:
betas: List[float]


@run.cli.entrypoint
def train_model(
model: Model,
optimizer: Optimizer,
epochs: int = 10,
batch_size: int = 32
):
"""
Train a model using the specified configuration.
Args:
model (Model): Configuration for the model.
optimizer (Optimizer): Configuration for the optimizer.
epochs (int, optional): Number of training epochs. Defaults to 10.
batch_size (int, optional): Batch size for training. Defaults to 32.
"""
print(f"Training model with the following configuration:")
print(f"Model: {model}")
print(f"Optimizer: {optimizer}")
print(f"Epochs: {epochs}")
print(f"Batch size: {batch_size}")

# Simulating model training
for epoch in range(epochs):
print(f"Epoch {epoch + 1}/{epochs}")

print("Training completed!")


@run.cli.factory
Expand Down Expand Up @@ -75,5 +48,34 @@ def my_optimizer(
return Optimizer(learning_rate=learning_rate, weight_decay=weight_decay, betas=betas)


@run.cli.entrypoint
def train_model(
model: Model = my_model(),
optimizer: Optimizer = my_optimizer(),
epochs: int = 10,
batch_size: int = 32
):
"""
Train a model using the specified configuration.
Args:
model (Model): Configuration for the model.
optimizer (Optimizer): Configuration for the optimizer.
epochs (int, optional): Number of training epochs. Defaults to 10.
batch_size (int, optional): Batch size for training. Defaults to 32.
"""
print(f"Training model with the following configuration:")
print(f"Model: {model}")
print(f"Optimizer: {optimizer}")
print(f"Epochs: {epochs}")
print(f"Batch size: {batch_size}")

# Simulating model training
for epoch in range(epochs):
print(f"Epoch {epoch + 1}/{epochs}")

print("Training completed!")


if __name__ == "__main__":
run.cli.main(train_model)
26 changes: 14 additions & 12 deletions src/nemo_run/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
# limitations under the License.

from functools import wraps
from typing import (Any, Callable, Concatenate, Literal, Optional, ParamSpec,
Protocol, Type, TypeVar, Union, cast, overload,
from typing import (Any, Callable, Concatenate, List, Literal, Optional,
ParamSpec, Protocol, Type, TypeVar, Union, cast, overload,
runtime_checkable)

import fiddle as fdl
Expand Down Expand Up @@ -43,19 +43,21 @@ def default_autoconfig_buildable(
cls: Type[Union[Partial, Config]],
*args: P.args,
**kwargs: P.kwargs,
) -> Config[T] | Partial[T]:
) -> Config[T] | Partial[T] | List[Config[T]] | List[Partial[T]]:
def exemption_policy(cfg):
return cfg in [Partial, Config] or getattr(cfg, "__auto_config__", False)

return fdl.cast(
cls,
_auto_config.auto_config(
fn,
experimental_allow_control_flow=False,
experimental_allow_dataclass_attribute_access=True,
experimental_exemption_policy=exemption_policy,
).as_buildable(*args, **kwargs),
)
_output = _auto_config.auto_config(
fn,
experimental_allow_control_flow=False,
experimental_allow_dataclass_attribute_access=True,
experimental_exemption_policy=exemption_policy,
).as_buildable(*args, **kwargs)

if isinstance(_output, list):
return [fdl.cast(cls, item) for item in _output]

return fdl.cast(cls, _output)


@overload
Expand Down
3 changes: 2 additions & 1 deletion src/nemo_run/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo_run.cli.api import (create_cli, entrypoint, factory,
from nemo_run.cli.api import (RunContext, create_cli, entrypoint, factory,
list_entrypoints, list_factories, main,
resolve_factory)

Expand All @@ -25,4 +25,5 @@
"resolve_factory",
"list_entrypoints",
"list_factories",
"RunContext",
]
Loading

0 comments on commit c2c565d

Please sign in to comment.