Skip to content

Commit

Permalink
Add custom record_step for es variants
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Feb 3, 2025
1 parent 14868c0 commit f37f74d
Show file tree
Hide file tree
Showing 11 changed files with 58 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/evox/algorithms/es_variants/ars.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class ARS(Algorithm):
More information about evosax can be found at the following URL:
GitHub Link: https://github.com/RobertTLange/evosax
"""

def __init__(
self,
pop_size: int,
Expand Down Expand Up @@ -95,3 +96,6 @@ def step(self):
self.lr,
)
self.center = center

def record_step(self):
return {"center": self.center}
8 changes: 8 additions & 0 deletions src/evox/algorithms/es_variants/asebo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class ASEBO(Algorithm):
More information about evosax can be found at the following URL:
GitHub Link: https://github.com/RobertTLange/evosax
"""

def __init__(
self,
pop_size: int,
Expand Down Expand Up @@ -154,3 +155,10 @@ def step(self):
self.center = center
self.sigma = sigma
self.alpha = alpha

def record_step(self):
return {
"center": self.center,
"sigma": self.sigma,
"alpha": self.alpha,
}
6 changes: 6 additions & 0 deletions src/evox/algorithms/es_variants/cma_es.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,9 @@ def _decomposition(
D = torch.sqrt(D)
D = B @ D
return B.T, D, C_invsqrt

def record_step(self):
return {
"mean": self.mean,
"sigma": self.sigma,
}
7 changes: 7 additions & 0 deletions src/evox/algorithms/es_variants/des.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class DES(Algorithm):
More information about evosax can be found at the following URL:
GitHub Link: https://github.com/RobertTLange/evosax
"""

def __init__(
self,
pop_size: int,
Expand Down Expand Up @@ -72,3 +73,9 @@ def step(self):

self.center = center
self.sigma = sigma

def record_step(self):
return {
"center": self.center,
"sigma": self.sigma,
}
4 changes: 4 additions & 0 deletions src/evox/algorithms/es_variants/esmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class ESMC(Algorithm):
More information about evosax can be found at the following URL:
GitHub Link: https://github.com/RobertTLange/evosax
"""

def __init__(
self,
pop_size: int,
Expand Down Expand Up @@ -106,3 +107,6 @@ def step(self):

sigma = torch.maximum(self.sigma * self.sigma_decay, self.sigma_limit)
self.sigma = sigma

def record_step(self):
return {"center": self.center, "sigma": self.sigma}
4 changes: 4 additions & 0 deletions src/evox/algorithms/es_variants/guided_es.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class GuidedES(Algorithm):
More information about evosax can be found at the following URL:
GitHub Link: https://github.com/RobertTLange/evosax
"""

def __init__(
self,
pop_size: int,
Expand Down Expand Up @@ -119,3 +120,6 @@ def step(self):

sigma = torch.maximum(self.sigma_decay * self.sigma, self.sigma_limit)
self.sigma = sigma

def record_step(self):
return {"center": self.center, "sigma": self.sigma}
8 changes: 8 additions & 0 deletions src/evox/algorithms/es_variants/nes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class XNES(Algorithm):
Exponential Natural Evolution Strategies
(https://dl.acm.org/doi/abs/10.1145/1830483.1830557)
"""

def __init__(
self,
init_mean: torch.Tensor,
Expand Down Expand Up @@ -114,6 +115,9 @@ def step(self):
self.mean = mean
self.B = B

def record_step(self):
return {"mean": self.mean, "sigma": self.sigma, "B": self.B}


@jit_class
class SeparableNES(Algorithm):
Expand All @@ -123,6 +127,7 @@ class SeparableNES(Algorithm):
Natural Evolution Strategies
(https://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf)
"""

def __init__(
self,
init_mean: torch.Tensor,
Expand Down Expand Up @@ -204,3 +209,6 @@ def step(self):

self.mean = mean
self.sigma = sigma

def record_step(self):
return {"mean": self.mean, "sigma": self.sigma}
4 changes: 4 additions & 0 deletions src/evox/algorithms/es_variants/noise_reuse_es.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class NoiseReuseES(Algorithm):
More information about evosax can be found at the following URL:
GitHub Link: https://github.com/RobertTLange/evosax
"""

def __init__(
self,
pop_size: int,
Expand Down Expand Up @@ -114,3 +115,6 @@ def step(self):

sigma = torch.maximum(self.sigma_decay * self.sigma, self.sigma_limit)
self.sigma = sigma

def record_step(self):
return {"center": self.center, "sigma": self.sigma}
3 changes: 3 additions & 0 deletions src/evox/algorithms/es_variants/open_es.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,6 @@ def step(self):
self.learning_rate,
)
self.center = center

def record_step(self):
return {"center": self.center}
4 changes: 4 additions & 0 deletions src/evox/algorithms/es_variants/persistent_es.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class PersistentES(Algorithm):
More information about evosax can be found at the following URL:
GitHub Link: https://github.com/RobertTLange/evosax
"""

def __init__(
self,
pop_size: int,
Expand Down Expand Up @@ -109,3 +110,6 @@ def step(self):

self.sigma = sigma
self.pert_accum = pert_accum

def record_step(self):
return {"center": self.center, "sigma": self.sigma}
6 changes: 6 additions & 0 deletions src/evox/algorithms/es_variants/snes.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,9 @@ def step(self):

self.center = center
self.sigma = sigma

def record_step(self):
return {
"center": self.center,
"sigma": self.sigma,
}

0 comments on commit f37f74d

Please sign in to comment.