Skip to content

Commit

Permalink
Fix hebo
Browse files Browse the repository at this point in the history
  • Loading branch information
y0z committed Jun 21, 2024
1 parent 335c0a6 commit 39080b4
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
2 changes: 2 additions & 0 deletions package/samplers/hebo/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
numpy
optuna
optunahub
pandas
hebo@git+https://github.com/huawei-noah/HEBO.git#subdirectory=HEBO
17 changes: 16 additions & 1 deletion package/samplers/hebo/sampler.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
from __future__ import annotations

from typing import Optional
from typing import Sequence

from optuna.distributions import BaseDistribution
from optuna.distributions import CategoricalDistribution
from optuna.distributions import FloatDistribution
from optuna.distributions import IntDistribution
from optuna.study import Study
from optuna.trial import FrozenTrial
from optuna.trial import FrozenTrial, TrialState
import optunahub

import numpy as np
import pandas as pd

from hebo.design_space.design_space import DesignSpace
from hebo.optimizers.hebo import HEBO

Expand All @@ -30,6 +36,15 @@ def sample_relative(
params[name] = params_pd[name].to_numpy()[0]
return params

def after_trial(
self,
study: Study,
trial: FrozenTrial,
state: TrialState,
values: Optional[Sequence[float]],
) -> None:
self._hebo.observe(pd.DataFrame([trial.params]), np.asarray([values]))

def _convert_to_hebo_design_space(
self, search_space: dict[str, BaseDistribution]
) -> DesignSpace:
Expand Down

0 comments on commit 39080b4

Please sign in to comment.