Skip to content

Commit

Permalink
Add zdt, and modify hpo_wrapper, rvea and so on.
Browse files Browse the repository at this point in the history
  • Loading branch information
XU-Boqing committed Jan 23, 2025
1 parent ab806b6 commit 7ad7da7
Show file tree
Hide file tree
Showing 8 changed files with 255 additions and 21 deletions.
4 changes: 3 additions & 1 deletion src/evox/algorithms/mo/rvea.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ def _trace_mating_pool(self):
no_nan_pop = ~torch.isnan(self.pop).all(dim=1)
max_idx = torch.sum(no_nan_pop, dtype=torch.int32)
mating_pool = torch.randint(0, max_idx, (self.pop_size,), device=self.device)
pop = self.pop[torch.nonzero(no_nan_pop)[mating_pool].squeeze()]
pop_index = torch.where(no_nan_pop, torch.arange(self.pop_size), torch.inf)
pop_index = torch.argsort(pop_index, stable=True)
pop = self.pop[pop_index[mating_pool].squeeze()]
return pop

def _update_pop_and_rv(self, survivor: torch.Tensor, survivor_fit: torch.Tensor):
Expand Down
4 changes: 2 additions & 2 deletions src/evox/core/_vmap_fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def wrap_batch_tensor(tensor: torch.Tensor, in_dims: int | Tuple[int, ...]) -> t
"""
assert get_level(tensor) <= 0, f"Expect vmap level of tensor to be none, got {get_level(tensor)}"
if not isinstance(in_dims, Sequence):
in_dims = tuple(in_dims)
in_dims = (in_dims,)
for level, dim in enumerate(in_dims, 1):
tensor = add_batch_dim(tensor, dim, level)
return tensor
Expand Down Expand Up @@ -302,7 +302,7 @@ def _batch_getitem(tensor: torch.Tensor, indices, dim=0):
if level is None or level <= 0:
return _original_get_item(tensor, indices)
# else
if isinstance(indices, torch.Tensor) and indices.ndim <= 1:
if isinstance(indices, torch.Tensor) and indices.dtype == torch.int64 and indices.ndim <= 1:
tensor = torch.index_select(tensor, dim, indices)
if indices.ndim == 0:
tensor = tensor.__getitem__(*(([slice(None)] * dim) + [0]))
Expand Down
6 changes: 5 additions & 1 deletion src/evox/metrics/igd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,9 @@ def igd(objs: torch.Tensor, pf: torch.Tensor, p: int = 1):
:note:
The IGD score is lower when the approximation is closer to the Pareto front.
"""
min_dis = torch.cdist(pf, objs).min(dim=1).values
nan_idx = torch.any(torch.isnan(objs), dim=1)
objs = torch.nan_to_num(objs)
dis = torch.cdist(pf, objs)
dis = torch.where(nan_idx[None, :], torch.inf, dis)
min_dis = dis.min(dim=1).values
return (min_dis.pow(p).mean()).pow(1 / p)
112 changes: 96 additions & 16 deletions src/evox/problems/hpo_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,39 @@
from abc import ABC
from typing import Callable, Dict, Optional

import torch
from torch import nn

from ..core import Monitor, Mutable, Problem, Workflow, jit, jit_class, use_state, vmap
from ..core import Monitor, Mutable, Problem, Workflow, jit, jit_class, use_state, vmap, vmap_impl
from ..core._vmap_fix import unwrap_batch_tensor, wrap_batch_tensor
from ..core.module import _WrapClassBase


class HPOMonitor(Monitor, ABC):
class HPOMonitor(Monitor):
"""The base class for hyper parameter optimization (HPO) monitors used in `HPOProblem.workflow.monitor`."""

def __init__(self):
def __init__(self, num_repeats: int = 1):
"""
Initialize the HPO monitor.
:param num_repeats: The number of workflow repeats to be executed in the optimization process. Defaults to 1.
"""
super().__init__()
self.num_repeats = num_repeats

