From 23e8f624e562dc9d0d364d8f75ab25bb37c99a86 Mon Sep 17 00:00:00 2001 From: sukhijab Date: Fri, 3 Nov 2023 22:00:23 +0100 Subject: [PATCH] added key to baseoptimizer state --- mbpo/optimizers/policy_optimizers/bptt_optimizer.py | 1 - mbpo/optimizers/policy_optimizers/brax_optimizers.py | 1 - mbpo/optimizers/trajectory_optimizers/icem_optimizer.py | 1 - mbpo/utils/type_aliases.py | 2 ++ 4 files changed, 2 insertions(+), 3 deletions(-) diff --git a/mbpo/optimizers/policy_optimizers/bptt_optimizer.py b/mbpo/optimizers/policy_optimizers/bptt_optimizer.py index ff4f0ec..e34d1d4 100644 --- a/mbpo/optimizers/policy_optimizers/bptt_optimizer.py +++ b/mbpo/optimizers/policy_optimizers/bptt_optimizer.py @@ -86,7 +86,6 @@ class BPTTState(OptimizerState, Generic[DynamicsParams, RewardParams]): target_critic_params: Any state_normalizer_state: NormalizerState reward_normalizer_state: NormalizerState - key: jax.random.PRNGKeyArray @chex.dataclass diff --git a/mbpo/optimizers/policy_optimizers/brax_optimizers.py b/mbpo/optimizers/policy_optimizers/brax_optimizers.py index 5df4163..7751147 100644 --- a/mbpo/optimizers/policy_optimizers/brax_optimizers.py +++ b/mbpo/optimizers/policy_optimizers/brax_optimizers.py @@ -21,7 +21,6 @@ @chex.dataclass class BraxState(OptimizerState, Generic[DynamicsParams, RewardParams]): policy_params: PyTree - key: chex.PRNGKey @chex.dataclass diff --git a/mbpo/optimizers/trajectory_optimizers/icem_optimizer.py b/mbpo/optimizers/trajectory_optimizers/icem_optimizer.py index da1a9c1..c3dda14 100755 --- a/mbpo/optimizers/trajectory_optimizers/icem_optimizer.py +++ b/mbpo/optimizers/trajectory_optimizers/icem_optimizer.py @@ -169,7 +169,6 @@ class iCemParams(NamedTuple): class iCemOptimizerState(OptimizerState, Generic[DynamicsParams, RewardParams]): best_sequence: chex.Array best_reward: chex.Array - key: jax.random.PRNGKey @property def action(self): diff --git a/mbpo/utils/type_aliases.py b/mbpo/utils/type_aliases.py index f2ff3ae..02b5dd9 100644 --- a/mbpo/utils/type_aliases.py +++ b/mbpo/utils/type_aliases.py @@ -4,12 +4,14 @@ from mbpo.systems.base_systems import SystemParams from brax.training.replay_buffers import ReplayBufferState import chex +import jax @chex.dataclass class OptimizerState(Generic[DynamicsParams, RewardParams]): true_buffer_state: ReplayBufferState system_params: SystemParams[DynamicsParams, RewardParams] + key: jax.random.PRNGKey @chex.dataclass