diff --git a/src/emcee/ensemble.py b/src/emcee/ensemble.py index 9a9a4216..af0ddc06 100644 --- a/src/emcee/ensemble.py +++ b/src/emcee/ensemble.py @@ -160,7 +160,7 @@ def __init__( # Get the last random state state = self.backend.random_state if state is not None: - self._random.bit_generator.state = state + self.random_state = state # Grab the last step so that we can restart it = self.backend.iteration @@ -225,7 +225,12 @@ def random_state(self): def rng_dict(rng): bg_state = rng.bit_generator.state ss = rng.bit_generator.seed_seq - ss_dict = dict(entropy=ss.entropy, spawn_key=ss.spawn_key, pool_size=ss.pool_size, n_children_spawned=ss.n_children_spawned) + ss_dict = dict( + entropy=ss.entropy, + spawn_key=ss.spawn_key, + pool_size=ss.pool_size, + n_children_spawned=ss.n_children_spawned + ) return dict(bg_state=bg_state, seed_seq=ss_dict) return rng_dict(self._random) # return self._random.bit_generator.state diff --git a/src/emcee/tests/unit/test_backends.py b/src/emcee/tests/unit/test_backends.py index b01f3995..e1ec5484 100644 --- a/src/emcee/tests/unit/test_backends.py +++ b/src/emcee/tests/unit/test_backends.py @@ -124,7 +124,7 @@ def test_backend(backend, dtype, blobs): last2 = sampler2.get_last_sample() assert np.allclose(last1.coords, last2.coords) assert np.allclose(last1.log_prob, last2.log_prob) - assert last1.random_state == last2.random_state + assert last1.random_state['bg_state'] == last2.random_state['bg_state'] if blobs: _custom_allclose(last1.blobs, last2.blobs) else: @@ -192,7 +192,7 @@ def test_restart(backend, dtype): last2 = sampler2.get_last_sample() assert np.allclose(last1.coords, last2.coords) assert np.allclose(last1.log_prob, last2.log_prob) - assert last1.random_state == last2.random_state + assert last1.random_state['bg_state'] == last2.random_state['bg_state'] _custom_allclose(last1.blobs, last2.blobs) a = sampler1.acceptance_fraction