def vmap_fitness_unwrap_repeats(self, fitness: torch.Tensor):
"""
Unwrap the `num_repeats` batch dimension from the given fitness tensor.
:param fitness: The fitness tensor to be unwrapped.
:return: The unwrapped fitness tensor.
"""
assert 1 <= fitness.ndim <= 2
original_fitness, batch_dims, _ = unwrap_batch_tensor(fitness)
assert batch_dims == (0,)
new_fitness = original_fitness.view(self.num_repeats, -1, *fitness.size())
pop_size = new_fitness.size(1)
new_fitness = wrap_batch_tensor(new_fitness, (1,))
return new_fitness, pop_size

def tell_fitness(self) -> torch.Tensor:
"""Get the best fitness found so far in the optimization process that this monitor is monitoring.
Expand All @@ -22,37 +43,73 @@ def tell_fitness(self) -> torch.Tensor:
raise NotImplementedError("`tell_fitness` function is not implemented. It must be overwritten.")


@jit_class
class HPOFitnessMonitor(HPOMonitor):
"""The monitor for hyper parameter optimization (HPO) that records the best fitness found so far in the optimization process."""

def __init__(self, multi_obj_metric: Optional[Callable] = None):
def __init__(
self,
*,
num_repeats: int = 1,
fit_aggregation: Optional[Callable[[torch.Tensor, int], torch.Tensor]] = torch.mean,
multi_obj_metric: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
):
"""
Initialize the HPO fitness monitor.
:param multi_obj_metric: The metric function to use for multi-objective optimization, unused in single-objective optimization.
Currently we only support "IGD" or "HV" for multi-objective optimization. Defaults to `None`.
:param num_repeats: The number of workflow repeats to be executed in the optimization process. Defaults to 1.
:param fit_aggregation: The aggregation function to use for fitness aggregation when `num_repeats` > 1. Defaults to `torch.mean`.
:param multi_obj_metric: The metric function to use for multi-objective optimization, unused in single-objective optimization. Defaults to `None`.
"""
super().__init__()
super().__init__(num_repeats=num_repeats)
assert multi_obj_metric is None or callable(multi_obj_metric), (
f"Expect `multi_obj_metric` to be `None` or callable, got {multi_obj_metric}"
)
self.fit_aggregation = fit_aggregation
self.multi_obj_metric = multi_obj_metric
self.best_fitness = Mutable(torch.tensor(torch.inf))

def _fitness_unwrap(self, fitness: torch.Tensor):
return fitness

@vmap_impl(_fitness_unwrap)
def _vmap_fitness_unwrap(self, fitness: torch.Tensor):
fitness, pop_size = self.vmap_fitness_unwrap_repeats(fitness)
original_best = unwrap_batch_tensor(self.best_fitness)[0]
if original_best.size(0) != pop_size:
self.best_fitness = wrap_batch_tensor(original_best[:pop_size], 0)
return fitness

def _fitness_wrap(self, best_fitness: torch.Tensor):
return best_fitness

@vmap_impl(_fitness_wrap)
def _vmap_fitness_wrap(self, best_fitness: torch.Tensor):
original_best = unwrap_batch_tensor(best_fitness)[0]
original_best = original_best.repeat(self.num_repeats)
return wrap_batch_tensor(original_best, 0)

def pre_tell(self, fitness: torch.Tensor):
"""Update the best fitness value found so far based on the provided fitness tensor and multi-objective metric.
:param fitness: A tensor representing fitness values. It can be either a 1D tensor for single-objective optimization or a 2D tensor for multi-objective optimization.
:raises AssertionError: If the dimensionality of the fitness tensor is not 1 or 2.
"""
assert 1 <= fitness.ndim <= 2
if fitness.ndim == 1:
fitness = self._fitness_unwrap(fitness)
has_repeat = 1 if self.num_repeats > 1 else 0
assert 1 <= fitness.ndim - has_repeat <= 2
if fitness.ndim - has_repeat == 1:
# single-objective
self.best_fitness = torch.min(torch.min(fitness), self.best_fitness)
if self.num_repeats > 1:
fitness = self.fit_aggregation(fitness, dim=0)
best_fitness = torch.min(torch.min(fitness), self.best_fitness)
else:
# multi-objective
self.best_fitness = torch.min(self.multi_obj_metric(fitness), self.best_fitness)
if self.num_repeats > 1:
fitness = self.fit_aggregation(fitness, dim=0)
best_fitness = torch.min(self.multi_obj_metric(fitness), self.best_fitness)
self.best_fitness = self._fitness_wrap(best_fitness)

def tell_fitness(self) -> torch.Tensor:
"""Get the best fitness found so far in the optimization process that this monitor is monitoring.
Expand All @@ -61,6 +118,11 @@ def tell_fitness(self) -> torch.Tensor:
"""
return self.best_fitness

@vmap_impl(tell_fitness)
def _vmap_tell_fitness(self) -> torch.Tensor:
actual_best = unwrap_batch_tensor(self.best_fitness)[0].view(self.num_repeats, -1)[0]
return wrap_batch_tensor(actual_best, 0)


@jit_class
class HPOProblemWrapper(Problem):
Expand All @@ -83,7 +145,9 @@ class HPOProblemWrapper(Problem):
```
"""

