Skip to content

Commit

Permalink
undoing a previous commit, allowing reverting stepsamplers again
Browse files Browse the repository at this point in the history
Without reverting, a stepsampler may return a result, but it cannot
be used anymore, even when resuming it would probably not be loaded.
With reverting, raising the Lmin should affect only the stepsamper
with probability 1/K, so the sampler has a good chance of not having
to revert back very often.

With the new queue, variation in executation duration should be less
of an issue.
  • Loading branch information
JohannesBuchner committed Jan 26, 2025
1 parent 14ceb6e commit 86441d2
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 74 deletions.
2 changes: 1 addition & 1 deletion tests/test_popstepsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def test_SimpleSliceSampler_SLOW(seed=4):
# resetting the seed to check the slice axes
np.random.seed(seed)
for i in range(popsize):
u[i],_,L[i],_,_= stepsampler.__next__(region, Lmin, us.copy(), Ls.copy(), transform, loglike_vectorized, test=True)
u[i],_,L[i],_= stepsampler.__next__(region, Lmin, us.copy(), Ls.copy(), transform, loglike_vectorized, test=True)

# Basic check
assert (L>Lmin).all(), (L,Lmin) # Lmin check
Expand Down
12 changes: 5 additions & 7 deletions tests/test_stepsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,18 +224,16 @@ def test_stepsampler(plot=False):

stepsampler = CubeMHSampler(nsteps=len(paramnames))
while True:
u1, p1, L1, Lmin, nc = stepsampler.__next__(region, -1e100, region.u, Ls, transform, loglike)
u1, p1, L1, nc = stepsampler.__next__(region, -1e100, region.u, Ls, transform, loglike)
if u1 is not None:
break
assert L1 > -1e100
assert Lmin == -1e100
print(u1, L1)
while True:
u2, p2, L2, Lmin, nc = stepsampler.__next__(region, -1e100, region.u, Ls, transform, loglike)
u2, p2, L2, nc = stepsampler.__next__(region, -1e100, region.u, Ls, transform, loglike)
if u2 is not None:
break
assert L2 > -1e100
assert Lmin == -1e100
print(u2, L2)
assert np.all(u1 != u2)
assert np.all(L1 != L2)
Expand All @@ -255,7 +253,7 @@ def test_stepsampler_adapt_when_stuck(plot=False):
for i in range(1000):
if i > 100:
assert False, i
unew, pnew, Lnew, Lnew, nc = stepsampler.__next__(region, Lmin, us, Ls, transform, loglike, ndraw=10)
unew, pnew, Lnew, nc = stepsampler.__next__(region, Lmin, us, Ls, transform, loglike, ndraw=10)
if unew is not None:
break

Expand All @@ -271,7 +269,7 @@ def test_stepsampler_adapt_when_stuck(plot=False):
for i in range(1000):
if i > 100:
assert False, i
unew, pnew, Lnew, Lnew, nc = stepsampler.__next__(region, Lmin, us, Ls, transform, loglike, ndraw=10)
unew, pnew, Lnew, nc = stepsampler.__next__(region, Lmin, us, Ls, transform, loglike, ndraw=10)
if unew is not None:
break

