Skip to content

Commit

Permalink
Fix some test fail due to incorrect usage of JIT; change Brax import …
Browse files Browse the repository at this point in the history
…to prevent unnecessary import JAX
  • Loading branch information
sses7757 committed Jan 9, 2025
1 parent 2c354ef commit 5e7ad66
Show file tree
Hide file tree
Showing 9 changed files with 20 additions and 25 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,4 @@ cython_debug/
# Test
tests
**/_future/*
data/MNIST
4 changes: 0 additions & 4 deletions src/evox/problems/neuroevolution/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +0,0 @@
__all__ = ["BraxProblem", "SupervisedLearningProblem"]

from .brax import BraxProblem
from .supervised_learning import SupervisedLearningProblem
16 changes: 14 additions & 2 deletions src/evox/problems/neuroevolution/brax.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,26 @@
__all__ = ["BraxProblem"]

from typing import Callable, Dict, Tuple

import jax
import jax.numpy as jnp
import torch
import torch.nn as nn
import torch.utils.dlpack
from brax import envs
from torch.utils.dlpack import from_dlpack, to_dlpack

from ...core import Problem, jit_class
from .utils import from_jax_array, get_vmap_model_state_forward, to_jax_array
from .utils import get_vmap_model_state_forward


def to_jax_array(x: torch.Tensor) -> jax.Array:
return jax.dlpack.from_dlpack(to_dlpack(x.detach()))


def from_jax_array(x: jax.Array) -> torch.Tensor:
return from_dlpack(jax.dlpack.to_dlpack(x, take_ownership=True))


__brax_data__: Dict[
int,
Expand Down Expand Up @@ -113,7 +126,6 @@ def __init__(
def __del__(self):
global __brax_data__
__brax_data__.pop(self._hash_id_, None)
super().__del__()

def evaluate(self, pop_params: Dict[str, nn.Parameter]) -> torch.Tensor:
"""Evaluate the final rewards of a population (batch) of model parameters.
Expand Down
3 changes: 2 additions & 1 deletion src/evox/problems/neuroevolution/supervised_learning.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
__all__ = ["SupervisedLearningProblem"]

from typing import Dict, Iterable, Iterator, Tuple

import torch
Expand Down Expand Up @@ -122,7 +124,6 @@ def __init__(
def __del__(self):
global __supervised_data__
__supervised_data__.pop(self._hash_id_, None)
super().__del__()

@torch.jit.ignore
def _data_loader_reset(self) -> None:
Expand Down
11 changes: 0 additions & 11 deletions src/evox/problems/neuroevolution/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@
import types
from typing import Any, Callable, Dict

import jax
import torch
import torch.nn as nn
import torch.utils.dlpack
from torch.utils.dlpack import from_dlpack, to_dlpack

from ...core import jit, use_state, vmap
from ...core.module import assign_load_state_dict
Expand Down Expand Up @@ -87,11 +84,3 @@ def get_vmap_model_state_forward(
param_to_state_key_map,
model_buffers,
)


def to_jax_array(x: torch.Tensor) -> jax.Array:
return jax.dlpack.from_dlpack(to_dlpack(x.detach()))


def from_jax_array(x: jax.Array) -> torch.Tensor:
return from_dlpack(jax.dlpack.to_dlpack(x, take_ownership=True))
4 changes: 0 additions & 4 deletions unit_test/core/test_vmap_fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,3 @@ def test_distance_fn_with_mask(self):
def test_distance_fn_without_mask(self):
distances = jit(distance_fn, trace=True, lazy=False, example_inputs=(self.costs,))
self.assertIsNotNone(distances(self.costs))

def test_distance_fn_with_none(self):
distances = jit(distance_fn, trace=True, lazy=False, example_inputs=(self.costs, None))
self.assertIsNotNone(distances(self.costs, None))
2 changes: 1 addition & 1 deletion unit_test/problems/test_brax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn as nn

from evox.algorithms import PSO
from evox.problems.neuroevolution import BraxProblem
from evox.problems.neuroevolution.brax import BraxProblem
from evox.utils import ParamsAndVector
from evox.workflows import EvalMonitor, StdWorkflow

Expand Down
2 changes: 1 addition & 1 deletion unit_test/problems/test_supervised_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from evox.algorithms import PSO
from evox.core import Algorithm, Parameter, jit_class
from evox.problems.neuroevolution import SupervisedLearningProblem
from evox.problems.neuroevolution.supervised_learning import SupervisedLearningProblem
from evox.utils import ParamsAndVector
from evox.workflows import EvalMonitor, StdWorkflow

Expand Down
2 changes: 1 addition & 1 deletion unit_test/utils/test_jit_fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ def test_basic_switch(self):
def test_vmap_switch(self):
x = torch.randint(low=0, high=10, size=(2, 10), dtype=torch.int)
y = [torch.rand(2, 10) for _ in range(10)]
vmap_switch = jit(switch, trace=False)
vmap_switch = jit(switch, trace=False, lazy=True)
z = vmap_switch(x, y)
self.assertIsNotNone(z)

0 comments on commit 5e7ad66

Please sign in to comment.