From f37f74d332b687856c01a5f311ace590fb7cd075 Mon Sep 17 00:00:00 2001 From: Bill Huang Date: Mon, 3 Feb 2025 15:40:25 +0800 Subject: [PATCH] Add custom record_step for es variants --- src/evox/algorithms/es_variants/ars.py | 4 ++++ src/evox/algorithms/es_variants/asebo.py | 8 ++++++++ src/evox/algorithms/es_variants/cma_es.py | 6 ++++++ src/evox/algorithms/es_variants/des.py | 7 +++++++ src/evox/algorithms/es_variants/esmc.py | 4 ++++ src/evox/algorithms/es_variants/guided_es.py | 4 ++++ src/evox/algorithms/es_variants/nes.py | 8 ++++++++ src/evox/algorithms/es_variants/noise_reuse_es.py | 4 ++++ src/evox/algorithms/es_variants/open_es.py | 3 +++ src/evox/algorithms/es_variants/persistent_es.py | 4 ++++ src/evox/algorithms/es_variants/snes.py | 6 ++++++ 11 files changed, 58 insertions(+) diff --git a/src/evox/algorithms/es_variants/ars.py b/src/evox/algorithms/es_variants/ars.py index 7168be93d..30333426e 100644 --- a/src/evox/algorithms/es_variants/ars.py +++ b/src/evox/algorithms/es_variants/ars.py @@ -17,6 +17,7 @@ class ARS(Algorithm): More information about evosax can be found at the following URL: GitHub Link: https://github.com/RobertTLange/evosax """ + def __init__( self, pop_size: int, @@ -95,3 +96,6 @@ def step(self): self.lr, ) self.center = center + + def record_step(self): + return {"center": self.center} diff --git a/src/evox/algorithms/es_variants/asebo.py b/src/evox/algorithms/es_variants/asebo.py index 52dd90c40..283f06782 100644 --- a/src/evox/algorithms/es_variants/asebo.py +++ b/src/evox/algorithms/es_variants/asebo.py @@ -18,6 +18,7 @@ class ASEBO(Algorithm): More information about evosax can be found at the following URL: GitHub Link: https://github.com/RobertTLange/evosax """ + def __init__( self, pop_size: int, @@ -154,3 +155,10 @@ def step(self): self.center = center self.sigma = sigma self.alpha = alpha + + def record_step(self): + return { + "center": self.center, + "sigma": self.sigma, + "alpha": self.alpha, + } diff --git a/src/evox/algorithms/es_variants/cma_es.py b/src/evox/algorithms/es_variants/cma_es.py index 9ad036914..1e40e171e 100644 --- a/src/evox/algorithms/es_variants/cma_es.py +++ b/src/evox/algorithms/es_variants/cma_es.py @@ -182,3 +182,9 @@ def _decomposition( D = torch.sqrt(D) D = B @ D return B.T, D, C_invsqrt + + def record_step(self): + return { + "mean": self.mean, + "sigma": self.sigma, + } diff --git a/src/evox/algorithms/es_variants/des.py b/src/evox/algorithms/es_variants/des.py index 4afda1ece..d8f2c72e2 100644 --- a/src/evox/algorithms/es_variants/des.py +++ b/src/evox/algorithms/es_variants/des.py @@ -16,6 +16,7 @@ class DES(Algorithm): More information about evosax can be found at the following URL: GitHub Link: https://github.com/RobertTLange/evosax """ + def __init__( self, pop_size: int, @@ -72,3 +73,9 @@ def step(self): self.center = center self.sigma = sigma + + def record_step(self): + return { + "center": self.center, + "sigma": self.sigma, + } diff --git a/src/evox/algorithms/es_variants/esmc.py b/src/evox/algorithms/es_variants/esmc.py index 581986f37..412fd2592 100644 --- a/src/evox/algorithms/es_variants/esmc.py +++ b/src/evox/algorithms/es_variants/esmc.py @@ -18,6 +18,7 @@ class ESMC(Algorithm): More information about evosax can be found at the following URL: GitHub Link: https://github.com/RobertTLange/evosax """ + def __init__( self, pop_size: int, @@ -106,3 +107,6 @@ def step(self): sigma = torch.maximum(self.sigma * self.sigma_decay, self.sigma_limit) self.sigma = sigma + + def record_step(self): + return {"center": self.center, "sigma": self.sigma} diff --git a/src/evox/algorithms/es_variants/guided_es.py b/src/evox/algorithms/es_variants/guided_es.py index 6bb02e427..4e735495c 100644 --- a/src/evox/algorithms/es_variants/guided_es.py +++ b/src/evox/algorithms/es_variants/guided_es.py @@ -18,6 +18,7 @@ class GuidedES(Algorithm): More information about evosax can be found at the following URL: GitHub Link: https://github.com/RobertTLange/evosax """ + def __init__( self, pop_size: int, @@ -119,3 +120,6 @@ def step(self): sigma = torch.maximum(self.sigma_decay * self.sigma, self.sigma_limit) self.sigma = sigma + + def record_step(self): + return {"center": self.center, "sigma": self.sigma} diff --git a/src/evox/algorithms/es_variants/nes.py b/src/evox/algorithms/es_variants/nes.py index b8285ee0c..8b3be4b84 100644 --- a/src/evox/algorithms/es_variants/nes.py +++ b/src/evox/algorithms/es_variants/nes.py @@ -13,6 +13,7 @@ class XNES(Algorithm): Exponential Natural Evolution Strategies (https://dl.acm.org/doi/abs/10.1145/1830483.1830557) """ + def __init__( self, init_mean: torch.Tensor, @@ -114,6 +115,9 @@ def step(self): self.mean = mean self.B = B + def record_step(self): + return {"mean": self.mean, "sigma": self.sigma, "B": self.B} + @jit_class class SeparableNES(Algorithm): @@ -123,6 +127,7 @@ class SeparableNES(Algorithm): Natural Evolution Strategies (https://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf) """ + def __init__( self, init_mean: torch.Tensor, @@ -204,3 +209,6 @@ def step(self): self.mean = mean self.sigma = sigma + + def record_step(self): + return {"mean": self.mean, "sigma": self.sigma} diff --git a/src/evox/algorithms/es_variants/noise_reuse_es.py b/src/evox/algorithms/es_variants/noise_reuse_es.py index ab6b8ad78..7ec592164 100644 --- a/src/evox/algorithms/es_variants/noise_reuse_es.py +++ b/src/evox/algorithms/es_variants/noise_reuse_es.py @@ -18,6 +18,7 @@ class NoiseReuseES(Algorithm): More information about evosax can be found at the following URL: GitHub Link: https://github.com/RobertTLange/evosax """ + def __init__( self, pop_size: int, @@ -114,3 +115,6 @@ def step(self): sigma = torch.maximum(self.sigma_decay * self.sigma, self.sigma_limit) self.sigma = sigma + + def record_step(self): + return {"center": self.center, "sigma": self.sigma} diff --git a/src/evox/algorithms/es_variants/open_es.py b/src/evox/algorithms/es_variants/open_es.py index acc6a3e02..3baa6ebf5 100644 --- a/src/evox/algorithms/es_variants/open_es.py +++ b/src/evox/algorithms/es_variants/open_es.py @@ -81,3 +81,6 @@ def step(self): self.learning_rate, ) self.center = center + + def record_step(self): + return {"center": self.center} diff --git a/src/evox/algorithms/es_variants/persistent_es.py b/src/evox/algorithms/es_variants/persistent_es.py index 10f09777b..e213bda23 100644 --- a/src/evox/algorithms/es_variants/persistent_es.py +++ b/src/evox/algorithms/es_variants/persistent_es.py @@ -18,6 +18,7 @@ class PersistentES(Algorithm): More information about evosax can be found at the following URL: GitHub Link: https://github.com/RobertTLange/evosax """ + def __init__( self, pop_size: int, @@ -109,3 +110,6 @@ def step(self): self.sigma = sigma self.pert_accum = pert_accum + + def record_step(self): + return {"center": self.center, "sigma": self.sigma} diff --git a/src/evox/algorithms/es_variants/snes.py b/src/evox/algorithms/es_variants/snes.py index 380576c1d..c72af3cea 100644 --- a/src/evox/algorithms/es_variants/snes.py +++ b/src/evox/algorithms/es_variants/snes.py @@ -92,3 +92,9 @@ def step(self): self.center = center self.sigma = sigma + + def record_step(self): + return { + "center": self.center, + "sigma": self.sigma, + }