Expand Down Expand Up @@ -303,7 +301,7 @@ def test_stepsampler_adapt(plot=True):
old_scale = stepsampler.scale
for i in range(5):
while True:
unew, pnew, Lnew, Lnew, nc = stepsampler.__next__(region, -1e100, region.u, Ls, transform, loglike)
unew, pnew, Lnew, nc = stepsampler.__next__(region, -1e100, region.u, Ls, transform, loglike)
if unew is not None:
break
new_scale = stepsampler.scale
Expand Down
2 changes: 1 addition & 1 deletion ultranest/integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1899,7 +1899,7 @@ def _create_point(self, iteration, Lmin, ndraw, active_u, active_values):
while not self.pointqueue.has(rank_to_fetch):
# clear and reset cache, then refill by sampling
if use_stepsampler:
u, v, logl, Lmin_sampled, nc = self.stepsampler.__next__(
u, v, logl, nc = self.stepsampler.__next__(
self.region,
transform=self.transform, loglike=self.loglike,
Lmin=Lmin, us=active_u, Ls=active_values,
Expand Down
39 changes: 13 additions & 26 deletions ultranest/popstepsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def __next__(
nc = 0

u, p, L = self.prepared_samples.pop(0)
return u, p, L, Lmin, nc
return u, p, L, nc


class PopulationSliceSampler(GenericPopulationSampler):
Expand Down Expand Up @@ -395,7 +395,6 @@ def __init__(
self.scale_adapt_factor = scale_adapt_factor
self.allu = []
self.allL = []
self.allLmin = []
self.currentt = []
self.currentv = []
self.currentp = []
Expand Down Expand Up @@ -429,7 +428,6 @@ def _setup(self, ndim):
"""Allocate arrays."""
self.allu = np.zeros((self.popsize, self.nsteps + 1, ndim)) + np.nan
self.allL = np.zeros((self.popsize, self.nsteps + 1)) + np.nan
self.allLmin = np.zeros((self.popsize, self.nsteps + 1)) + np.nan
self.currentt = np.zeros(self.popsize) + np.nan
self.currentv = np.zeros((self.popsize, ndim)) + np.nan
self.generation = np.zeros(self.popsize, dtype=int_dtype) - 1
Expand All @@ -438,7 +436,7 @@ def _setup(self, ndim):
self.searching_left = np.zeros(self.popsize, dtype=bool)
self.searching_right = np.zeros(self.popsize, dtype=bool)

def setup_start(self, us, Ls, Lmin, starting):
def setup_start(self, us, Ls, starting):
"""Initialize walker starting points.
For iteration zero, randomly selects a live point as starting point.
Expand All @@ -449,8 +447,6 @@ def setup_start(self, us, Ls, Lmin, starting):
live points
Ls: np.array(nlive)
loglikelihoods live points
Lmin: float
loglikelihood threshold
starting: np.array(nwalkers, dtype=bool)
which walkers to initialize.
Expand All @@ -470,7 +466,6 @@ def setup_start(self, us, Ls, Lmin, starting):

self.allu[starting,0] = us[i]
self.allL[starting,0] = Ls[i]
self.allLmin[starting,:] = Lmin
self.generation[starting] = 0

@property
Expand Down Expand Up @@ -511,7 +506,7 @@ def _setup_currentp(self, nparams):
print("setting currentp")
self.currentp = np.zeros((self.popsize, nparams)) + np.nan

def advance(self, transform, loglike, region):
def advance(self, transform, loglike, Lmin, region):
"""Advance the walker population.
Parameters
Expand All @@ -520,6 +515,8 @@ def advance(self, transform, loglike, region):
prior transform function
loglike: function
loglikelihood function
Lmin: float
current log-likelihood threshold
region: MLFriends object
Region
Expand All @@ -536,7 +533,6 @@ def advance(self, transform, loglike, region):
args = [
self.allu[i, self.generation],
self.allL[i, self.generation],
self.allLmin[i, self.generation],
# pass values directly
self.currentt,
self.currentv,
Expand All @@ -550,7 +546,6 @@ def advance(self, transform, loglike, region):
args = [
self.allu[movable, self.generation[movable]],
self.allL[movable, self.generation[movable]],
self.allLmin[movable, self.generation[movable]],
# this makes copies
self.currentt[movable],
self.currentv[movable],
Expand All @@ -567,12 +562,9 @@ def advance(self, transform, loglike, region):
currentt, currentv,
current_left, current_right, searching_left, searching_right
),
(success, unew, pnew, Lnew, Lminnew),
(success, unew, pnew, Lnew),
nc
) = evolve(transform, loglike, *args)

if self.log:
print("evolve moved:", args[1], args[2], Lnew, Lminnew)
) = evolve(transform, loglike, Lmin, *args)

if success.any():
far_enough, (move_distance, reference_distance) = diagnose_move_distances(region, uorig[success,:], unew)
Expand Down Expand Up @@ -668,14 +660,12 @@ def __next__(
if len(self.allu) == 0:
self._setup(ndim)

# step_back(Lmin, self.allL, self.generation, self.currentt)
step_back(Lmin, self.allL, self.generation, self.currentt)

starting = self.generation < 0
assert np.isfinite(Lmin), Lmin
if starting.any():
self.setup_start(us[Ls > Lmin], Ls[Ls > Lmin], Lmin, starting)
self.setup_start(us[Ls > Lmin], Ls[Ls > Lmin], starting)
assert (self.generation >= 0).all(), self.generation
assert np.isfinite(self.allLmin).all(), self.allLmin

# find those where bracket is undefined:
mask_starting = ~np.isfinite(self.currentt)
Expand All @@ -684,10 +674,9 @@ def __next__(

if self.log:
print(str(self), "(before)")
nc = self.advance(transform, loglike, region)
nc = self.advance(transform, loglike, Lmin, region)
if self.log:
print(str(self), "(after)")
assert np.isfinite(self.allLmin).all(), self.allLmin

# harvest top individual if possible
if self.generation[self.ringindex] == self.nsteps:
Expand All @@ -696,23 +685,21 @@ def __next__(
u = self.allu[self.ringindex, self.nsteps, :].copy()
p = self.currentp[self.ringindex, :].copy()
L = self.allL[self.ringindex, self.nsteps].copy()
Lmin_start = self.allLmin[self.ringindex, self.nsteps].copy()
assert np.isfinite(u).all(), u
assert np.isfinite(p).all(), p
self.generation[self.ringindex] = -1
self.currentt[self.ringindex] = np.nan
self.allu[self.ringindex,:,:] = np.nan
self.allL[self.ringindex,:] = np.nan
self.allLmin[self.ringindex,:] = np.nan

# adjust guess length
newscale = (self.current_right[self.ringindex] - self.current_left[self.ringindex]) / 2
self.scale = self.scale * 0.9 + 0.1 * newscale

self.shift()
return u, p, L, Lmin_start, nc
return u, p, L, nc
else:
return None, None, None, None, nc
return None, None, None, nc


def slice_limit_to_unitcube(tleft, tright):
Expand Down Expand Up @@ -1011,7 +998,7 @@ def __next__(
nc = 0

u, p, L = self.prepared_samples.pop(0)
return u, p, L, Lmin, nc
return u, p, L, nc


__all__ = [
Expand Down
22 changes: 9 additions & 13 deletions ultranest/stepfuncs.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def evolve_prepare(searching_left, searching_right):
cpdef evolve_update(
np.ndarray[np.uint8_t, ndim=1] acceptable,
np.ndarray[np.float_t, ndim=1] Lnew,
np.ndarray[np.float_t, ndim=1] Lmin,
np.float_t Lmin,
np.ndarray[np.uint8_t, ndim=1] search_right,
np.ndarray[np.uint8_t, ndim=1] bisecting,
np.float_t[:] currentt,
Expand All @@ -120,7 +120,7 @@ cpdef evolve_update(
whether a likelihood evaluation was made. If false, rejected because out of contour.
Lnew: np.array(acceptable.sum(), dtype=bool)
likelihood value of proposed point
Lmin: np.array(acceptable.sum(), dtype=bool)
Lmin: float
current log-likelihood threshold
search_right: np.array(nwalkers, dtype=bool)
whether stepping out in the positive direction
Expand Down Expand Up @@ -151,7 +151,7 @@ cpdef evolve_update(

for k in range(popsize):
if acceptable[k]:
if Lnew[j] > Lmin[j]:
if Lnew[j] > Lmin:
success[k] = 1
j += 1

Expand Down Expand Up @@ -187,8 +187,8 @@ pnew_empty = np.empty((0,1))
Lnew_empty = np.empty(0)

def evolve(
transform, loglike,
currentu, currentL, Lmin, currentt, currentv,
transform, loglike, Lmin,
currentu, currentL, currentt, currentv,
current_left, current_right, searching_left, searching_right
):
"""Evolve each slice sampling walker.
Expand All @@ -199,12 +199,12 @@ def evolve(
prior transform function
loglike: function
loglikelihood function
Lmin: float
current log-likelihood threshold
currentu: np.array((nwalkers, ndim))
slice starting point (where currentt=0)
currentL: np.array(nwalkers)
current loglikelihood
Lmin: np.array(nwalkers)
current log-likelihood threshold
currentt: np.array(nwalkers)
proposed coordinate on the slice
currentv: np.array((nwalkers, ndim))
Expand Down Expand Up @@ -261,12 +261,10 @@ def evolve(
if acceptable.any():
pnew = transform(unew[acceptable,:])
Lnew = loglike(pnew)
Lminnew = Lmin[acceptable]
nc += len(pnew)
else:
pnew = pnew_empty
Lnew = Lnew_empty
Lminnew = Lnew_empty

success = np.zeros_like(searching_left)
evolve_update(
Expand All @@ -279,10 +277,8 @@ def evolve(
(
currentt, currentv,
current_left, current_right, searching_left, searching_right),
(
success, unew[success,:], pnew[success[acceptable],:],
Lnew[success[acceptable]], Lminnew[success[acceptable]]
), nc
(success, unew[success,:], pnew[success[acceptable],:], Lnew[success[acceptable]]),
nc
)


Expand Down
35 changes: 9 additions & 26 deletions ultranest/stepsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ class StepSampler:
def __init__(
self, nsteps, generate_direction,
scale=1.0, check_nsteps='move-distance', adaptive_nsteps=False, max_nsteps=1000,
region_filter=False, log=False, revert=True,
region_filter=False, log=False,
starting_point_selector=select_random_livepoint,
):
"""Initialise sampler.
Expand Down Expand Up @@ -662,15 +662,8 @@ def __init__(
always been the default behaviour,
or an instance of :py:class:`IslandPopulationRandomLivepointSelector`.
revert: bool
if true, and Lmin increases, the step sampler reverts to the last valid state.
if false, the step sampler resumes, but may return a result that is not
usable.
"""
self.revert = revert
self.history = []
self.Lmin = None
self.nsteps = nsteps
self.nrejects = 0
self.scale = scale
Expand Down Expand Up @@ -707,11 +700,7 @@ def __init__(
)
self.starting_point_selector = starting_point_selector
self.mean_pair_distance = np.nan
if self.revert:
self.region_filter = region_filter
else:
self.region_filter = False
raise ValueError("stepsampler cannot use region_filter=True with revert=False")
self.region_filter = region_filter
if log:
assert hasattr(log, 'write'), 'log argument should be a file, use log=open(filename, "w") or similar'
self.log = log
Expand Down Expand Up @@ -975,7 +964,6 @@ def finalize_chain(self, region=None, Lmin=None, Ls=None):
self.scale = self.next_scale
self.history = []
self.nrejects = 0
self.Lmin = None

def new_chain(self, region=None):
"""Start a new path, reset statistics."""
Expand Down Expand Up @@ -1032,17 +1020,13 @@ def __next__(self, region, Lmin, us, Ls, transform, loglike, ndraw=10, plot=Fals
number of likelihood function calls
"""
# find most recent point in history conforming to current Lmin
if self.revert:
for j, (_uj, Lj) in enumerate(self.history):
if not Lj > Lmin:
self.history = self.history[:j]
# print("wandered out of L constraint; reverting", ui[0])
break
for j, (_uj, Lj) in enumerate(self.history):
if not Lj > Lmin:
self.history = self.history[:j]
# print("wandered out of L constraint; reverting", ui[0])
break
if len(self.history) > 0:
ui, Li = self.history[-1]
if not self.revert:
# stick with original Lmin
Lmin = self.Lmin
else:
# select starting point
self.new_chain(region)
Expand All @@ -1057,7 +1041,6 @@ def __next__(self, region, Lmin, us, Ls, transform, loglike, ndraw=10, plot=Fals
# assert np.logical_and(ui > 0, ui < 1).all(), ui
Li = Ls[i]
self.history.append((ui.copy(), Li.copy()))
self.Lmin = Lmin
del i

while True:
Expand Down Expand Up @@ -1108,10 +1091,10 @@ def __next__(self, region, Lmin, us, Ls, transform, loglike, ndraw=10, plot=Fals
u, L = self.history[-1]
p = transform(u.reshape((1, -1)))[0]
self.finalize_chain(region=region, Lmin=Lmin, Ls=Ls)
return u, p, L, Lmin, nc
return u, p, L, nc

# do not have a independent sample yet
return None, None, None, Lmin, nc
return None, None, None, nc


class MHSampler(StepSampler):
Expand Down

0 comments on commit 86441d2

Please sign in to comment.