def __init__(self, iterations: int, num_instances: int, workflow: Workflow, copy_init_state: bool = True):
def __init__(
self, iterations: int, num_instances: int, workflow: Workflow, num_repeats: int = 1, copy_init_state: bool = True
):
"""Initialize the HPO problem wrapper.
:param iterations: The number of iterations to be executed in the optimization process.
Expand All @@ -96,13 +160,15 @@ def __init__(self, iterations: int, num_instances: int, workflow: Workflow, copy
assert num_instances > 0, f"`num_instances` should be greater than 0, got {num_instances}"
self.iterations = iterations
self.num_instances = num_instances
self.num_repeats = num_repeats
self.copy_init_state = copy_init_state
# compile workflow steps
assert isinstance(workflow, _WrapClassBase), f"Expect `workflow` to be wrapped by `jit_class`, got {type(workflow)}"
workflow.__sync__()
# check monitor
monitor = workflow.get_submodule("monitor")
assert isinstance(monitor, HPOMonitor), f"Expect workflow monitor to be `HPOMonitor`, got {type(monitor)}"
monitor.num_repeats = num_repeats
monitor_state = monitor.state_dict(keep_vars=True)
state_step = use_state(lambda: workflow.step)
# get monitor's corresponding keys in init_state
Expand All @@ -128,9 +194,9 @@ def get_monitor_fitness(x: Dict[str, torch.Tensor]):

# JIT workflow step
vmap_state_step = vmap(state_step)
init_state = vmap_state_step.init_state(self.num_instances)
init_state = vmap_state_step.init_state(self.num_instances * self.num_repeats)
self._workflow_step_: torch.jit.ScriptFunction = jit(vmap_state_step, trace=True, example_inputs=(init_state,))
self._get_monitor_fitness_ = jit(get_monitor_fitness, trace=True, example_inputs=(init_state,))
self._get_monitor_fitness_ = jit(vmap(get_monitor_fitness, trace=False), trace=True, example_inputs=(init_state,))
monitor.load_state_dict(monitor_state)
# if no init step
if type(workflow).init_step == Workflow.init_step:
Expand All @@ -140,11 +206,24 @@ def get_monitor_fitness(x: Dict[str, torch.Tensor]):
# otherwise, JIT workflow init step
state_init_step = use_state(lambda: workflow.init_step)
vmap_state_init_step = vmap(state_init_step)
self._init_state_ = vmap_state_init_step.init_state(self.num_instances)
self._init_state_ = vmap_state_init_step.init_state(self.num_instances * self.num_repeats)
self._workflow_init_step_: torch.jit.ScriptFunction = jit(
vmap_state_init_step, trace=True, example_inputs=(self._init_state_,)
)

# JIT expand hyperparameters
def expand_hyper_params(hyper_parameters: Dict[str, torch.Tensor]):
if num_repeats > 1:
return {k: v.repeat(num_repeats, *([1] * (v.ndim - 1))) for k, v in hyper_parameters.items()}
else:
return hyper_parameters

self._expand_hp_ = jit(
expand_hyper_params,
trace=True,
example_inputs=({k: v for k, v in self._init_state_.items() if k in hyper_param_keys},),
)

def evaluate(self, hyper_parameters: Dict[str, nn.Parameter]):
"""
Evaluate the fitness (given by the internal workflow's monitor) of the batch of hyper parameters by running the internal workflow.
Expand All @@ -166,6 +245,7 @@ def evaluate(self, hyper_parameters: Dict[str, nn.Parameter]):
state[k] = v.clone()
else:
state = self._init_state_
hyper_parameters = self._expand_hp_(hyper_parameters)
state.update(hyper_parameters)
state = self._workflow_init_step_(state)
for _ in range(self.iterations - 1):
Expand Down
6 changes: 6 additions & 0 deletions src/evox/problems/numerical/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,14 @@
"DTLZ5",
"DTLZ6",
"DTLZ7",
"ZDT1",
"ZDT2",
"ZDT3",
"ZDT4",
"ZDT6",
]

