From e2536a9a9fb49028770c419185dd4dcd9d8c6c90 Mon Sep 17 00:00:00 2001 From: Andrew Nelson Date: Mon, 28 Oct 2024 17:35:37 +1100 Subject: [PATCH 1/4] ENH: use Generator instead of RandomState --- src/emcee/backends/hdf.py | 25 +++++++++++------ src/emcee/ensemble.py | 39 +++++++++++++++++---------- src/emcee/moves/de.py | 2 +- src/emcee/moves/de_snooker.py | 2 +- src/emcee/moves/gaussian.py | 6 ++--- src/emcee/moves/mh.py | 2 +- src/emcee/moves/red_blue.py | 2 +- src/emcee/moves/stretch.py | 4 +-- src/emcee/tests/unit/test_backends.py | 35 +++++++----------------- src/emcee/tests/unit/test_sampler.py | 6 ++--- src/emcee/tests/unit/test_stretch.py | 6 ++--- 11 files changed, 67 insertions(+), 62 deletions(-) diff --git a/src/emcee/backends/hdf.py b/src/emcee/backends/hdf.py index 90bc2ac2..1c575161 100644 --- a/src/emcee/backends/hdf.py +++ b/src/emcee/backends/hdf.py @@ -4,6 +4,7 @@ import os from tempfile import NamedTemporaryFile +import json import numpy as np @@ -19,6 +20,13 @@ h5py = None +class NumpyEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, np.ndarray): + return obj.tolist() + return super().default(obj) + + def does_hdf5_support_longdouble(): if h5py is None: return False @@ -193,12 +201,11 @@ def accepted(self): @property def random_state(self): with self.open() as f: - elements = [ - v - for k, v in sorted(f[self.name].attrs.items()) - if k.startswith("random_state_") - ] - return elements if len(elements) else None + try: + dct = json.loads(f[self.name].attrs['random_state']) + except KeyError: + return None + return dct def grow(self, ngrow, blobs): """Expand the storage space by some number of samples @@ -261,8 +268,10 @@ def save_step(self, state, accepted): g["blobs"][iteration, :] = state.blobs g["accepted"][:] += accepted - for i, v in enumerate(state.random_state): - g.attrs["random_state_{0}".format(i)] = v + g.attrs["random_state"] = json.dumps( + state.random_state, + cls=NumpyEncoder + ) g.attrs["iteration"] = iteration + 1 diff --git a/src/emcee/ensemble.py b/src/emcee/ensemble.py index c71f6ee5..160cd783 100644 --- a/src/emcee/ensemble.py +++ b/src/emcee/ensemble.py @@ -89,6 +89,7 @@ def __init__( vectorize=False, blobs_dtype=None, parameter_names: Optional[Union[Dict[str, int], List[str]]] = None, + rng = None, # Deprecated... a=None, postargs=None, @@ -136,11 +137,14 @@ def __init__( self.nwalkers = nwalkers self.backend = Backend() if backend is None else backend + # This is a random number generator that we can easily set the state + # of + self._random = np.random.default_rng(rng) + # Deal with re-used backends if not self.backend.initialized: self._previous_state = None self.reset() - state = np.random.get_state() else: # Check the backend shape if self.backend.shape != (self.nwalkers, self.ndim): @@ -153,19 +157,14 @@ def __init__( # Get the last random state state = self.backend.random_state - if state is None: - state = np.random.get_state() + if state is not None: + self._random.bit_generator.state = state # Grab the last step so that we can restart it = self.backend.iteration if it > 0: self._previous_state = self.get_last_sample() - # This is a random number generator that we can easily set the state - # of without affecting the numpy-wide generator - self._random = np.random.mtrand.RandomState() - self._random.set_state(state) - # Do a little bit of _magic_ to make the likelihood call with # ``args`` and ``kwargs`` pickleable. self.log_prob_fn = _FunctionWrapper(log_prob_fn, args, kwargs) @@ -216,14 +215,18 @@ def __init__( @property def random_state(self): """ - The state of the internal random number generator. In practice, it's - the result of calling ``get_state()`` on a - ``numpy.random.mtrand.RandomState`` object. You can try to set this + The state of the internal random number generator. You can try to set this property but be warned that if you do this and it fails, it will do so silently. """ - return self._random.get_state() + def rng_dict(rng): + bg_state = rng.bit_generator.state + ss = rng.bit_generator.seed_seq + ss_dict = dict(entropy=ss.entropy, spawn_key=ss.spawn_key, pool_size=ss.pool_size, n_children_spawned=ss.n_children_spawned) + return dict(bg_state=bg_state, seed_seq=ss_dict) + return rng_dict(self._random) + # return self._random.bit_generator.state @random_state.setter # NOQA def random_state(self, state): @@ -232,8 +235,16 @@ def random_state(self, state): if it doesn't work. Don't say I didn't warn you... """ + def _rng_fromdict(d): + bg_state = d['bg_state'] + ss = np.random.SeedSequence(**d['seed_seq']) + bg = getattr(np.random, bg_state['bit_generator'])(ss) + bg.state = bg_state + rng = np.random.Generator(bg) + return rng try: - self._random.set_state(state) + self._random = _rng_fromdict(state) + # self._random.bit_generator = state except: pass @@ -325,7 +336,7 @@ def sample( # Try to set the initial value of the random number generator. This # fails silently if it doesn't work but that's what we want because # we'll just interpret any garbage as letting the generator stay in - # it's current state. + # its current state. if rstate0 is not None: deprecation_warning( "The 'rstate0' argument is deprecated, use a 'State' " diff --git a/src/emcee/moves/de.py b/src/emcee/moves/de.py index 27105e0c..9f3807dd 100644 --- a/src/emcee/moves/de.py +++ b/src/emcee/moves/de.py @@ -53,7 +53,7 @@ def get_proposal(self, s, c, random): diffs = np.diff(c[pairs], axis=1).squeeze(axis=1) # (ns, ndim) # Sample a gamma value for each walker following Nelson et al. (2013) - gamma = self.g0 * (1 + self.sigma * random.randn(ns, 1)) # (ns, 1) + gamma = self.g0 * (1 + self.sigma * random.standard_normal((ns, 1))) # (ns, 1) # In this way, sigma is the standard deviation of the distribution of gamma, # instead of the standard deviation of the distribution of the proposal as proposed by Ter Braak (2006). diff --git a/src/emcee/moves/de_snooker.py b/src/emcee/moves/de_snooker.py index 00d5c50a..3e60455b 100644 --- a/src/emcee/moves/de_snooker.py +++ b/src/emcee/moves/de_snooker.py @@ -35,7 +35,7 @@ def get_proposal(self, s, c, random): q = np.empty_like(s) metropolis = np.empty(Ns, dtype=np.float64) for i in range(Ns): - w = np.array([c[j][random.randint(Nc[j])] for j in range(3)]) + w = np.array([c[j][random.integers(Nc[j])] for j in range(3)]) random.shuffle(w) z, z1, z2 = w delta = s[i] - z diff --git a/src/emcee/moves/gaussian.py b/src/emcee/moves/gaussian.py index c255fd87..ba9aa9cf 100644 --- a/src/emcee/moves/gaussian.py +++ b/src/emcee/moves/gaussian.py @@ -87,13 +87,13 @@ def get_factor(self, rng): return np.exp(rng.uniform(-self._log_factor, self._log_factor)) def get_updated_vector(self, rng, x0): - return x0 + self.get_factor(rng) * self.scale * rng.randn(*(x0.shape)) + return x0 + self.get_factor(rng) * self.scale * rng.standard_normal((x0.shape)) def __call__(self, x0, rng): nw, nd = x0.shape xnew = self.get_updated_vector(rng, x0) if self.mode == "random": - m = (range(nw), rng.randint(x0.shape[-1], size=nw)) + m = (range(nw), rng.integers(x0.shape[-1], size=nw)) elif self.mode == "sequential": m = (range(nw), self.index % nd + np.zeros(nw, dtype=int)) self.index = (self.index + 1) % nd @@ -106,7 +106,7 @@ def __call__(self, x0, rng): class _diagonal_proposal(_isotropic_proposal): def get_updated_vector(self, rng, x0): - return x0 + self.get_factor(rng) * self.scale * rng.randn(*(x0.shape)) + return x0 + self.get_factor(rng) * self.scale * rng.standard_normal((x0.shape)) class _proposal(_isotropic_proposal): diff --git a/src/emcee/moves/mh.py b/src/emcee/moves/mh.py index 4b190498..ac0fe875 100644 --- a/src/emcee/moves/mh.py +++ b/src/emcee/moves/mh.py @@ -56,7 +56,7 @@ def propose(self, model, state): # Loop over the walkers and update them accordingly. lnpdiff = new_log_probs - state.log_prob + factors - accepted = np.log(model.random.rand(nwalkers)) < lnpdiff + accepted = np.log(model.random.random(nwalkers)) < lnpdiff # Update the parameters new_state = State(q, log_prob=new_log_probs, blobs=new_blobs) diff --git a/src/emcee/moves/red_blue.py b/src/emcee/moves/red_blue.py index e6ab59d5..1f7f86d5 100644 --- a/src/emcee/moves/red_blue.py +++ b/src/emcee/moves/red_blue.py @@ -97,7 +97,7 @@ def propose(self, model, state): zip(all_inds[S1], factors, new_log_probs) ): lnpdiff = f + nlp - state.log_prob[j] - if lnpdiff > np.log(model.random.rand()): + if lnpdiff > np.log(model.random.random()): accepted[j] = True new_state = State(q, log_prob=new_log_probs, blobs=new_blobs) diff --git a/src/emcee/moves/stretch.py b/src/emcee/moves/stretch.py index 40d8e537..cf3e7764 100644 --- a/src/emcee/moves/stretch.py +++ b/src/emcee/moves/stretch.py @@ -27,7 +27,7 @@ def get_proposal(self, s, c, random): c = np.concatenate(c, axis=0) Ns, Nc = len(s), len(c) ndim = s.shape[1] - zz = ((self.a - 1.0) * random.rand(Ns) + 1) ** 2.0 / self.a + zz = ((self.a - 1.0) * random.random(Ns) + 1) ** 2.0 / self.a factors = (ndim - 1.0) * np.log(zz) - rint = random.randint(Nc, size=(Ns,)) + rint = random.integers(Nc, size=(Ns,)) return c[rint] - (c[rint] - s) * zz[:, None], factors diff --git a/src/emcee/tests/unit/test_backends.py b/src/emcee/tests/unit/test_backends.py index a55f2c58..b01f3995 100644 --- a/src/emcee/tests/unit/test_backends.py +++ b/src/emcee/tests/unit/test_backends.py @@ -43,11 +43,10 @@ def run_sampler( ): if lp is None: lp = normal_log_prob_blobs if blobs else normal_log_prob - if seed is not None: - np.random.seed(seed) - coords = np.random.randn(nwalkers, ndim) + rng = np.random.default_rng(seed) + coords = rng.standard_normal((nwalkers, ndim)) sampler = EnsembleSampler( - nwalkers, ndim, lp, backend=backend, blobs_dtype=dtype + nwalkers, ndim, lp, rng=rng, backend=backend, blobs_dtype=dtype ) sampler.run_mcmc(coords, nsteps, thin_by=thin_by) return sampler @@ -125,10 +124,7 @@ def test_backend(backend, dtype, blobs): last2 = sampler2.get_last_sample() assert np.allclose(last1.coords, last2.coords) assert np.allclose(last1.log_prob, last2.log_prob) - assert all( - np.allclose(l1, l2) - for l1, l2 in zip(last1.random_state[1:], last2.random_state[1:]) - ) + assert last1.random_state == last2.random_state if blobs: _custom_allclose(last1.blobs, last2.blobs) else: @@ -141,12 +137,11 @@ def test_backend(backend, dtype, blobs): @pytest.mark.parametrize("backend,dtype", product(other_backends, dtypes)) def test_reload(backend, dtype): - with backend() as backend1: + with (backend() as backend1): run_sampler(backend1, dtype=dtype) # Test the state state = backend1.random_state - np.random.set_state(state) # Load the file using a new backend object. backend2 = backends.HDFBackend( @@ -156,11 +151,7 @@ def test_reload(backend, dtype): with pytest.raises(RuntimeError): backend2.reset(32, 3) - assert state[0] == backend2.random_state[0] - assert all( - np.allclose(a, b) - for a, b in zip(state[1:], backend2.random_state[1:]) - ) + assert state == backend2.random_state # Check all of the components. for k in ["chain", "log_prob", "blobs"]: @@ -172,10 +163,7 @@ def test_reload(backend, dtype): last2 = backend2.get_last_sample() assert np.allclose(last1.coords, last2.coords) assert np.allclose(last1.log_prob, last2.log_prob) - assert all( - np.allclose(l1, l2) - for l1, l2 in zip(last1.random_state[1:], last2.random_state[1:]) - ) + assert last1.random_state == last2.random_state _custom_allclose(last1.blobs, last2.blobs) a = backend1.accepted @@ -188,11 +176,11 @@ def test_restart(backend, dtype): # Run a sampler with the default backend. b = backends.Backend() run_sampler(b, dtype=dtype) - sampler1 = run_sampler(b, seed=None, dtype=dtype) + sampler1 = run_sampler(b, seed=2, dtype=dtype) with backend() as be: run_sampler(be, dtype=dtype) - sampler2 = run_sampler(be, seed=None, dtype=dtype) + sampler2 = run_sampler(be, seed=2, dtype=dtype) # Check all of the components. for k in ["chain", "log_prob", "blobs"]: @@ -204,10 +192,7 @@ def test_restart(backend, dtype): last2 = sampler2.get_last_sample() assert np.allclose(last1.coords, last2.coords) assert np.allclose(last1.log_prob, last2.log_prob) - assert all( - np.allclose(l1, l2) - for l1, l2 in zip(last1.random_state[1:], last2.random_state[1:]) - ) + assert last1.random_state == last2.random_state _custom_allclose(last1.blobs, last2.blobs) a = sampler1.acceptance_fraction diff --git a/src/emcee/tests/unit/test_sampler.py b/src/emcee/tests/unit/test_sampler.py index f272d4ac..0cac139a 100644 --- a/src/emcee/tests/unit/test_sampler.py +++ b/src/emcee/tests/unit/test_sampler.py @@ -135,9 +135,9 @@ def run_sampler( progress=False, store=True, ): - np.random.seed(seed) - coords = np.random.randn(nwalkers, ndim) - sampler = EnsembleSampler(nwalkers, ndim, normal_log_prob, backend=backend) + rng = np.random.default_rng(seed) + coords = rng.standard_normal((nwalkers, ndim)) + sampler = EnsembleSampler(nwalkers, ndim, normal_log_prob, rng=rng, backend=backend) sampler.run_mcmc( coords, nsteps, diff --git a/src/emcee/tests/unit/test_stretch.py b/src/emcee/tests/unit/test_stretch.py index a75e0521..f2a6b317 100644 --- a/src/emcee/tests/unit/test_stretch.py +++ b/src/emcee/tests/unit/test_stretch.py @@ -16,12 +16,12 @@ def test_live_dangerously(nwalkers=32, nsteps=3000, seed=1234): warnings.filterwarnings("error") # Set up the random number generator. - np.random.seed(seed) + rng = np.random.default_rng(seed) state = State( - np.random.randn(nwalkers, 2 * nwalkers), + rng.standard_normal((nwalkers, 2 * nwalkers)), log_prob=np.random.randn(nwalkers), ) - model = Model(None, lambda x: (np.zeros(len(x)), None), map, np.random) + model = Model(None, lambda x: (np.zeros(len(x)), None), map, rng) proposal = moves.StretchMove() # Test to make sure that the error is thrown if there aren't enough From 2a0b394ff861a5525ea1421c309c2872d13092ff Mon Sep 17 00:00:00 2001 From: Andrew Nelson Date: Mon, 28 Oct 2024 17:45:24 +1100 Subject: [PATCH 2/4] DOC --- src/emcee/ensemble.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/emcee/ensemble.py b/src/emcee/ensemble.py index 160cd783..9a9a4216 100644 --- a/src/emcee/ensemble.py +++ b/src/emcee/ensemble.py @@ -73,6 +73,8 @@ class EnsembleSampler(object): names of individual parameters or groups of parameters. If specified, the ``log_prob_fn`` will recieve a dictionary of parameters, rather than a ``np.ndarray``. + rng (Optional): + int, :class:`np.random.Generator`, used for reproducibility. """ @@ -89,7 +91,7 @@ def __init__( vectorize=False, blobs_dtype=None, parameter_names: Optional[Union[Dict[str, int], List[str]]] = None, - rng = None, + rng=None, # Deprecated... a=None, postargs=None, From c5c58423107bb3301e8aae63f9c9cb4b497823ed Mon Sep 17 00:00:00 2001 From: Andrew Nelson Date: Mon, 28 Oct 2024 17:59:29 +1100 Subject: [PATCH 3/4] initial state setting --- src/emcee/ensemble.py | 9 +++++++-- src/emcee/tests/unit/test_backends.py | 4 ++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/emcee/ensemble.py b/src/emcee/ensemble.py index 9a9a4216..af0ddc06 100644 --- a/src/emcee/ensemble.py +++ b/src/emcee/ensemble.py @@ -160,7 +160,7 @@ def __init__( # Get the last random state state = self.backend.random_state if state is not None: - self._random.bit_generator.state = state + self.random_state = state # Grab the last step so that we can restart it = self.backend.iteration @@ -225,7 +225,12 @@ def random_state(self): def rng_dict(rng): bg_state = rng.bit_generator.state ss = rng.bit_generator.seed_seq - ss_dict = dict(entropy=ss.entropy, spawn_key=ss.spawn_key, pool_size=ss.pool_size, n_children_spawned=ss.n_children_spawned) + ss_dict = dict( + entropy=ss.entropy, + spawn_key=ss.spawn_key, + pool_size=ss.pool_size, + n_children_spawned=ss.n_children_spawned + ) return dict(bg_state=bg_state, seed_seq=ss_dict) return rng_dict(self._random) # return self._random.bit_generator.state diff --git a/src/emcee/tests/unit/test_backends.py b/src/emcee/tests/unit/test_backends.py index b01f3995..e1ec5484 100644 --- a/src/emcee/tests/unit/test_backends.py +++ b/src/emcee/tests/unit/test_backends.py @@ -124,7 +124,7 @@ def test_backend(backend, dtype, blobs): last2 = sampler2.get_last_sample() assert np.allclose(last1.coords, last2.coords) assert np.allclose(last1.log_prob, last2.log_prob) - assert last1.random_state == last2.random_state + assert last1.random_state['bg_state'] == last2.random_state['bg_state'] if blobs: _custom_allclose(last1.blobs, last2.blobs) else: @@ -192,7 +192,7 @@ def test_restart(backend, dtype): last2 = sampler2.get_last_sample() assert np.allclose(last1.coords, last2.coords) assert np.allclose(last1.log_prob, last2.log_prob) - assert last1.random_state == last2.random_state + assert last1.random_state['bg_state'] == last2.random_state['bg_state'] _custom_allclose(last1.blobs, last2.blobs) a = sampler1.acceptance_fraction From 6cafae1b736d81f1386d52a4158c62dc8425da0a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 28 Oct 2024 07:01:32 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/emcee/backends/hdf.py | 7 +++---- src/emcee/ensemble.py | 12 ++++++++---- src/emcee/moves/de.py | 4 +++- src/emcee/moves/gaussian.py | 8 ++++++-- src/emcee/tests/unit/test_backends.py | 6 +++--- src/emcee/tests/unit/test_sampler.py | 4 +++- 6 files changed, 26 insertions(+), 15 deletions(-) diff --git a/src/emcee/backends/hdf.py b/src/emcee/backends/hdf.py index 1c575161..278c15f1 100644 --- a/src/emcee/backends/hdf.py +++ b/src/emcee/backends/hdf.py @@ -2,9 +2,9 @@ from __future__ import division, print_function +import json import os from tempfile import NamedTemporaryFile -import json import numpy as np @@ -202,7 +202,7 @@ def accepted(self): def random_state(self): with self.open() as f: try: - dct = json.loads(f[self.name].attrs['random_state']) + dct = json.loads(f[self.name].attrs["random_state"]) except KeyError: return None return dct @@ -269,8 +269,7 @@ def save_step(self, state, accepted): g["accepted"][:] += accepted g.attrs["random_state"] = json.dumps( - state.random_state, - cls=NumpyEncoder + state.random_state, cls=NumpyEncoder ) g.attrs["iteration"] = iteration + 1 diff --git a/src/emcee/ensemble.py b/src/emcee/ensemble.py index af0ddc06..b539c1f1 100644 --- a/src/emcee/ensemble.py +++ b/src/emcee/ensemble.py @@ -222,6 +222,7 @@ def random_state(self): so silently. """ + def rng_dict(rng): bg_state = rng.bit_generator.state ss = rng.bit_generator.seed_seq @@ -229,9 +230,10 @@ def rng_dict(rng): entropy=ss.entropy, spawn_key=ss.spawn_key, pool_size=ss.pool_size, - n_children_spawned=ss.n_children_spawned + n_children_spawned=ss.n_children_spawned, ) return dict(bg_state=bg_state, seed_seq=ss_dict) + return rng_dict(self._random) # return self._random.bit_generator.state @@ -242,13 +244,15 @@ def random_state(self, state): if it doesn't work. Don't say I didn't warn you... """ + def _rng_fromdict(d): - bg_state = d['bg_state'] - ss = np.random.SeedSequence(**d['seed_seq']) - bg = getattr(np.random, bg_state['bit_generator'])(ss) + bg_state = d["bg_state"] + ss = np.random.SeedSequence(**d["seed_seq"]) + bg = getattr(np.random, bg_state["bit_generator"])(ss) bg.state = bg_state rng = np.random.Generator(bg) return rng + try: self._random = _rng_fromdict(state) # self._random.bit_generator = state diff --git a/src/emcee/moves/de.py b/src/emcee/moves/de.py index 9f3807dd..c357fac1 100644 --- a/src/emcee/moves/de.py +++ b/src/emcee/moves/de.py @@ -53,7 +53,9 @@ def get_proposal(self, s, c, random): diffs = np.diff(c[pairs], axis=1).squeeze(axis=1) # (ns, ndim) # Sample a gamma value for each walker following Nelson et al. (2013) - gamma = self.g0 * (1 + self.sigma * random.standard_normal((ns, 1))) # (ns, 1) + gamma = self.g0 * ( + 1 + self.sigma * random.standard_normal((ns, 1)) + ) # (ns, 1) # In this way, sigma is the standard deviation of the distribution of gamma, # instead of the standard deviation of the distribution of the proposal as proposed by Ter Braak (2006). diff --git a/src/emcee/moves/gaussian.py b/src/emcee/moves/gaussian.py index ba9aa9cf..d3f4fa26 100644 --- a/src/emcee/moves/gaussian.py +++ b/src/emcee/moves/gaussian.py @@ -87,7 +87,9 @@ def get_factor(self, rng): return np.exp(rng.uniform(-self._log_factor, self._log_factor)) def get_updated_vector(self, rng, x0): - return x0 + self.get_factor(rng) * self.scale * rng.standard_normal((x0.shape)) + return x0 + self.get_factor(rng) * self.scale * rng.standard_normal( + (x0.shape) + ) def __call__(self, x0, rng): nw, nd = x0.shape @@ -106,7 +108,9 @@ def __call__(self, x0, rng): class _diagonal_proposal(_isotropic_proposal): def get_updated_vector(self, rng, x0): - return x0 + self.get_factor(rng) * self.scale * rng.standard_normal((x0.shape)) + return x0 + self.get_factor(rng) * self.scale * rng.standard_normal( + (x0.shape) + ) class _proposal(_isotropic_proposal): diff --git a/src/emcee/tests/unit/test_backends.py b/src/emcee/tests/unit/test_backends.py index e1ec5484..7db16492 100644 --- a/src/emcee/tests/unit/test_backends.py +++ b/src/emcee/tests/unit/test_backends.py @@ -124,7 +124,7 @@ def test_backend(backend, dtype, blobs): last2 = sampler2.get_last_sample() assert np.allclose(last1.coords, last2.coords) assert np.allclose(last1.log_prob, last2.log_prob) - assert last1.random_state['bg_state'] == last2.random_state['bg_state'] + assert last1.random_state["bg_state"] == last2.random_state["bg_state"] if blobs: _custom_allclose(last1.blobs, last2.blobs) else: @@ -137,7 +137,7 @@ def test_backend(backend, dtype, blobs): @pytest.mark.parametrize("backend,dtype", product(other_backends, dtypes)) def test_reload(backend, dtype): - with (backend() as backend1): + with backend() as backend1: run_sampler(backend1, dtype=dtype) # Test the state @@ -192,7 +192,7 @@ def test_restart(backend, dtype): last2 = sampler2.get_last_sample() assert np.allclose(last1.coords, last2.coords) assert np.allclose(last1.log_prob, last2.log_prob) - assert last1.random_state['bg_state'] == last2.random_state['bg_state'] + assert last1.random_state["bg_state"] == last2.random_state["bg_state"] _custom_allclose(last1.blobs, last2.blobs) a = sampler1.acceptance_fraction diff --git a/src/emcee/tests/unit/test_sampler.py b/src/emcee/tests/unit/test_sampler.py index 0cac139a..00cc3bd0 100644 --- a/src/emcee/tests/unit/test_sampler.py +++ b/src/emcee/tests/unit/test_sampler.py @@ -137,7 +137,9 @@ def run_sampler( ): rng = np.random.default_rng(seed) coords = rng.standard_normal((nwalkers, ndim)) - sampler = EnsembleSampler(nwalkers, ndim, normal_log_prob, rng=rng, backend=backend) + sampler = EnsembleSampler( + nwalkers, ndim, normal_log_prob, rng=rng, backend=backend + ) sampler.run_mcmc( coords, nsteps,