diff --git a/src/evox/algorithms/mo/rvea.py b/src/evox/algorithms/mo/rvea.py index 840918f3a..8a7600759 100644 --- a/src/evox/algorithms/mo/rvea.py +++ b/src/evox/algorithms/mo/rvea.py @@ -121,7 +121,9 @@ def _trace_mating_pool(self): no_nan_pop = ~torch.isnan(self.pop).all(dim=1) max_idx = torch.sum(no_nan_pop, dtype=torch.int32) mating_pool = torch.randint(0, max_idx, (self.pop_size,), device=self.device) - pop = self.pop[torch.nonzero(no_nan_pop)[mating_pool].squeeze()] + pop_index = torch.where(no_nan_pop, torch.arange(self.pop_size), torch.inf) + pop_index = torch.argsort(pop_index, stable=True) + pop = self.pop[pop_index[mating_pool].squeeze()] return pop def _update_pop_and_rv(self, survivor: torch.Tensor, survivor_fit: torch.Tensor): diff --git a/src/evox/core/_vmap_fix.py b/src/evox/core/_vmap_fix.py index fee962c82..6fc2e356a 100644 --- a/src/evox/core/_vmap_fix.py +++ b/src/evox/core/_vmap_fix.py @@ -84,7 +84,7 @@ def wrap_batch_tensor(tensor: torch.Tensor, in_dims: int | Tuple[int, ...]) -> t """ assert get_level(tensor) <= 0, f"Expect vmap level of tensor to be none, got {get_level(tensor)}" if not isinstance(in_dims, Sequence): - in_dims = tuple(in_dims) + in_dims = (in_dims,) for level, dim in enumerate(in_dims, 1): tensor = add_batch_dim(tensor, dim, level) return tensor @@ -302,7 +302,7 @@ def _batch_getitem(tensor: torch.Tensor, indices, dim=0): if level is None or level <= 0: return _original_get_item(tensor, indices) # else - if isinstance(indices, torch.Tensor) and indices.ndim <= 1: + if isinstance(indices, torch.Tensor) and indices.dtype == torch.int64 and indices.ndim <= 1: tensor = torch.index_select(tensor, dim, indices) if indices.ndim == 0: tensor = tensor.__getitem__(*(([slice(None)] * dim) + [0])) diff --git a/src/evox/metrics/igd.py b/src/evox/metrics/igd.py index 7c82c3581..58c6ee314 100644 --- a/src/evox/metrics/igd.py +++ b/src/evox/metrics/igd.py @@ -16,5 +16,9 @@ def igd(objs: torch.Tensor, pf: torch.Tensor, p: int = 1): :note: The IGD score is lower when the approximation is closer to the Pareto front. """ - min_dis = torch.cdist(pf, objs).min(dim=1).values + nan_idx = torch.any(torch.isnan(objs), dim=1) + objs = torch.nan_to_num(objs) + dis = torch.cdist(pf, objs) + dis = torch.where(nan_idx[None, :], torch.inf, dis) + min_dis = dis.min(dim=1).values return (min_dis.pow(p).mean()).pow(1 / p) diff --git a/src/evox/problems/hpo_wrapper.py b/src/evox/problems/hpo_wrapper.py index 9f5f4eb14..f13f5af48 100644 --- a/src/evox/problems/hpo_wrapper.py +++ b/src/evox/problems/hpo_wrapper.py @@ -1,18 +1,39 @@ -from abc import ABC from typing import Callable, Dict, Optional import torch from torch import nn -from ..core import Monitor, Mutable, Problem, Workflow, jit, jit_class, use_state, vmap +from ..core import Monitor, Mutable, Problem, Workflow, jit, jit_class, use_state, vmap, vmap_impl +from ..core._vmap_fix import unwrap_batch_tensor, wrap_batch_tensor from ..core.module import _WrapClassBase -class HPOMonitor(Monitor, ABC): +class HPOMonitor(Monitor): """The base class for hyper parameter optimization (HPO) monitors used in `HPOProblem.workflow.monitor`.""" - def __init__(self): + def __init__(self, num_repeats: int = 1): + """ + Initialize the HPO monitor. + + :param num_repeats: The number of workflow repeats to be executed in the optimization process. Defaults to 1. + """ super().__init__() + self.num_repeats = num_repeats + + def vmap_fitness_unwrap_repeats(self, fitness: torch.Tensor): + """ + Unwrap the `num_repeats` batch dimension from the given fitness tensor. + + :param fitness: The fitness tensor to be unwrapped. + :return: The unwrapped fitness tensor. + """ + assert 1 <= fitness.ndim <= 2 + original_fitness, batch_dims, _ = unwrap_batch_tensor(fitness) + assert batch_dims == (0,) + new_fitness = original_fitness.view(self.num_repeats, -1, *fitness.size()) + pop_size = new_fitness.size(1) + new_fitness = wrap_batch_tensor(new_fitness, (1,)) + return new_fitness, pop_size def tell_fitness(self) -> torch.Tensor: """Get the best fitness found so far in the optimization process that this monitor is monitoring. @@ -22,23 +43,52 @@ def tell_fitness(self) -> torch.Tensor: raise NotImplementedError("`tell_fitness` function is not implemented. It must be overwritten.") +@jit_class class HPOFitnessMonitor(HPOMonitor): """The monitor for hyper parameter optimization (HPO) that records the best fitness found so far in the optimization process.""" - def __init__(self, multi_obj_metric: Optional[Callable] = None): + def __init__( + self, + *, + num_repeats: int = 1, + fit_aggregation: Optional[Callable[[torch.Tensor, int], torch.Tensor]] = torch.mean, + multi_obj_metric: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, + ): """ Initialize the HPO fitness monitor. - :param multi_obj_metric: The metric function to use for multi-objective optimization, unused in single-objective optimization. - Currently we only support "IGD" or "HV" for multi-objective optimization. Defaults to `None`. + :param num_repeats: The number of workflow repeats to be executed in the optimization process. Defaults to 1. + :param fit_aggregation: The aggregation function to use for fitness aggregation when `num_repeats` > 1. Defaults to `torch.mean`. + :param multi_obj_metric: The metric function to use for multi-objective optimization, unused in single-objective optimization. Defaults to `None`. """ - super().__init__() + super().__init__(num_repeats=num_repeats) assert multi_obj_metric is None or callable(multi_obj_metric), ( f"Expect `multi_obj_metric` to be `None` or callable, got {multi_obj_metric}" ) + self.fit_aggregation = fit_aggregation self.multi_obj_metric = multi_obj_metric self.best_fitness = Mutable(torch.tensor(torch.inf)) + def _fitness_unwrap(self, fitness: torch.Tensor): + return fitness + + @vmap_impl(_fitness_unwrap) + def _vmap_fitness_unwrap(self, fitness: torch.Tensor): + fitness, pop_size = self.vmap_fitness_unwrap_repeats(fitness) + original_best = unwrap_batch_tensor(self.best_fitness)[0] + if original_best.size(0) != pop_size: + self.best_fitness = wrap_batch_tensor(original_best[:pop_size], 0) + return fitness + + def _fitness_wrap(self, best_fitness: torch.Tensor): + return best_fitness + + @vmap_impl(_fitness_wrap) + def _vmap_fitness_wrap(self, best_fitness: torch.Tensor): + original_best = unwrap_batch_tensor(best_fitness)[0] + original_best = original_best.repeat(self.num_repeats) + return wrap_batch_tensor(original_best, 0) + def pre_tell(self, fitness: torch.Tensor): """Update the best fitness value found so far based on the provided fitness tensor and multi-objective metric. @@ -46,13 +96,20 @@ def pre_tell(self, fitness: torch.Tensor): :raises AssertionError: If the dimensionality of the fitness tensor is not 1 or 2. """ - assert 1 <= fitness.ndim <= 2 - if fitness.ndim == 1: + fitness = self._fitness_unwrap(fitness) + has_repeat = 1 if self.num_repeats > 1 else 0 + assert 1 <= fitness.ndim - has_repeat <= 2 + if fitness.ndim - has_repeat == 1: # single-objective - self.best_fitness = torch.min(torch.min(fitness), self.best_fitness) + if self.num_repeats > 1: + fitness = self.fit_aggregation(fitness, dim=0) + best_fitness = torch.min(torch.min(fitness), self.best_fitness) else: # multi-objective - self.best_fitness = torch.min(self.multi_obj_metric(fitness), self.best_fitness) + if self.num_repeats > 1: + fitness = self.fit_aggregation(fitness, dim=0) + best_fitness = torch.min(self.multi_obj_metric(fitness), self.best_fitness) + self.best_fitness = self._fitness_wrap(best_fitness) def tell_fitness(self) -> torch.Tensor: """Get the best fitness found so far in the optimization process that this monitor is monitoring. @@ -61,6 +118,11 @@ def tell_fitness(self) -> torch.Tensor: """ return self.best_fitness + @vmap_impl(tell_fitness) + def _vmap_tell_fitness(self) -> torch.Tensor: + actual_best = unwrap_batch_tensor(self.best_fitness)[0].view(self.num_repeats, -1)[0] + return wrap_batch_tensor(actual_best, 0) + @jit_class class HPOProblemWrapper(Problem): @@ -83,7 +145,9 @@ class HPOProblemWrapper(Problem): ``` """ - def __init__(self, iterations: int, num_instances: int, workflow: Workflow, copy_init_state: bool = True): + def __init__( + self, iterations: int, num_instances: int, workflow: Workflow, num_repeats: int = 1, copy_init_state: bool = True + ): """Initialize the HPO problem wrapper. :param iterations: The number of iterations to be executed in the optimization process. @@ -96,6 +160,7 @@ def __init__(self, iterations: int, num_instances: int, workflow: Workflow, copy assert num_instances > 0, f"`num_instances` should be greater than 0, got {num_instances}" self.iterations = iterations self.num_instances = num_instances + self.num_repeats = num_repeats self.copy_init_state = copy_init_state # compile workflow steps assert isinstance(workflow, _WrapClassBase), f"Expect `workflow` to be wrapped by `jit_class`, got {type(workflow)}" @@ -103,6 +168,7 @@ def __init__(self, iterations: int, num_instances: int, workflow: Workflow, copy # check monitor monitor = workflow.get_submodule("monitor") assert isinstance(monitor, HPOMonitor), f"Expect workflow monitor to be `HPOMonitor`, got {type(monitor)}" + monitor.num_repeats = num_repeats monitor_state = monitor.state_dict(keep_vars=True) state_step = use_state(lambda: workflow.step) # get monitor's corresponding keys in init_state @@ -128,9 +194,9 @@ def get_monitor_fitness(x: Dict[str, torch.Tensor]): # JIT workflow step vmap_state_step = vmap(state_step) - init_state = vmap_state_step.init_state(self.num_instances) + init_state = vmap_state_step.init_state(self.num_instances * self.num_repeats) self._workflow_step_: torch.jit.ScriptFunction = jit(vmap_state_step, trace=True, example_inputs=(init_state,)) - self._get_monitor_fitness_ = jit(get_monitor_fitness, trace=True, example_inputs=(init_state,)) + self._get_monitor_fitness_ = jit(vmap(get_monitor_fitness, trace=False), trace=True, example_inputs=(init_state,)) monitor.load_state_dict(monitor_state) # if no init step if type(workflow).init_step == Workflow.init_step: @@ -140,11 +206,24 @@ def get_monitor_fitness(x: Dict[str, torch.Tensor]): # otherwise, JIT workflow init step state_init_step = use_state(lambda: workflow.init_step) vmap_state_init_step = vmap(state_init_step) - self._init_state_ = vmap_state_init_step.init_state(self.num_instances) + self._init_state_ = vmap_state_init_step.init_state(self.num_instances * self.num_repeats) self._workflow_init_step_: torch.jit.ScriptFunction = jit( vmap_state_init_step, trace=True, example_inputs=(self._init_state_,) ) + # JIT expand hyperparameters + def expand_hyper_params(hyper_parameters: Dict[str, torch.Tensor]): + if num_repeats > 1: + return {k: v.repeat(num_repeats, *([1] * (v.ndim - 1))) for k, v in hyper_parameters.items()} + else: + return hyper_parameters + + self._expand_hp_ = jit( + expand_hyper_params, + trace=True, + example_inputs=({k: v for k, v in self._init_state_.items() if k in hyper_param_keys},), + ) + def evaluate(self, hyper_parameters: Dict[str, nn.Parameter]): """ Evaluate the fitness (given by the internal workflow's monitor) of the batch of hyper parameters by running the internal workflow. @@ -166,6 +245,7 @@ def evaluate(self, hyper_parameters: Dict[str, nn.Parameter]): state[k] = v.clone() else: state = self._init_state_ + hyper_parameters = self._expand_hp_(hyper_parameters) state.update(hyper_parameters) state = self._workflow_init_step_(state) for _ in range(self.iterations - 1): diff --git a/src/evox/problems/numerical/__init__.py b/src/evox/problems/numerical/__init__.py index 420e6dd67..4ee2f8efc 100644 --- a/src/evox/problems/numerical/__init__.py +++ b/src/evox/problems/numerical/__init__.py @@ -13,8 +13,14 @@ "DTLZ5", "DTLZ6", "DTLZ7", + "ZDT1", + "ZDT2", + "ZDT3", + "ZDT4", + "ZDT6", ] from .basic import Ackley, Griewank, Rastrigin, Rosenbrock, Schwefel, Sphere from .cec2022 import CEC2022 from .dtlz import DTLZ1, DTLZ2, DTLZ3, DTLZ4, DTLZ5, DTLZ6, DTLZ7 +from .zdt import ZDT1, ZDT2, ZDT3, ZDT4, ZDT6 diff --git a/src/evox/problems/numerical/zdt.py b/src/evox/problems/numerical/zdt.py new file mode 100644 index 000000000..68cffb487 --- /dev/null +++ b/src/evox/problems/numerical/zdt.py @@ -0,0 +1,106 @@ +from functools import partial + +import torch + +from ...core import Problem, jit_class + + +def _generic_zdt(f1, g, h, x): + f1_x = f1(x) + g_x = g(x) + return torch.stack([f1_x, g_x * h(f1_x, g_x)],dim=1) + + +class ZDTTestSuit(Problem): + def __init__(self, n: int, ref_num: int = 100): + super().__init__() + self.n = n + self._zdt = None + self.ref_num = ref_num + + def evaluate(self, X: torch.Tensor): + return self._zdt(X) + + def pf(self): + x = torch.linspace(0, 1, self.ref_num) + return torch.stack([x, 1 - torch.sqrt(x)], dim=1) + + +@jit_class +class ZDT1(ZDTTestSuit): + def __init__(self, n): + super().__init__(n) + f1 = lambda x: x[:,0] + g = lambda x: 1 + 9 * torch.mean(x[:,1:]) + h = lambda f1, g: 1 - torch.sqrt(f1 / g) + self._zdt = partial(_generic_zdt, f1, g, h) + + +@jit_class +class ZDT2(ZDTTestSuit): + def __init__(self, n): + super().__init__(n) + f1 = lambda x: x[:,0] + g = lambda x: 1 + 9 * torch.mean(x[:,1:]) + h = lambda f1_val, g_val: 1 - (f1_val / g_val) ** 2 + self._zdt = partial(_generic_zdt, f1, g, h) + + def pf(self): + x = torch.linspace(0, 1, self.ref_num) + return torch.stack([x, 1 - x**2], dim=1) + + +@jit_class +class ZDT3(ZDTTestSuit): + def __init__(self, n): + super().__init__(n) + f1 = lambda x: x[:,0] + g = lambda x: 1 + 9 * torch.mean(x[:,1:]) + h = lambda f1, g: 1 - torch.sqrt(f1 / g) - (f1 / g) * torch.sin(10 * torch.pi * f1) + self._zdt = partial(_generic_zdt, f1, g, h) + + def pf(self): + r = torch.tensor( + [ + [0.0000, 0.0830], + [0.1822, 0.2577], + [0.4093, 0.4538], + [0.6183, 0.6525], + [0.8233, 0.8518], + ] + ) + + pf_points = [] + segment_size = self.ref_num // len(r) + for row in r: + x_vals = torch.linspace(row[0].item(), row[1].item(), segment_size) + f2_vals = 1 - torch.sqrt(x_vals) - x_vals * torch.sin(10 * torch.pi * x_vals) + pf_points.append(torch.stack([x_vals, f2_vals], dim=1)) + pf = torch.cat(pf_points, dim=0) + return pf + + +@jit_class +class ZDT4(ZDTTestSuit): + def __init__(self, n): + super().__init__(n) + f1 = lambda x: x[:,0] + g= lambda x: 1 + 10 * (self.n - 1) + torch.sum(x[:,1:] ** 2 - 10.0 * torch.cos(4.0 * torch.pi * x[:,1:])) + h = lambda f1_val, g_val: 1 - torch.sqrt(f1_val / g_val) + self._zdt = partial(_generic_zdt, f1, g, h) + + +@jit_class +class ZDT6(ZDTTestSuit): + def __init__(self, n): + super().__init__(n) + f1 = lambda x: 1 - torch.exp(-4.0 * x[:,0]) * torch.sin(6.0 * torch.pi * x[:,0]) ** 6 + g = lambda x: 1 + 9.0 * (torch.sum(x[:,1:]) / 9.0) ** 0.25 + h = lambda f1_val, g_val: 1 - (f1_val / g_val) ** 2 + self._zdt = partial(_generic_zdt, f1, g, h) + + def pf(self): + min_f1 = 0.280775 + f1_vals = torch.linspace(min_f1, 1.0, self.ref_num) + f2_vals = 1.0 - f1_vals**2 + return torch.stack([f1_vals, f2_vals], dim=1) diff --git a/unit_test/problems/test_dtlz.py b/unit_test/problems/test_dtlz.py index f9ceba7bf..56d5a7195 100644 --- a/unit_test/problems/test_dtlz.py +++ b/unit_test/problems/test_dtlz.py @@ -5,7 +5,7 @@ from evox.problems.numerical import DTLZ1, DTLZ2, DTLZ3, DTLZ4, DTLZ5, DTLZ6, DTLZ7 -class TestBraxProblem(TestCase): +class TestDTLZ(TestCase): def setUp(self): d = 12 m = 3 @@ -26,3 +26,8 @@ def test_dtlz(self): assert fit.size() == (2, 3) pf = pro.pf() assert pf.size(1) == 3 + +if __name__ == "__main__": + test = TestDTLZ() + test.setUp() + test.test_dtlz() diff --git a/unit_test/problems/test_zdt.py b/unit_test/problems/test_zdt.py new file mode 100644 index 000000000..83f09c3ab --- /dev/null +++ b/unit_test/problems/test_zdt.py @@ -0,0 +1,31 @@ +from unittest import TestCase + +import torch + +from evox.problems.numerical import ZDT1, ZDT2, ZDT3, ZDT4, ZDT6 + + +class TestZDT(TestCase): + + def __init__(self, methodName = "runTest"): + super().__init__(methodName) + self.n = 12 + def setUp(self): + self.pro = [ + ZDT1(n=self.n), + ZDT2(n=self.n), + ZDT3(n=self.n), + ZDT4(n=self.n), + ZDT6(n=self.n), + ] + + def test_zdt(self): + pop = torch.rand(7, self.n) + for pro in self.pro: + print(f"pro: {pro}") + fit = pro.evaluate(pop) + print(f"fit.size(): {fit.size()}") + assert fit.size() == (7, 2) + pf = pro.pf() + print(f"pf.size(): {pf.size()}") + assert pf.size(1) == 2 \ No newline at end of file