from .basic import Ackley, Griewank, Rastrigin, Rosenbrock, Schwefel, Sphere
from .cec2022 import CEC2022
from .dtlz import DTLZ1, DTLZ2, DTLZ3, DTLZ4, DTLZ5, DTLZ6, DTLZ7
from .zdt import ZDT1, ZDT2, ZDT3, ZDT4, ZDT6
106 changes: 106 additions & 0 deletions src/evox/problems/numerical/zdt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from functools import partial

import torch

from ...core import Problem, jit_class


def _generic_zdt(f1, g, h, x):
f1_x = f1(x)
g_x = g(x)
return torch.stack([f1_x, g_x * h(f1_x, g_x)],dim=1)


class ZDTTestSuit(Problem):
def __init__(self, n: int, ref_num: int = 100):
super().__init__()
self.n = n
self._zdt = None
self.ref_num = ref_num

def evaluate(self, X: torch.Tensor):
return self._zdt(X)

def pf(self):
x = torch.linspace(0, 1, self.ref_num)
return torch.stack([x, 1 - torch.sqrt(x)], dim=1)


@jit_class
class ZDT1(ZDTTestSuit):
def __init__(self, n):
super().__init__(n)
f1 = lambda x: x[:,0]

Check failure on line 33 in src/evox/problems/numerical/zdt.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E731)

src/evox/problems/numerical/zdt.py:33:9: E731 Do not assign a `lambda` expression, use a `def`
g = lambda x: 1 + 9 * torch.mean(x[:,1:])

Check failure on line 34 in src/evox/problems/numerical/zdt.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E731)

src/evox/problems/numerical/zdt.py:34:9: E731 Do not assign a `lambda` expression, use a `def`
h = lambda f1, g: 1 - torch.sqrt(f1 / g)

Check failure on line 35 in src/evox/problems/numerical/zdt.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E731)

src/evox/problems/numerical/zdt.py:35:9: E731 Do not assign a `lambda` expression, use a `def`
self._zdt = partial(_generic_zdt, f1, g, h)


@jit_class
class ZDT2(ZDTTestSuit):
def __init__(self, n):
super().__init__(n)
f1 = lambda x: x[:,0]

Check failure on line 43 in src/evox/problems/numerical/zdt.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E731)

src/evox/problems/numerical/zdt.py:43:9: E731 Do not assign a `lambda` expression, use a `def`
g = lambda x: 1 + 9 * torch.mean(x[:,1:])

Check failure on line 44 in src/evox/problems/numerical/zdt.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E731)

src/evox/problems/numerical/zdt.py:44:9: E731 Do not assign a `lambda` expression, use a `def`
h = lambda f1_val, g_val: 1 - (f1_val / g_val) ** 2

Check failure on line 45 in src/evox/problems/numerical/zdt.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E731)

