From 009b238807fef62d40ca2019d383c58bab429f24 Mon Sep 17 00:00:00 2001 From: Dom Miketa Date: Sun, 6 Aug 2023 09:39:09 +0100 Subject: [PATCH] Feature/hydra config modules (#9783) * hydra: enable config modules Lets the user specify Hydra configs in installed Python modules as well as inside directories. Adds `hydra.config_module` to `.dvc/config`. Fixes #9740 old tests green tests for config modules * hydra: fix linter errors * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * hydra: fix mypy error * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * hydra: fix tests * Update dvc/utils/hydra.py Co-authored-by: David de la Iglesia Castro * address comments * convince mypy not to worry * split long line * fix test * check for raised error in test --------- Co-authored-by: Dom Miketa Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: David de la Iglesia Castro --- dvc/config_schema.py | 4 +- dvc/repo/experiments/queue/base.py | 10 +++-- dvc/utils/hydra.py | 20 ++++++++-- tests/func/utils/test_hydra.py | 64 ++++++++++++++++++++++++++++-- 4 files changed, 86 insertions(+), 12 deletions(-) diff --git a/dvc/config_schema.py b/dvc/config_schema.py index cb3e9cdbe7..38610a5029 100644 --- a/dvc/config_schema.py +++ b/dvc/config_schema.py @@ -8,6 +8,7 @@ All, Any, Coerce, + Exclusive, Invalid, Lower, Optional, @@ -324,7 +325,8 @@ def __call__(self, data): }, "hydra": { Optional("enabled", default=False): Bool, - "config_dir": str, + Exclusive("config_dir", "config_source"): str, + Exclusive("config_module", "config_source"): str, "config_name": str, }, "studio": { diff --git a/dvc/repo/experiments/queue/base.py b/dvc/repo/experiments/queue/base.py index 452cba2a71..474980ffa8 100644 --- a/dvc/repo/experiments/queue/base.py +++ b/dvc/repo/experiments/queue/base.py @@ -476,13 +476,17 @@ def _update_params(self, params: Dict[str, List[str]]): hydra_output_file = ParamsDependency.DEFAULT_PARAMS_FILE for path, overrides in params.items(): if hydra_enabled and path == hydra_output_file: - config_dir = os.path.join( - self.repo.root_dir, hydra_config.get("config_dir", "conf") - ) + if (config_module := hydra_config.get("config_module")) is None: + config_dir = os.path.join( + self.repo.root_dir, hydra_config.get("config_dir", "conf") + ) + else: + config_dir = None config_name = hydra_config.get("config_name", "config") compose_and_dump( path, config_dir, + config_module, config_name, overrides, ) diff --git a/dvc/utils/hydra.py b/dvc/utils/hydra.py index aceaa16623..91eecff0b0 100644 --- a/dvc/utils/hydra.py +++ b/dvc/utils/hydra.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional from dvc.exceptions import InvalidArgumentError @@ -15,7 +15,8 @@ def compose_and_dump( output_file: "StrPath", - config_dir: str, + config_dir: Optional[str], + config_module: Optional[str], config_name: str, overrides: List[str], ) -> None: @@ -25,6 +26,8 @@ def compose_and_dump( output_file: File where the composed config will be dumped. config_dir: Folder containing the Hydra config files. Must be absolute file system path. + config_module: Module containing the Hydra config files. + Ignored if `config_dir` is not `None`. config_name: Name of the config file containing defaults, without the .yaml extension. overrides: List of `Hydra Override`_ patterns. @@ -32,12 +35,21 @@ def compose_and_dump( .. _Hydra Override: https://hydra.cc/docs/advanced/override_grammar/basic/ """ - from hydra import compose, initialize_config_dir + from hydra import compose, initialize_config_dir, initialize_config_module from omegaconf import OmegaConf from .serialize import DUMPERS - with initialize_config_dir(config_dir, version_base=None): + config_source = config_dir or config_module + if not config_source: + raise ValueError("Either `config_dir` or `config_module` should be provided.") + initialize_config = ( + initialize_config_dir if config_dir else initialize_config_module + ) + + with initialize_config( # type: ignore[attr-defined] + config_source, version_base=None + ): cfg = compose(config_name=config_name, overrides=overrides) OmegaConf.resolve(cfg) diff --git a/tests/func/utils/test_hydra.py b/tests/func/utils/test_hydra.py index cc677a1ac4..6a448e2a93 100644 --- a/tests/func/utils/test_hydra.py +++ b/tests/func/utils/test_hydra.py @@ -1,3 +1,5 @@ +from contextlib import nullcontext as does_not_raise + import pytest from dvc.exceptions import InvalidArgumentError @@ -167,16 +169,70 @@ def hydra_setup(tmp_dir, config_dir, config_name): ), ], ) -def test_compose_and_dump(tmp_dir, suffix, overrides, expected): +def test_compose_and_dump_overrides(tmp_dir, suffix, overrides, expected): from dvc.utils.hydra import compose_and_dump config_name = "config" output_file = tmp_dir / f"params.{suffix}" config_dir = hydra_setup(tmp_dir, "conf", "config") - compose_and_dump(output_file, config_dir, config_name, overrides) + config_module = None + compose_and_dump(output_file, config_dir, config_module, config_name, overrides) assert output_file.parse() == expected +def hydra_setup_dir_basic(tmp_dir, config_subdir, config_name, config_content): + if config_subdir is None: + return None + + config_dir = tmp_dir / config_subdir + config_dir.mkdir() + (config_dir / f"{config_name}.yaml").dump(config_content) + return str(config_dir) + + +@pytest.mark.parametrize( + "config_subdir,config_module,config_content,error_context", + [ + ("conf", None, {"normal_yaml_config": False}, does_not_raise()), + ( + None, + "hydra.test_utils.configs", + {"normal_yaml_config": True}, + does_not_raise(), + ), + ( + "conf", + "hydra.test_utils.configs", + {"normal_yaml_config": False}, + does_not_raise(), + ), + ( + None, + None, + None, + pytest.raises( + ValueError, + match="Either `config_dir` or `config_module` should be provided.", + ), + ), + ], +) +def test_compose_and_dump_dir_module( + tmp_dir, config_subdir, config_module, config_content, error_context +): + from dvc.utils.hydra import compose_and_dump + + output_file = tmp_dir / "params.yaml" + config_name = "config" + config_dir = hydra_setup_dir_basic( + tmp_dir, config_subdir, config_name, config_content + ) + + with error_context: + compose_and_dump(output_file, config_dir, config_module, config_name, []) + assert output_file.parse() == config_content + + def test_compose_and_dump_yaml_handles_string(tmp_dir): """Regression test for https://github.com/iterative/dvc/issues/8583""" from dvc.utils.hydra import compose_and_dump @@ -185,7 +241,7 @@ def test_compose_and_dump_yaml_handles_string(tmp_dir): config.parent.mkdir() config.write_text("foo: 'no'\n") output_file = tmp_dir / "params.yaml" - compose_and_dump(output_file, str(config.parent), "config", []) + compose_and_dump(output_file, str(config.parent), None, "config", []) assert output_file.read_text() == "foo: 'no'\n" @@ -197,7 +253,7 @@ def test_compose_and_dump_resolves_interpolation(tmp_dir): config.parent.mkdir() config.dump({"data": {"root": "path/to/root", "raw": "${.root}/raw"}}) output_file = tmp_dir / "params.yaml" - compose_and_dump(output_file, str(config.parent), "config", []) + compose_and_dump(output_file, str(config.parent), None, "config", []) assert output_file.parse() == { "data": {"root": "path/to/root", "raw": "path/to/root/raw"} }