diff --git a/.gitignore b/.gitignore index 5c443473f..13eb8b573 100644 --- a/.gitignore +++ b/.gitignore @@ -167,3 +167,4 @@ cython_debug/ # Test tests **/_future/* +data/MNIST \ No newline at end of file diff --git a/src/evox/problems/neuroevolution/__init__.py b/src/evox/problems/neuroevolution/__init__.py index 1429f6f79..e69de29bb 100644 --- a/src/evox/problems/neuroevolution/__init__.py +++ b/src/evox/problems/neuroevolution/__init__.py @@ -1,4 +0,0 @@ -__all__ = ["BraxProblem", "SupervisedLearningProblem"] - -from .brax import BraxProblem -from .supervised_learning import SupervisedLearningProblem diff --git a/src/evox/problems/neuroevolution/brax.py b/src/evox/problems/neuroevolution/brax.py index 67210e8c0..9ae4fba2b 100644 --- a/src/evox/problems/neuroevolution/brax.py +++ b/src/evox/problems/neuroevolution/brax.py @@ -1,13 +1,26 @@ +__all__ = ["BraxProblem"] + from typing import Callable, Dict, Tuple import jax import jax.numpy as jnp import torch import torch.nn as nn +import torch.utils.dlpack from brax import envs +from torch.utils.dlpack import from_dlpack, to_dlpack from ...core import Problem, jit_class -from .utils import from_jax_array, get_vmap_model_state_forward, to_jax_array +from .utils import get_vmap_model_state_forward + + +def to_jax_array(x: torch.Tensor) -> jax.Array: + return jax.dlpack.from_dlpack(to_dlpack(x.detach())) + + +def from_jax_array(x: jax.Array) -> torch.Tensor: + return from_dlpack(jax.dlpack.to_dlpack(x, take_ownership=True)) + __brax_data__: Dict[ int, @@ -113,7 +126,6 @@ def __init__( def __del__(self): global __brax_data__ __brax_data__.pop(self._hash_id_, None) - super().__del__() def evaluate(self, pop_params: Dict[str, nn.Parameter]) -> torch.Tensor: """Evaluate the final rewards of a population (batch) of model parameters. diff --git a/src/evox/problems/neuroevolution/supervised_learning.py b/src/evox/problems/neuroevolution/supervised_learning.py index eb7459dfe..1b8ca9f0a 100644 --- a/src/evox/problems/neuroevolution/supervised_learning.py +++ b/src/evox/problems/neuroevolution/supervised_learning.py @@ -1,3 +1,5 @@ +__all__ = ["SupervisedLearningProblem"] + from typing import Dict, Iterable, Iterator, Tuple import torch @@ -122,7 +124,6 @@ def __init__( def __del__(self): global __supervised_data__ __supervised_data__.pop(self._hash_id_, None) - super().__del__() @torch.jit.ignore def _data_loader_reset(self) -> None: diff --git a/src/evox/problems/neuroevolution/utils.py b/src/evox/problems/neuroevolution/utils.py index 2e81db5a0..1c552775e 100644 --- a/src/evox/problems/neuroevolution/utils.py +++ b/src/evox/problems/neuroevolution/utils.py @@ -2,11 +2,8 @@ import types from typing import Any, Callable, Dict -import jax import torch import torch.nn as nn -import torch.utils.dlpack -from torch.utils.dlpack import from_dlpack, to_dlpack from ...core import jit, use_state, vmap from ...core.module import assign_load_state_dict @@ -87,11 +84,3 @@ def get_vmap_model_state_forward( param_to_state_key_map, model_buffers, ) - - -def to_jax_array(x: torch.Tensor) -> jax.Array: - return jax.dlpack.from_dlpack(to_dlpack(x.detach())) - - -def from_jax_array(x: jax.Array) -> torch.Tensor: - return from_dlpack(jax.dlpack.to_dlpack(x, take_ownership=True)) diff --git a/unit_test/core/test_vmap_fix.py b/unit_test/core/test_vmap_fix.py index 19e89d569..1410e02cd 100644 --- a/unit_test/core/test_vmap_fix.py +++ b/unit_test/core/test_vmap_fix.py @@ -51,7 +51,3 @@ def test_distance_fn_with_mask(self): def test_distance_fn_without_mask(self): distances = jit(distance_fn, trace=True, lazy=False, example_inputs=(self.costs,)) self.assertIsNotNone(distances(self.costs)) - - def test_distance_fn_with_none(self): - distances = jit(distance_fn, trace=True, lazy=False, example_inputs=(self.costs, None)) - self.assertIsNotNone(distances(self.costs, None)) diff --git a/unit_test/problems/test_brax.py b/unit_test/problems/test_brax.py index 468f79e35..fcd11f10c 100644 --- a/unit_test/problems/test_brax.py +++ b/unit_test/problems/test_brax.py @@ -5,7 +5,7 @@ import torch.nn as nn from evox.algorithms import PSO -from evox.problems.neuroevolution import BraxProblem +from evox.problems.neuroevolution.brax import BraxProblem from evox.utils import ParamsAndVector from evox.workflows import EvalMonitor, StdWorkflow diff --git a/unit_test/problems/test_supervised_learning.py b/unit_test/problems/test_supervised_learning.py index 03bdf38cf..40877ee29 100644 --- a/unit_test/problems/test_supervised_learning.py +++ b/unit_test/problems/test_supervised_learning.py @@ -9,7 +9,7 @@ from evox.algorithms import PSO from evox.core import Algorithm, Parameter, jit_class -from evox.problems.neuroevolution import SupervisedLearningProblem +from evox.problems.neuroevolution.supervised_learning import SupervisedLearningProblem from evox.utils import ParamsAndVector from evox.workflows import EvalMonitor, StdWorkflow diff --git a/unit_test/utils/test_jit_fix.py b/unit_test/utils/test_jit_fix.py index 4171942dc..f24808e15 100644 --- a/unit_test/utils/test_jit_fix.py +++ b/unit_test/utils/test_jit_fix.py @@ -18,6 +18,6 @@ def test_basic_switch(self): def test_vmap_switch(self): x = torch.randint(low=0, high=10, size=(2, 10), dtype=torch.int) y = [torch.rand(2, 10) for _ in range(10)] - vmap_switch = jit(switch, trace=False) + vmap_switch = jit(switch, trace=False, lazy=True) z = vmap_switch(x, y) self.assertIsNotNone(z)