src/evox/problems/numerical/zdt.py:45:9: E731 Do not assign a `lambda` expression, use a `def`
self._zdt = partial(_generic_zdt, f1, g, h)

def pf(self):
x = torch.linspace(0, 1, self.ref_num)
return torch.stack([x, 1 - x**2], dim=1)


@jit_class
class ZDT3(ZDTTestSuit):
def __init__(self, n):
super().__init__(n)
f1 = lambda x: x[:,0]

Check failure on line 57 in src/evox/problems/numerical/zdt.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E731)

src/evox/problems/numerical/zdt.py:57:9: E731 Do not assign a `lambda` expression, use a `def`
g = lambda x: 1 + 9 * torch.mean(x[:,1:])

Check failure on line 58 in src/evox/problems/numerical/zdt.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E731)

src/evox/problems/numerical/zdt.py:58:9: E731 Do not assign a `lambda` expression, use a `def`
h = lambda f1, g: 1 - torch.sqrt(f1 / g) - (f1 / g) * torch.sin(10 * torch.pi * f1)

Check failure on line 59 in src/evox/problems/numerical/zdt.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E731)

src/evox/problems/numerical/zdt.py:59:9: E731 Do not assign a `lambda` expression, use a `def`
self._zdt = partial(_generic_zdt, f1, g, h)

def pf(self):
r = torch.tensor(
[
[0.0000, 0.0830],
[0.1822, 0.2577],
[0.4093, 0.4538],
[0.6183, 0.6525],
[0.8233, 0.8518],
]
)

pf_points = []
segment_size = self.ref_num // len(r)
for row in r:
x_vals = torch.linspace(row[0].item(), row[1].item(), segment_size)
f2_vals = 1 - torch.sqrt(x_vals) - x_vals * torch.sin(10 * torch.pi * x_vals)
pf_points.append(torch.stack([x_vals, f2_vals], dim=1))
pf = torch.cat(pf_points, dim=0)
return pf


@jit_class
class ZDT4(ZDTTestSuit):
def __init__(self, n):
super().__init__(n)
f1 = lambda x: x[:,0]

Check failure on line 87 in src/evox/problems/numerical/zdt.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E731)

src/evox/problems/numerical/zdt.py:87:9: E731 Do not assign a `lambda` expression, use a `def`
g= lambda x: 1 + 10 * (self.n - 1) + torch.sum(x[:,1:] ** 2 - 10.0 * torch.cos(4.0 * torch.pi * x[:,1:]))
h = lambda f1_val, g_val: 1 - torch.sqrt(f1_val / g_val)
self._zdt = partial(_generic_zdt, f1, g, h)


@jit_class
class ZDT6(ZDTTestSuit):
def __init__(self, n):
super().__init__(n)
f1 = lambda x: 1 - torch.exp(-4.0 * x[:,0]) * torch.sin(6.0 * torch.pi * x[:,0]) ** 6
g = lambda x: 1 + 9.0 * (torch.sum(x[:,1:]) / 9.0) ** 0.25
h = lambda f1_val, g_val: 1 - (f1_val / g_val) ** 2
self._zdt = partial(_generic_zdt, f1, g, h)

def pf(self):
min_f1 = 0.280775
f1_vals = torch.linspace(min_f1, 1.0, self.ref_num)
f2_vals = 1.0 - f1_vals**2
return torch.stack([f1_vals, f2_vals], dim=1)
7 changes: 6 additions & 1 deletion unit_test/problems/test_dtlz.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from evox.problems.numerical import DTLZ1, DTLZ2, DTLZ3, DTLZ4, DTLZ5, DTLZ6, DTLZ7


class TestBraxProblem(TestCase):
class TestDTLZ(TestCase):
def setUp(self):
d = 12
m = 3
Expand All @@ -26,3 +26,8 @@ def test_dtlz(self):
assert fit.size() == (2, 3)
pf = pro.pf()
assert pf.size(1) == 3

if __name__ == "__main__":
test = TestDTLZ()
test.setUp()
test.test_dtlz()
Loading

0 comments on commit 7ad7da7

Please sign in to comment.