Skip to content

Commit

Permalink
added key to baseoptimizer state
Browse files Browse the repository at this point in the history
  • Loading branch information
sukhijab committed Nov 3, 2023
1 parent eb2e344 commit 23e8f62
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 3 deletions.
1 change: 0 additions & 1 deletion mbpo/optimizers/policy_optimizers/bptt_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ class BPTTState(OptimizerState, Generic[DynamicsParams, RewardParams]):
target_critic_params: Any
state_normalizer_state: NormalizerState
reward_normalizer_state: NormalizerState
key: jax.random.PRNGKeyArray


@chex.dataclass
Expand Down
1 change: 0 additions & 1 deletion mbpo/optimizers/policy_optimizers/brax_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
@chex.dataclass
class BraxState(OptimizerState, Generic[DynamicsParams, RewardParams]):
policy_params: PyTree
key: chex.PRNGKey


@chex.dataclass
Expand Down
1 change: 0 additions & 1 deletion mbpo/optimizers/trajectory_optimizers/icem_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ class iCemParams(NamedTuple):
class iCemOptimizerState(OptimizerState, Generic[DynamicsParams, RewardParams]):
best_sequence: chex.Array
best_reward: chex.Array
key: jax.random.PRNGKey

@property
def action(self):
Expand Down
2 changes: 2 additions & 0 deletions mbpo/utils/type_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
from mbpo.systems.base_systems import SystemParams
from brax.training.replay_buffers import ReplayBufferState
import chex
import jax


@chex.dataclass
class OptimizerState(Generic[DynamicsParams, RewardParams]):
true_buffer_state: ReplayBufferState
system_params: SystemParams[DynamicsParams, RewardParams]
key: jax.random.PRNGKey


@chex.dataclass
Expand Down

0 comments on commit 23e8f62

Please sign in to comment.