Skip to content

Commit

Permalink
Feature/hydra config modules (#9783)
Browse files Browse the repository at this point in the history
* hydra: enable config modules

Lets the user specify Hydra configs in installed Python modules as well
as inside directories. Adds `hydra.config_module` to `.dvc/config`.

Fixes #9740

old tests green

tests for config modules

* hydra: fix linter errors

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* hydra: fix mypy error

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* hydra: fix tests

* Update dvc/utils/hydra.py

Co-authored-by: David de la Iglesia Castro <[email protected]>

* address comments

* convince mypy not to worry

* split long line

* fix test

* check for raised error in test

---------

Co-authored-by: Dom Miketa <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: David de la Iglesia Castro <[email protected]>
  • Loading branch information
4 people authored Aug 6, 2023
1 parent 43b371c commit 009b238
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 12 deletions.
4 changes: 3 additions & 1 deletion dvc/config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
All,
Any,
Coerce,
Exclusive,
Invalid,
Lower,
Optional,
Expand Down Expand Up @@ -324,7 +325,8 @@ def __call__(self, data):
},
"hydra": {
Optional("enabled", default=False): Bool,
"config_dir": str,
Exclusive("config_dir", "config_source"): str,
Exclusive("config_module", "config_source"): str,
"config_name": str,
},
"studio": {
Expand Down
10 changes: 7 additions & 3 deletions dvc/repo/experiments/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,13 +476,17 @@ def _update_params(self, params: Dict[str, List[str]]):
hydra_output_file = ParamsDependency.DEFAULT_PARAMS_FILE
for path, overrides in params.items():
if hydra_enabled and path == hydra_output_file:
config_dir = os.path.join(
self.repo.root_dir, hydra_config.get("config_dir", "conf")
)
if (config_module := hydra_config.get("config_module")) is None:
config_dir = os.path.join(
self.repo.root_dir, hydra_config.get("config_dir", "conf")
)
else:
config_dir = None
config_name = hydra_config.get("config_name", "config")
compose_and_dump(
path,
config_dir,
config_module,
config_name,
overrides,
)
Expand Down
20 changes: 16 additions & 4 deletions dvc/utils/hydra.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from pathlib import Path
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, List, Optional

from dvc.exceptions import InvalidArgumentError

Expand All @@ -15,7 +15,8 @@

def compose_and_dump(
output_file: "StrPath",
config_dir: str,
config_dir: Optional[str],
config_module: Optional[str],
config_name: str,
overrides: List[str],
) -> None:
Expand All @@ -25,19 +26,30 @@ def compose_and_dump(
output_file: File where the composed config will be dumped.
config_dir: Folder containing the Hydra config files.
Must be absolute file system path.
config_module: Module containing the Hydra config files.
Ignored if `config_dir` is not `None`.
config_name: Name of the config file containing defaults,
without the .yaml extension.
overrides: List of `Hydra Override`_ patterns.
.. _Hydra Override:
https://hydra.cc/docs/advanced/override_grammar/basic/
"""
from hydra import compose, initialize_config_dir
from hydra import compose, initialize_config_dir, initialize_config_module
from omegaconf import OmegaConf

from .serialize import DUMPERS

with initialize_config_dir(config_dir, version_base=None):
config_source = config_dir or config_module
if not config_source:
raise ValueError("Either `config_dir` or `config_module` should be provided.")
initialize_config = (
initialize_config_dir if config_dir else initialize_config_module
)

with initialize_config( # type: ignore[attr-defined]
config_source, version_base=None
):
cfg = compose(config_name=config_name, overrides=overrides)

OmegaConf.resolve(cfg)
Expand Down
64 changes: 60 additions & 4 deletions tests/func/utils/test_hydra.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from contextlib import nullcontext as does_not_raise

import pytest

from dvc.exceptions import InvalidArgumentError
Expand Down Expand Up @@ -167,16 +169,70 @@ def hydra_setup(tmp_dir, config_dir, config_name):
),
],
)
def test_compose_and_dump(tmp_dir, suffix, overrides, expected):
def test_compose_and_dump_overrides(tmp_dir, suffix, overrides, expected):
from dvc.utils.hydra import compose_and_dump

config_name = "config"
output_file = tmp_dir / f"params.{suffix}"
config_dir = hydra_setup(tmp_dir, "conf", "config")
compose_and_dump(output_file, config_dir, config_name, overrides)
config_module = None
compose_and_dump(output_file, config_dir, config_module, config_name, overrides)
assert output_file.parse() == expected


def hydra_setup_dir_basic(tmp_dir, config_subdir, config_name, config_content):
if config_subdir is None:
return None

config_dir = tmp_dir / config_subdir
config_dir.mkdir()
(config_dir / f"{config_name}.yaml").dump(config_content)
return str(config_dir)


@pytest.mark.parametrize(
"config_subdir,config_module,config_content,error_context",
[
("conf", None, {"normal_yaml_config": False}, does_not_raise()),
(
None,
"hydra.test_utils.configs",
{"normal_yaml_config": True},
does_not_raise(),
),
(
"conf",
"hydra.test_utils.configs",
{"normal_yaml_config": False},
does_not_raise(),
),
(
None,
None,
None,
pytest.raises(
ValueError,
match="Either `config_dir` or `config_module` should be provided.",
),
),
],
)
def test_compose_and_dump_dir_module(
tmp_dir, config_subdir, config_module, config_content, error_context
):
from dvc.utils.hydra import compose_and_dump

output_file = tmp_dir / "params.yaml"
config_name = "config"
config_dir = hydra_setup_dir_basic(
tmp_dir, config_subdir, config_name, config_content
)

with error_context:
compose_and_dump(output_file, config_dir, config_module, config_name, [])
assert output_file.parse() == config_content


def test_compose_and_dump_yaml_handles_string(tmp_dir):
"""Regression test for https://github.com/iterative/dvc/issues/8583"""
from dvc.utils.hydra import compose_and_dump
Expand All @@ -185,7 +241,7 @@ def test_compose_and_dump_yaml_handles_string(tmp_dir):
config.parent.mkdir()
config.write_text("foo: 'no'\n")
output_file = tmp_dir / "params.yaml"
compose_and_dump(output_file, str(config.parent), "config", [])
compose_and_dump(output_file, str(config.parent), None, "config", [])
assert output_file.read_text() == "foo: 'no'\n"


Expand All @@ -197,7 +253,7 @@ def test_compose_and_dump_resolves_interpolation(tmp_dir):
config.parent.mkdir()
config.dump({"data": {"root": "path/to/root", "raw": "${.root}/raw"}})
output_file = tmp_dir / "params.yaml"
compose_and_dump(output_file, str(config.parent), "config", [])
compose_and_dump(output_file, str(config.parent), None, "config", [])
assert output_file.parse() == {
"data": {"root": "path/to/root", "raw": "path/to/root/raw"}
}
Expand Down

0 comments on commit 009b238

Please sign in to comment.