diff --git a/src/otx/engine/hpo/hpo_api.py b/src/otx/engine/hpo/hpo_api.py index 227010d82b5..03280517b70 100644 --- a/src/otx/engine/hpo/hpo_api.py +++ b/src/otx/engine/hpo/hpo_api.py @@ -106,10 +106,10 @@ def execute_hpo( ) best_trial = hpo_algo.get_best_config() - if best_trial is None: - best_config = None - best_hpo_weight = None - else: + + best_config = None + best_hpo_weight = None + if best_trial is not None: best_config = best_trial["configuration"] if (trial_file := find_trial_file(hpo_workdir, best_trial["id"])) is not None: best_hpo_weight = get_best_hpo_weight(get_hpo_weight_dir(hpo_workdir, best_trial["id"]), trial_file) diff --git a/src/otx/engine/hpo/utils.py b/src/otx/engine/hpo/utils.py index 28586406ca4..a869af8d301 100644 --- a/src/otx/engine/hpo/utils.py +++ b/src/otx/engine/hpo/utils.py @@ -114,6 +114,9 @@ def get_metric(callbacks: list[Callback] | Callback) -> str: for callback in callbacks: if isinstance(callback, ModelCheckpoint): - return callback.monitor + if (metric := callback.monitor) is None: + msg = "Failed to find a metric. 'monitor' value of ModelCheckpoint callback is set to None." + raise ValueError(msg) + return metric msg = "Failed to find a metric. There is no ModelCheckpoint in callback list." raise RuntimeError(msg) diff --git a/tests/unit/engine/hpo/__init__.py b/tests/unit/engine/hpo/__init__.py new file mode 100644 index 00000000000..916f3a44b27 --- /dev/null +++ b/tests/unit/engine/hpo/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/unit/engine/hpo/test_hpo_api.py b/tests/unit/engine/hpo/test_hpo_api.py new file mode 100644 index 00000000000..9895f4082e1 --- /dev/null +++ b/tests/unit/engine/hpo/test_hpo_api.py @@ -0,0 +1,328 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Tests for HPO API utility functions.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from unittest.mock import MagicMock + +import pytest +from otx.core.config.hpo import HpoConfig +from otx.core.optimizer.callable import OptimizerCallableSupportHPO +from otx.core.schedulers import LinearWarmupSchedulerCallable, SchedulerCallableSupportHPO +from otx.engine.hpo import hpo_api as target_file +from otx.engine.hpo.hpo_api import ( + HPOConfigurator, + _adjust_train_args, + _remove_unused_model_weights, + _update_hpo_progress, + execute_hpo, +) + +if TYPE_CHECKING: + from pathlib import Path + +HPO_NAME_MAP: dict[str, str] = { + "lr": "model.optimizer_callable.optimizer_kwargs.lr", + "bs": "datamodule.config.train_subset.batch_size", +} + + +@pytest.fixture() +def engine_work_dir(tmp_path: Path) -> Path: + return tmp_path + + +@pytest.fixture() +def dataset_size() -> int: + return 10 + + +@pytest.fixture() +def default_bs() -> int: + return 8 + + +@pytest.fixture() +def default_lr() -> float: + return 0.001 + + +@pytest.fixture() +def mock_engine(engine_work_dir: Path, dataset_size: int, default_bs: int, default_lr: float) -> MagicMock: + engine = MagicMock() + engine.work_dir = engine_work_dir + engine.datamodule.subsets = {engine.datamodule.config.train_subset.subset_name: range(dataset_size)} + engine.datamodule.config.train_subset.batch_size = default_bs + engine.model.optimizer_callable = MagicMock(spec=OptimizerCallableSupportHPO) + engine.model.optimizer_callable.lr = default_lr + engine.model.optimizer_callable.optimizer_kwargs = {"lr": default_lr} + return engine + + +@pytest.fixture() +def mock_hpo_algo() -> MagicMock: + hpo_algo = MagicMock() + hpo_algo.get_best_config.return_value = {"configuration": "best_config", "id": "best_id"} + return hpo_algo + + +@pytest.fixture() +def mock_hpo_configurator(mocker, mock_hpo_algo: MagicMock) -> HPOConfigurator: + hpo_configurator = MagicMock() + hpo_configurator.get_hpo_algo.return_value = mock_hpo_algo + mocker.patch.object(target_file, "HPOConfigurator", return_value=hpo_configurator) + return hpo_configurator + + +@pytest.fixture() +def mock_run_hpo_loop(mocker) -> MagicMock: + return mocker.patch.object(target_file, "run_hpo_loop") + + +@pytest.fixture() +def mock_thread(mocker) -> MagicMock: + return mocker.patch.object(target_file, "Thread") + + +@pytest.fixture() +def mock_get_best_hpo_weight(mocker) -> MagicMock: + return mocker.patch.object(target_file, "get_best_hpo_weight") + + +@pytest.fixture() +def mock_find_trial_file(mocker) -> MagicMock: + return mocker.patch.object(target_file, "find_trial_file") + + +@pytest.fixture() +def hpo_config() -> HpoConfig: + return HpoConfig(metric_name="val/accuracy") + + +@pytest.fixture() +def mock_progress_update_callback() -> MagicMock: + return MagicMock() + + +def test_execute_hpo( + mock_engine: MagicMock, + hpo_config: HpoConfig, + engine_work_dir: Path, + mock_run_hpo_loop: MagicMock, + mock_thread: MagicMock, + mock_hpo_configurator: HPOConfigurator, # noqa: ARG001 + mock_hpo_algo: MagicMock, + mock_get_best_hpo_weight: MagicMock, + mock_find_trial_file: MagicMock, # noqa: ARG001 + mock_progress_update_callback: MagicMock, +): + best_config, best_hpo_weight = execute_hpo( + engine=mock_engine, + max_epochs=10, + hpo_config=hpo_config, + progress_update_callback=mock_progress_update_callback, + ) + + # check hpo workdir exists + assert (engine_work_dir / "hpo").exists() + # check a case where progress_update_callback exists + mock_thread.assert_called_once() + assert mock_thread.call_args.kwargs["target"] == _update_hpo_progress + assert mock_thread.call_args.kwargs["args"][0] == mock_progress_update_callback + assert mock_thread.call_args.kwargs["daemon"] is True + mock_thread.return_value.start.assert_called_once() + # check whether run_hpo_loop is called well + mock_run_hpo_loop.assert_called_once() + assert mock_run_hpo_loop.call_args.args[0] == mock_hpo_algo + # print_result is called after HPO is done + mock_hpo_algo.print_result.assert_called_once() + # best_config and best_hpo_weight are returned well + assert best_config == "best_config" + assert best_hpo_weight == mock_get_best_hpo_weight.return_value + + +class TestHPOConfigurator: + def test_init(self, mock_engine: MagicMock, hpo_config: HpoConfig): + HPOConfigurator(mock_engine, 10, hpo_config) + + def test_hpo_config( + self, + mock_engine: MagicMock, + hpo_config: HpoConfig, + dataset_size: int, + default_lr: float, + default_bs: int, + ): + max_epochs = 10 + hpo_configurator = HPOConfigurator(mock_engine, max_epochs, hpo_config=hpo_config) + hpo_config = hpo_configurator.hpo_config + + # check default hpo config is set well + assert hpo_config["save_path"] == str(mock_engine.work_dir / "hpo") + assert hpo_config["num_full_iterations"] == max_epochs + assert hpo_config["full_dataset_size"] == dataset_size + # check search_space is set automatically + assert hpo_config["search_space"] is not None + # check max range of batch size isn't bigger than dataset size + assert hpo_config["search_space"][HPO_NAME_MAP["bs"]]["max"] == dataset_size + # check current hyper parameter will be tested first + assert hpo_config["prior_hyper_parameters"] == {HPO_NAME_MAP["lr"]: default_lr, HPO_NAME_MAP["bs"]: default_bs} + + def test_get_default_search_space(self, mock_engine: MagicMock, hpo_config: HpoConfig): + hpo_configurator = HPOConfigurator(mock_engine, 10, hpo_config) + search_sapce = hpo_configurator._get_default_search_space() + + for hp_name in HPO_NAME_MAP.values(): + assert hp_name in search_sapce + + def test_align_lr_bs_name(self, mock_engine: MagicMock, hpo_config: HpoConfig): + """Check learning rate and batch size names are aligned well.""" + search_space = { + "model.optimizer.lr": { + "type": "loguniform", + "min": 0.0001, + "max": 0.1, + }, + "data.config.train_subset.batch_size": { + "type": "quniform", + "min": 2, + "max": 512, + "step": 1, + }, + } + hpo_config.search_space = search_space + + hpo_configurator = HPOConfigurator(mock_engine, 10, hpo_config) + + for new_name in HPO_NAME_MAP.values(): + assert new_name in hpo_configurator.hpo_config["search_space"] + + def test_align_scheduler_callable_support_hpo_name(self, mock_engine: MagicMock, hpo_config: HpoConfig): + """Check scheduler name is aligned well if class of scheduler is SchedulerCallableSupportHPO.""" + mock_engine.model.scheduler_callable = MagicMock(spec=SchedulerCallableSupportHPO) + mock_engine.model.scheduler_callable.factor = 0.001 + mock_engine.model.scheduler_callable.scheduler_kwargs = {"factor": 0.001} + search_space = { + "model.scheduler.factor": { + "type": "loguniform", + "min": 0.0001, + "max": 0.1, + }, + } + hpo_config.search_space = search_space + + hpo_configurator = HPOConfigurator(mock_engine, 10, hpo_config) + + assert "model.scheduler_callable.scheduler_kwargs.factor" in hpo_configurator.hpo_config["search_space"] + + def test_align_linear_warmup_scheduler_callable_name(self, mock_engine: MagicMock, hpo_config: HpoConfig): + """Check scheduler name is aligned well if class of scheduler is LinearWarmupSchedulerCallable.""" + scheduler_callable = MagicMock(spec=LinearWarmupSchedulerCallable) + scheduler_callable.num_warmup_steps = 0.001 + main_scheduler_callable = MagicMock() + main_scheduler_callable.factor = 0.001 + main_scheduler_callable.scheduler_kwargs = {"factor": 0.001} + scheduler_callable.main_scheduler_callable = main_scheduler_callable + mock_engine.model.scheduler_callable = scheduler_callable + search_space = { + "model.scheduler.num_warmup_steps": { + "type": "loguniform", + "min": 0.0001, + "max": 0.1, + }, + "model.scheduler.main_scheduler_callable.factor": { + "type": "loguniform", + "min": 0.0001, + "max": 0.1, + }, + } + hpo_config.search_space = search_space + + hpo_configurator = HPOConfigurator(mock_engine, 10, hpo_config) + + assert "model.scheduler_callable.num_warmup_steps" in hpo_configurator.hpo_config["search_space"] + assert ( + "model.scheduler_callable.main_scheduler_callable.scheduler_kwargs.factor" + in hpo_configurator.hpo_config["search_space"] + ) + + def test_remove_wrong_search_space(self, mock_engine: MagicMock, hpo_config: HpoConfig): + hpo_configurator = HPOConfigurator(mock_engine, 10, hpo_config) + wrong_search_space = { + "wrong_choice": { + "type": "choice", + "choice_list": [], # choice shouldn't be empty + }, + "wrong_quniform": { + "type": "quniform", + "min": 2, + "max": 3, # max should be larger than min + step + "step": 2, + }, + } + hpo_configurator._remove_wrong_search_space(wrong_search_space) + assert wrong_search_space == {} + + def test_get_hpo_algo(self, mocker, mock_engine: MagicMock, hpo_config: HpoConfig): + hpo_configurator = HPOConfigurator(mock_engine, 10, hpo_config) + mock_hyper_band = mocker.patch.object(target_file, "HyperBand") + hpo_configurator.get_hpo_algo() + + mock_hyper_band.assert_called_once() + assert mock_hyper_band.call_args.kwargs == hpo_configurator.hpo_config + + +def test_update_hpo_progress(mocker, mock_progress_update_callback: MagicMock): + mock_hpo_algo = MagicMock() + mock_hpo_algo.is_done.side_effect = [False, False, False, True] + progress_arr = [0.3, 0.6, 1] + mock_hpo_algo.get_progress.side_effect = progress_arr + mocker.patch.object(target_file, "time") + + _update_hpo_progress(mock_progress_update_callback, mock_hpo_algo) + + mock_progress_update_callback.assert_called() + for i in range(3): + assert mock_progress_update_callback.call_args_list[i].args[0] == pytest.approx(progress_arr[i] * 100) + + +def test_adjust_train_args(): + new_train_args = _adjust_train_args( + { + "self": "self", + "run_hpo": "run_hpo", + "kwargs": { + "kwargs_1": "kwargs_1", + "kwargs_2": "kwargs_2", + }, + }, + ) + + assert "self" not in new_train_args + assert "run_hpo" not in new_train_args + assert "kwargs" not in new_train_args + assert "kwargs_1" in new_train_args + assert "kwargs_2" in new_train_args + + +@pytest.fixture() +def mock_hpo_workdir(tmp_path: Path) -> Path: + (tmp_path / "1.ckpt").touch() + sub_dir = tmp_path / "a" + sub_dir.mkdir() + (sub_dir / "2.ckpt").touch() + return tmp_path + + +def test_remove_unused_model_weights(mock_hpo_workdir: Path): + best_weight = mock_hpo_workdir / "3.ckpt" + best_weight.touch() + + _remove_unused_model_weights(mock_hpo_workdir, best_weight) + + ckpt_files = list(mock_hpo_workdir.rglob("*.ckpt")) + assert len(ckpt_files) == 1 + assert ckpt_files[0] == best_weight diff --git a/tests/unit/engine/hpo/test_hpo_trial.py b/tests/unit/engine/hpo/test_hpo_trial.py new file mode 100644 index 00000000000..5d310d78d8e --- /dev/null +++ b/tests/unit/engine/hpo/test_hpo_trial.py @@ -0,0 +1,230 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Tests for HPO API utility functions.""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING +from unittest.mock import MagicMock + +import pytest +from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint +from otx.algo.callbacks.adaptive_train_scheduling import AdaptiveTrainScheduling +from otx.engine.hpo import hpo_trial as target_file +from otx.engine.hpo.hpo_trial import ( + HPOCallback, + _register_hpo_callback, + _set_to_validate_every_epoch, + run_hpo_trial, + update_hyper_parameter, +) +from otx.engine.hpo.utils import get_hpo_weight_dir +from otx.hpo import TrialStatus +from torch import tensor + +if TYPE_CHECKING: + from lightning import Callback + + +@pytest.fixture() +def mock_engine() -> MagicMock: + engine = MagicMock() + + def train_side_effect(*args, **kwargs) -> None: # noqa: ARG001 + if isinstance(engine.work_dir, str): + work_dir = Path(engine.work_dir) + for i in range(3): + (work_dir / f"epoch_{i}.ckpt").touch() + (work_dir / "last.ckpt").write_text("last_ckpt") + + engine.train.side_effect = train_side_effect + + return engine + + +def test_update_hyper_parameter(mock_engine): + hyper_parameter = { + "a.b.c": 1, + "d.e.f": 2, + } + + update_hyper_parameter(mock_engine, hyper_parameter) + + assert mock_engine.a.b.c == 1 + assert mock_engine.d.e.f == 2 + + +@pytest.fixture() +def mock_report_func() -> MagicMock: + return MagicMock() + + +class TestHPOCallback: + @pytest.fixture() + def metric(self) -> str: + return "metric" + + def test_init(self, mock_report_func, metric): + HPOCallback(mock_report_func, metric) + + @pytest.fixture() + def hpo_callback(self, mock_report_func, metric) -> HPOCallback: + return HPOCallback(mock_report_func, metric) + + def test_on_train_epoch_end(self, hpo_callback: HPOCallback, mock_report_func, metric): + cur_eph = 10 + score = 0.1 + + mock_trainer = MagicMock() + mock_trainer.current_epoch = cur_eph + mock_trainer.callback_metrics = {metric: tensor(score)} + mock_report_func.return_value = TrialStatus.STOP + + hpo_callback.on_train_epoch_end(mock_trainer, MagicMock()) + + mock_report_func.assert_called_once_with(pytest.approx(score), cur_eph + 1) + assert mock_trainer.should_stop is True # if report_func returns STOP, it should be set to True + + def test_on_train_epoch_end_no_score(self, hpo_callback: HPOCallback, mock_report_func): + mock_trainer = MagicMock() + mock_trainer.callback_metrics = {} + + hpo_callback.on_train_epoch_end(mock_trainer, MagicMock()) + + mock_report_func.assert_not_called() + + +@pytest.fixture() +def mock_checkpoint_callback() -> MagicMock: + return MagicMock(spec=ModelCheckpoint) + + +@pytest.fixture() +def mock_adaptive_schedule_hook() -> MagicMock: + return MagicMock(spec=AdaptiveTrainScheduling) + + +@pytest.fixture() +def mock_callbacks(mock_checkpoint_callback, mock_adaptive_schedule_hook) -> list[Callback]: + return [mock_checkpoint_callback, mock_adaptive_schedule_hook] + + +def test_run_hpo_trial(mocker, mock_callbacks, mock_report_func, tmp_path, mock_engine, mock_checkpoint_callback): + trial_id = "0" + max_epochs = 10 + hp_config = { + "id": trial_id, + "configuration": { + "iterations": max_epochs, + "a.b.c": 1, + "d.e.f": 2, + }, + } + hpo_weight_dir = get_hpo_weight_dir(tmp_path, trial_id) + last_weight = hpo_weight_dir / "last.ckpt" # last checkpoint so far. will be used to resume + last_weight.write_text("prev_weight") + best_weight = hpo_weight_dir / "epoch_2.ckpt" + mocker.patch.object(target_file, "find_trial_file", return_value=Path("fake.json")) + mocker.patch.object(target_file, "get_best_hpo_weight", return_value=best_weight) + + run_hpo_trial( + hp_config=hp_config, + report_func=mock_report_func, + hpo_workdir=tmp_path, + engine=mock_engine, + callbacks=mock_callbacks, + metric_name="metric", + ) + + train_work_dir = mock_engine.work_dir + mock_engine.train.assert_called_once() + # HPOCallback should be added to callback list + for callback in mock_engine.train.call_args.kwargs["callbacks"]: + if isinstance(callback, HPOCallback): + break + else: + msg = "There is no HPOCallback in callback list." + raise AssertionError(msg) + # check training is resumed if model checkpoint exists + assert mock_engine.train.call_args.kwargs["checkpoint"] == last_weight + assert mock_engine.train.call_args.kwargs["resume"] is True + # check given hyper parameters are set well + assert mock_engine.train.call_args.kwargs["max_epochs"] == 10 + assert mock_engine.a.b.c == 1 + assert mock_engine.d.e.f == 2 + # check train work directory are changed well + assert mock_checkpoint_callback.dirpath == train_work_dir + # check final report is executed well + mock_report_func.assert_called_once() + mock_report_func.call_args.kwargs["done"] = True + # check all model weights in train directory are moved to hpo weight directory + assert len(list(Path(train_work_dir).rglob("*.ckpt"))) == 0 + # check all model checkpoint are removed except last and best weight + hpo_weights = list(hpo_weight_dir.rglob("*.ckpt")) + assert len(hpo_weights) == 2 + assert best_weight in hpo_weights + assert last_weight in hpo_weights + assert last_weight.read_text() == "last_ckpt" + + +def test_register_hpo_callback(mock_report_func): + """Check it returns list including only HPOCallback if any callbacks are passed.""" + callabcks = _register_hpo_callback( + report_func=mock_report_func, + metric_name="metric", + ) + assert len(callabcks) == 1 + assert isinstance(callabcks[0], HPOCallback) + + +def test_register_hpo_callback_given_callback(mock_report_func, mock_checkpoint_callback): + """Check it returns list including HPOCallback if single callback is passed.""" + new_callabcks = _register_hpo_callback( + report_func=mock_report_func, + callbacks=mock_checkpoint_callback, + metric_name="metric", + ) + assert len(new_callabcks) == 2 + for callback in new_callabcks: + if isinstance(callback, HPOCallback): + break + else: + msg = "There is no HPOCallback in callback list." + raise AssertionError(msg) + assert mock_checkpoint_callback in new_callabcks + + +def test_register_hpo_callback_given_callbacks_arr(mock_report_func, mock_checkpoint_callback, mock_callbacks): + """Check it returns list including HPOCallback if callback array is passed.""" + new_callabcks = _register_hpo_callback( + report_func=mock_report_func, + callbacks=mock_callbacks, + metric_name="metric", + ) + assert len(new_callabcks) == 3 + for callback in new_callabcks: + if isinstance(callback, HPOCallback): + break + else: + msg = "There is no HPOCallback in callback list." + raise AssertionError(msg) + assert mock_checkpoint_callback in new_callabcks + + +def test_set_to_validate_every_epoch(mock_callbacks, mock_adaptive_schedule_hook): + """Check AdaptiveTrainScheduling.max_iterval is changed if AdaptiveTrainScheduling is in callback list.""" + train_args = {} + _set_to_validate_every_epoch(mock_callbacks, train_args) + + assert mock_adaptive_schedule_hook.max_interval == 1 + assert train_args == {} + + +def test_set_to_validate_every_epoch_no_adap_schedule(): + """Check check_val_every_n_epoch is added to train_args if AdaptiveTrainScheduling isn't in callback list.""" + train_args = {} + _set_to_validate_every_epoch(callbacks=[], train_args=train_args) + + assert train_args["check_val_every_n_epoch"] == 1 diff --git a/tests/unit/engine/hpo/test_utils.py b/tests/unit/engine/hpo/test_utils.py new file mode 100644 index 00000000000..cf9b67c9e5b --- /dev/null +++ b/tests/unit/engine/hpo/test_utils.py @@ -0,0 +1,117 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Tests for HPO API utility functions.""" + +import json +from unittest.mock import MagicMock + +import pytest +from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint +from otx.engine.hpo.utils import ( + find_trial_file, + get_best_hpo_weight, + get_callable_args_name, + get_hpo_weight_dir, + get_metric, +) + + +@pytest.fixture() +def trial_id(): + return "1" + + +@pytest.fixture() +def trial_file(tmp_path, trial_id): + trial_file = tmp_path / "hpo" / "trial" / "0" / f"{trial_id}.json" + trial_file.parent.mkdir(parents=True) + with trial_file.open("w") as f: + json.dump( + { + "id": trial_id, + "rung": 0, + "configuration": {"lr": 0.1, "iterations": 10}, + "score": {"1": 0.1, "2": 0.2, "3": 0.5, "4": 0.4, "5": 0.5}, + }, + f, + ) + return trial_file + + +def test_find_trial_file(tmp_path, trial_file, trial_id): + assert trial_file == find_trial_file(tmp_path, trial_id) + + +def test_find_trial_file_file_not_exist(tmp_path, trial_file): # noqa: ARG001 + assert find_trial_file(tmp_path, "2") is None + + +@pytest.fixture() +def hpo_weight_dir(tmp_path, trial_id): + weight_dir = tmp_path / "weight" / trial_id + weight_dir.mkdir(parents=True) + return weight_dir + + +def test_get_best_hpo_weight(trial_file, hpo_weight_dir): + for eph in range(1, 6): + (hpo_weight_dir / f"epoch_{eph}.ckpt").touch() + + assert hpo_weight_dir / "epoch_4.ckpt" == get_best_hpo_weight(hpo_weight_dir, trial_file) + + +def test_get_absent_best_hpo_weight(trial_file, hpo_weight_dir): + assert get_best_hpo_weight(hpo_weight_dir, trial_file) is None + + +def test_get_hpo_weight_dir(tmp_path, hpo_weight_dir, trial_id): + assert hpo_weight_dir == get_hpo_weight_dir(tmp_path, trial_id) + + +def test_get_absent_hpo_weight_dir(tmp_path, hpo_weight_dir, trial_id): + hpo_weight_dir.rmdir() + assert hpo_weight_dir == get_hpo_weight_dir(tmp_path, trial_id) + assert hpo_weight_dir.exists() + + +def test_get_callable_args_name(): + def func(arg1, arg2) -> None: # noqa: ARG001 + pass + + assert get_callable_args_name(func) == ["arg1", "arg2"] + + +def test_get_callable_args_name_no_args(): + def func() -> None: + pass + + assert get_callable_args_name(func) == [] + + +@pytest.fixture() +def mock_model_ckpt_hook() -> MagicMock: + model_ckpt_hook = MagicMock(spec=ModelCheckpoint) + model_ckpt_hook.monitor = "val/accuracy" + return model_ckpt_hook + + +def test_get_metric(mock_model_ckpt_hook): + assert get_metric(mock_model_ckpt_hook) == "val/accuracy" + + +def test_get_metric_list_callback(mock_model_ckpt_hook): + callbacks = [mock_model_ckpt_hook] + assert get_metric(callbacks) == "val/accuracy" + + +def test_get_metric_no_model_ckpt_callback(): + callbacks = [MagicMock()] + with pytest.raises(RuntimeError, match="Failed to find a metric"): + get_metric(callbacks) + + +def test_get_metric_list_monitor_value_none(mock_model_ckpt_hook): + mock_model_ckpt_hook.monitor = None + with pytest.raises(ValueError, match="Failed to find a metric"): + get_metric(mock_model_ckpt_hook)