Skip to content

Commit

Permalink
Change code using fixed subset name in HPO (#3101)
Browse files Browse the repository at this point in the history
* change code using fixed subset name in hpo

* apply comment

* remove non_pure_train_ratio from HPO

* raise error once HPO trial exits abnormally

* update unit test

* align with pre-commit
  • Loading branch information
eunwoosh authored Mar 22, 2024
1 parent a50ac7f commit 77b470c
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 19 deletions.
4 changes: 1 addition & 3 deletions src/otx/engine/hpo/hpo_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,12 @@ def hpo_config(self) -> dict[str, Any]:

@hpo_config.setter
def hpo_config(self, hpo_config: HpoConfig | None) -> None:
train_dataset_size = len(self._engine.datamodule.subsets["train"])
val_dataset_size = len(self._engine.datamodule.subsets["val"])
train_dataset_size = len(self._engine.datamodule.train_dataloader())

self._hpo_config: dict[str, Any] = { # default setting
"save_path": str(self._hpo_workdir),
"num_full_iterations": self._max_epoch,
"full_dataset_size": train_dataset_size,
"non_pure_train_ratio": val_dataset_size / (train_dataset_size + val_dataset_size),
}

if hpo_config is not None:
Expand Down
9 changes: 0 additions & 9 deletions src/otx/hpo/hpo_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class HpoBase(ABC):
num_trials (int | None, optional): How many training to conduct for HPO.
num_workers (int, optional): How many trains are executed in parallel.
num_full_iterations (int, optional): epoch for traninig after HPO.
non_pure_train_ratio (float, optional): ratio of validation time to (train time + validation time)
full_dataset_size (int, optional): train dataset size
expected_time_ratio (int | float | None, optional): Time to use for HPO.
If HPO is configured automatically,
Expand Down Expand Up @@ -64,7 +63,6 @@ def __init__(
num_trials: int | None = None,
num_workers: int = 1,
num_full_iterations: int | float = 1,
non_pure_train_ratio: float = 0.2,
full_dataset_size: int = 0,
expected_time_ratio: int | float | None = None,
maximum_resource: int | float | None = None,
Expand All @@ -78,12 +76,6 @@ def __init__(
check_mode_input(mode)
check_positive(full_dataset_size, "full_dataset_size")
check_positive(num_full_iterations, "num_full_iterations")
if not 0 < non_pure_train_ratio <= 1:
error_msg = (
"non_pure_train_ratio should be greater than 0 and lesser than or equal to 1."
f"Your value is {subset_ratio}"
)
raise ValueError(error_msg)
if maximum_resource is not None:
check_positive(maximum_resource, "maximum_resource")
if num_trials is not None:
Expand All @@ -103,7 +95,6 @@ def __init__(
self.num_trials = num_trials
self.num_workers = num_workers
self.num_full_iterations = num_full_iterations
self.non_pure_train_ratio = non_pure_train_ratio
self.full_dataset_size = full_dataset_size
self.expected_time_ratio = expected_time_ratio
self.maximum_resource: int | float | None = maximum_resource
Expand Down
10 changes: 4 additions & 6 deletions src/otx/hpo/hpo_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def __init__(
self._mp = multiprocessing.get_context("spawn")
self._report_queue = self._mp.Queue()
self._uid_index = 0
self._trial_fault_count = 0
self._resource_manager = get_resource_manager(
resource_type,
num_parallel_trial,
Expand All @@ -83,7 +82,7 @@ def run(self) -> None:
"""Run a HPO loop."""
logger.info("HPO loop starts.")
try:
while not self._hpo_algo.is_done() and self._trial_fault_count < 3:
while not self._hpo_algo.is_done():
if self._resource_manager.have_available_resource():
trial = self._hpo_algo.get_next_sample()
if trial is not None:
Expand All @@ -98,9 +97,6 @@ def run(self) -> None:
raise e # noqa: TRY201
logger.info("HPO loop is done.")

if self._trial_fault_count >= 3:
logger.warning("HPO trials exited abnormally more than three times. HPO is suspended.")

self._get_reports()
self._join_all_processes()

Expand Down Expand Up @@ -143,7 +139,9 @@ def _remove_finished_process(self) -> None:
for uid, trial in self._running_trials.items():
if not trial.process.is_alive():
if trial.process.exitcode != 0:
self._trial_fault_count += 1
self._terminate_all_running_processes()
msg = "One of HPO trials exit abnormally."
raise RuntimeError(msg)
trial.queue.close()
trial.process.join()
trial_to_remove.append(uid)
Expand Down
1 change: 0 additions & 1 deletion tests/unit/hpo/test_hyperband.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def good_hyperband_args():
"mode": "max",
"num_workers": 1,
"num_full_iterations": 64,
"non_pure_train_ratio": 0.2,
"full_dataset_size": 100,
"maximum_resource": 64,
"minimum_resource": 1,
Expand Down

0 comments on commit 77b470c

Please sign in to comment.