Skip to content

Commit

Permalink
Add unit test for engine/hpo (#3406)
Browse files Browse the repository at this point in the history
* implement util unit test

* implement draft of test_hpo_api

* implement engine/hpo unit test

* udpate hpo_api unit test

* udpate hpo_util unit test

* align with pre-commit

* fix typo

* fix type hint
  • Loading branch information
eunwoosh authored Apr 29, 2024
1 parent ff2cf14 commit 96a7cd5
Show file tree
Hide file tree
Showing 6 changed files with 685 additions and 5 deletions.
8 changes: 4 additions & 4 deletions src/otx/engine/hpo/hpo_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ def execute_hpo(
)

best_trial = hpo_algo.get_best_config()
if best_trial is None:
best_config = None
best_hpo_weight = None
else:

best_config = None
best_hpo_weight = None
if best_trial is not None:
best_config = best_trial["configuration"]
if (trial_file := find_trial_file(hpo_workdir, best_trial["id"])) is not None:
best_hpo_weight = get_best_hpo_weight(get_hpo_weight_dir(hpo_workdir, best_trial["id"]), trial_file)
Expand Down
5 changes: 4 additions & 1 deletion src/otx/engine/hpo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def get_metric(callbacks: list[Callback] | Callback) -> str:

for callback in callbacks:
if isinstance(callback, ModelCheckpoint):
return callback.monitor
if (metric := callback.monitor) is None:
msg = "Failed to find a metric. 'monitor' value of ModelCheckpoint callback is set to None."
raise ValueError(msg)
return metric
msg = "Failed to find a metric. There is no ModelCheckpoint in callback list."
raise RuntimeError(msg)
2 changes: 2 additions & 0 deletions tests/unit/engine/hpo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
328 changes: 328 additions & 0 deletions tests/unit/engine/hpo/test_hpo_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,328 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Tests for HPO API utility functions."""

from __future__ import annotations

from typing import TYPE_CHECKING
from unittest.mock import MagicMock

import pytest
from otx.core.config.hpo import HpoConfig
from otx.core.optimizer.callable import OptimizerCallableSupportHPO
from otx.core.schedulers import LinearWarmupSchedulerCallable, SchedulerCallableSupportHPO
from otx.engine.hpo import hpo_api as target_file
from otx.engine.hpo.hpo_api import (
HPOConfigurator,
_adjust_train_args,
_remove_unused_model_weights,
_update_hpo_progress,
execute_hpo,
)

if TYPE_CHECKING:
from pathlib import Path

HPO_NAME_MAP: dict[str, str] = {
"lr": "model.optimizer_callable.optimizer_kwargs.lr",
"bs": "datamodule.config.train_subset.batch_size",
}


@pytest.fixture()
def engine_work_dir(tmp_path: Path) -> Path:
return tmp_path


@pytest.fixture()
def dataset_size() -> int:
return 10


@pytest.fixture()
def default_bs() -> int:
return 8


@pytest.fixture()
def default_lr() -> float:
return 0.001


@pytest.fixture()
def mock_engine(engine_work_dir: Path, dataset_size: int, default_bs: int, default_lr: float) -> MagicMock:
engine = MagicMock()
engine.work_dir = engine_work_dir
engine.datamodule.subsets = {engine.datamodule.config.train_subset.subset_name: range(dataset_size)}
engine.datamodule.config.train_subset.batch_size = default_bs
engine.model.optimizer_callable = MagicMock(spec=OptimizerCallableSupportHPO)
engine.model.optimizer_callable.lr = default_lr
engine.model.optimizer_callable.optimizer_kwargs = {"lr": default_lr}
return engine


@pytest.fixture()
def mock_hpo_algo() -> MagicMock:
hpo_algo = MagicMock()
hpo_algo.get_best_config.return_value = {"configuration": "best_config", "id": "best_id"}
return hpo_algo


@pytest.fixture()
def mock_hpo_configurator(mocker, mock_hpo_algo: MagicMock) -> HPOConfigurator:
hpo_configurator = MagicMock()
hpo_configurator.get_hpo_algo.return_value = mock_hpo_algo
mocker.patch.object(target_file, "HPOConfigurator", return_value=hpo_configurator)
return hpo_configurator


@pytest.fixture()
def mock_run_hpo_loop(mocker) -> MagicMock:
return mocker.patch.object(target_file, "run_hpo_loop")


@pytest.fixture()
def mock_thread(mocker) -> MagicMock:
return mocker.patch.object(target_file, "Thread")


@pytest.fixture()
def mock_get_best_hpo_weight(mocker) -> MagicMock:
return mocker.patch.object(target_file, "get_best_hpo_weight")


@pytest.fixture()
def mock_find_trial_file(mocker) -> MagicMock:
return mocker.patch.object(target_file, "find_trial_file")


@pytest.fixture()
def hpo_config() -> HpoConfig:
return HpoConfig(metric_name="val/accuracy")


@pytest.fixture()
def mock_progress_update_callback() -> MagicMock:
return MagicMock()


def test_execute_hpo(
mock_engine: MagicMock,
hpo_config: HpoConfig,
engine_work_dir: Path,
mock_run_hpo_loop: MagicMock,
mock_thread: MagicMock,
mock_hpo_configurator: HPOConfigurator, # noqa: ARG001
mock_hpo_algo: MagicMock,
mock_get_best_hpo_weight: MagicMock,
mock_find_trial_file: MagicMock, # noqa: ARG001
mock_progress_update_callback: MagicMock,
):
best_config, best_hpo_weight = execute_hpo(
engine=mock_engine,
max_epochs=10,
hpo_config=hpo_config,
progress_update_callback=mock_progress_update_callback,
)

# check hpo workdir exists
assert (engine_work_dir / "hpo").exists()
# check a case where progress_update_callback exists
mock_thread.assert_called_once()
assert mock_thread.call_args.kwargs["target"] == _update_hpo_progress
assert mock_thread.call_args.kwargs["args"][0] == mock_progress_update_callback
assert mock_thread.call_args.kwargs["daemon"] is True
mock_thread.return_value.start.assert_called_once()
# check whether run_hpo_loop is called well
mock_run_hpo_loop.assert_called_once()
assert mock_run_hpo_loop.call_args.args[0] == mock_hpo_algo
# print_result is called after HPO is done
mock_hpo_algo.print_result.assert_called_once()
# best_config and best_hpo_weight are returned well
assert best_config == "best_config"
assert best_hpo_weight == mock_get_best_hpo_weight.return_value


class TestHPOConfigurator:
def test_init(self, mock_engine: MagicMock, hpo_config: HpoConfig):
HPOConfigurator(mock_engine, 10, hpo_config)

def test_hpo_config(
self,
mock_engine: MagicMock,
hpo_config: HpoConfig,
dataset_size: int,
default_lr: float,
default_bs: int,
):
max_epochs = 10
hpo_configurator = HPOConfigurator(mock_engine, max_epochs, hpo_config=hpo_config)
hpo_config = hpo_configurator.hpo_config

# check default hpo config is set well
assert hpo_config["save_path"] == str(mock_engine.work_dir / "hpo")
assert hpo_config["num_full_iterations"] == max_epochs
assert hpo_config["full_dataset_size"] == dataset_size
# check search_space is set automatically
assert hpo_config["search_space"] is not None
# check max range of batch size isn't bigger than dataset size
assert hpo_config["search_space"][HPO_NAME_MAP["bs"]]["max"] == dataset_size
# check current hyper parameter will be tested first
assert hpo_config["prior_hyper_parameters"] == {HPO_NAME_MAP["lr"]: default_lr, HPO_NAME_MAP["bs"]: default_bs}

def test_get_default_search_space(self, mock_engine: MagicMock, hpo_config: HpoConfig):
hpo_configurator = HPOConfigurator(mock_engine, 10, hpo_config)
search_sapce = hpo_configurator._get_default_search_space()

for hp_name in HPO_NAME_MAP.values():
assert hp_name in search_sapce

def test_align_lr_bs_name(self, mock_engine: MagicMock, hpo_config: HpoConfig):
"""Check learning rate and batch size names are aligned well."""
search_space = {
"model.optimizer.lr": {
"type": "loguniform",
"min": 0.0001,
"max": 0.1,
},
"data.config.train_subset.batch_size": {
"type": "quniform",
"min": 2,
"max": 512,
"step": 1,
},
}
hpo_config.search_space = search_space

hpo_configurator = HPOConfigurator(mock_engine, 10, hpo_config)

for new_name in HPO_NAME_MAP.values():
assert new_name in hpo_configurator.hpo_config["search_space"]

def test_align_scheduler_callable_support_hpo_name(self, mock_engine: MagicMock, hpo_config: HpoConfig):
"""Check scheduler name is aligned well if class of scheduler is SchedulerCallableSupportHPO."""
mock_engine.model.scheduler_callable = MagicMock(spec=SchedulerCallableSupportHPO)
mock_engine.model.scheduler_callable.factor = 0.001
mock_engine.model.scheduler_callable.scheduler_kwargs = {"factor": 0.001}
search_space = {
"model.scheduler.factor": {
"type": "loguniform",
"min": 0.0001,
"max": 0.1,
},
}
hpo_config.search_space = search_space

hpo_configurator = HPOConfigurator(mock_engine, 10, hpo_config)

assert "model.scheduler_callable.scheduler_kwargs.factor" in hpo_configurator.hpo_config["search_space"]

def test_align_linear_warmup_scheduler_callable_name(self, mock_engine: MagicMock, hpo_config: HpoConfig):
"""Check scheduler name is aligned well if class of scheduler is LinearWarmupSchedulerCallable."""
scheduler_callable = MagicMock(spec=LinearWarmupSchedulerCallable)
scheduler_callable.num_warmup_steps = 0.001
main_scheduler_callable = MagicMock()
main_scheduler_callable.factor = 0.001
main_scheduler_callable.scheduler_kwargs = {"factor": 0.001}
scheduler_callable.main_scheduler_callable = main_scheduler_callable
mock_engine.model.scheduler_callable = scheduler_callable
search_space = {
"model.scheduler.num_warmup_steps": {
"type": "loguniform",
"min": 0.0001,
"max": 0.1,
},
"model.scheduler.main_scheduler_callable.factor": {
"type": "loguniform",
"min": 0.0001,
"max": 0.1,
},
}
hpo_config.search_space = search_space

hpo_configurator = HPOConfigurator(mock_engine, 10, hpo_config)

assert "model.scheduler_callable.num_warmup_steps" in hpo_configurator.hpo_config["search_space"]
assert (
"model.scheduler_callable.main_scheduler_callable.scheduler_kwargs.factor"
in hpo_configurator.hpo_config["search_space"]
)

def test_remove_wrong_search_space(self, mock_engine: MagicMock, hpo_config: HpoConfig):
hpo_configurator = HPOConfigurator(mock_engine, 10, hpo_config)
wrong_search_space = {
"wrong_choice": {
"type": "choice",
"choice_list": [], # choice shouldn't be empty
},
"wrong_quniform": {
"type": "quniform",
"min": 2,
"max": 3, # max should be larger than min + step
"step": 2,
},
}
hpo_configurator._remove_wrong_search_space(wrong_search_space)
assert wrong_search_space == {}

def test_get_hpo_algo(self, mocker, mock_engine: MagicMock, hpo_config: HpoConfig):
hpo_configurator = HPOConfigurator(mock_engine, 10, hpo_config)
mock_hyper_band = mocker.patch.object(target_file, "HyperBand")
hpo_configurator.get_hpo_algo()

mock_hyper_band.assert_called_once()
assert mock_hyper_band.call_args.kwargs == hpo_configurator.hpo_config


def test_update_hpo_progress(mocker, mock_progress_update_callback: MagicMock):
mock_hpo_algo = MagicMock()
mock_hpo_algo.is_done.side_effect = [False, False, False, True]
progress_arr = [0.3, 0.6, 1]
mock_hpo_algo.get_progress.side_effect = progress_arr
mocker.patch.object(target_file, "time")

_update_hpo_progress(mock_progress_update_callback, mock_hpo_algo)

mock_progress_update_callback.assert_called()
for i in range(3):
assert mock_progress_update_callback.call_args_list[i].args[0] == pytest.approx(progress_arr[i] * 100)


def test_adjust_train_args():
new_train_args = _adjust_train_args(
{
"self": "self",
"run_hpo": "run_hpo",
"kwargs": {
"kwargs_1": "kwargs_1",
"kwargs_2": "kwargs_2",
},
},
)

assert "self" not in new_train_args
assert "run_hpo" not in new_train_args
assert "kwargs" not in new_train_args
assert "kwargs_1" in new_train_args
assert "kwargs_2" in new_train_args


@pytest.fixture()
def mock_hpo_workdir(tmp_path: Path) -> Path:
(tmp_path / "1.ckpt").touch()
sub_dir = tmp_path / "a"
sub_dir.mkdir()
(sub_dir / "2.ckpt").touch()
return tmp_path


def test_remove_unused_model_weights(mock_hpo_workdir: Path):
best_weight = mock_hpo_workdir / "3.ckpt"
best_weight.touch()

_remove_unused_model_weights(mock_hpo_workdir, best_weight)

ckpt_files = list(mock_hpo_workdir.rglob("*.ckpt"))
assert len(ckpt_files) == 1
assert ckpt_files[0] == best_weight
Loading

0 comments on commit 96a7cd5

Please sign in to comment.