Skip to content

Commit

Permalink
Merge pull request #41 from kdolum/master
Browse files Browse the repository at this point in the history
Fix sample-counting and resuming bugs
  • Loading branch information
vhaasteren authored Nov 17, 2023
2 parents 992d66d + 92bbdf6 commit 5b4a764
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 52 deletions.
9 changes: 6 additions & 3 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ max-line-length = 120
max-complexity = 45
ignore =
E203
W503 # line break before binary operator; conflicts with black
E722 # bare except ok
E731 # lambda expressions ok
# line break before binary operator; conflicts with black
W503
# bare except ok
E722
# lambda expressions ok
E731
exclude =
.git
.tox
Expand Down
148 changes: 99 additions & 49 deletions PTMCMCSampler/PTMCMCSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def __init__(
resume=False,
seed=None,
):

# MPI initialization
self.comm = comm
self.MPIrank = self.comm.Get_rank()
Expand Down Expand Up @@ -204,11 +203,12 @@ def initialize(
self.neff = neff
self.tstart = 0

N = int(maxIter / thin)
N = int(maxIter / thin) + 1 # first sample + those we generate

self._lnprob = np.zeros(N)
self._lnlike = np.zeros(N)
self._chain = np.zeros((N, self.ndim))
self.ind_next_write = 0 # Next index in these arrays to write out
self.naccepted = 0
self.swapProposed = 0
self.nswap_accepted = 0
Expand Down Expand Up @@ -291,13 +291,27 @@ def initialize(
print("Resuming run from chain file {0}".format(self.fname))
try:
self.resumechain = np.loadtxt(self.fname)
self.resumeLength = self.resumechain.shape[0]
except ValueError:
print("WARNING: Cant read in file. Removing last line.")
os.system("sed -ie '$d' {0}".format(self.fname))
self.resumechain = np.loadtxt(self.fname)
self.resumeLength = self.resumechain.shape[0]
self.resumeLength = self.resumechain.shape[0] # Number of samples read from old chain
except ValueError as error:
print("Reading old chain files failed with error", error)
raise Exception("Couldn't read old chain to resume")
self._chainfile = open(self.fname, "a")
if (
self.isave != self.thin
and self.resumeLength % (self.isave / self.thin) != 1 # This special case is always OK
): # Initial sample plus blocks of isave/thin
raise Exception(
(
"Old chain has {0} rows, which is not the initial sample plus a multiple of isave/thin = {1}"
).format(self.resumeLength, self.isave // self.thin)
)
print(
"Resuming with",
self.resumeLength,
"samples from file representing",
(self.resumeLength - 1) * self.thin + 1,
"original samples",
)
else:
self._chainfile = open(self.fname, "w")
self._chainfile.close()
Expand All @@ -319,18 +333,40 @@ def updateChains(self, p0, lnlike0, lnprob0, iter):
self._lnprob[ind] = lnprob0

# write to file
if iter % self.isave == 0 and iter > 1 and iter > self.resumeLength:
if iter % self.isave == 0:
self.writeOutput(iter)

def writeOutput(self, iter):
"""
Write chains and covariance matrix. Called every isave on samples or at end.
"""
if iter // self.thin >= self.ind_next_write:
if self.writeHotChains or self.MPIrank == 0:
self._writeToFile(iter)

# write output covariance matrix
np.save(self.outDir + "/cov.npy", self.cov)
if self.MPIrank == 0 and self.verbose and iter > 1:
sys.stdout.write("\r")
sys.stdout.write(
"Finished %2.2f percent in %f s Acceptance rate = %g"
% (iter / self.Niter * 100, time.time() - self.tstart, self.naccepted / iter)
)
if iter > 0:
np.save(self.outDir + "/cov.npy", self.cov)

if self.MPIrank == 0 and self.verbose:
if iter > 0:
sys.stdout.write("\r")
percent = iter / self.Niter * 100 # Percent of total work finished
acceptance = self.naccepted / iter if iter > 0 else 0
elapsed = time.time() - self.tstart
if self.resume:
# Percentage of new work done
percentnew = (
(iter - self.resumeLength * self.thin) / (self.Niter - self.resumeLength * self.thin) * 100
)
sys.stdout.write(
"Finished %2.2f percent (%2.2f percent of new work) in %f s Acceptance rate = %g"
% (percent, percentnew, elapsed, acceptance)
)
else:
sys.stdout.write(
"Finished %2.2f percent in %f s Acceptance rate = %g" % (percent, elapsed, acceptance)
)
sys.stdout.flush()

def sample(
Expand Down Expand Up @@ -368,7 +404,7 @@ def sample(
@param Tmin: Minimum temperature in ladder (default=1)
@param Tmax: Maximum temperature in ladder (default=None)
@param Tskip: Number of steps between proposed temperature swaps (default=100)
@param isave: Number of iterations before writing to file (default=1000)
@param isave: Write to file every isave samples (default=1000)
@param covUpdate: Number of iterations between AM covariance updates (default=1000)
@param SCAMweight: Weight of SCAM jumps in overall jump cycle (default=20)
@param AMweight: Weight of AM jumps in overall jump cycle (default=20)
Expand All @@ -381,7 +417,7 @@ def sample(
@param burn: Burn in time (DE jumps added after this iteration) (default=10000)
@param maxIter: Maximum number of iterations for high temperature chains
(default=2*self.Niter)
@param self.thin: Save every self.thin MCMC samples
@param self.thin: MCMC Samples are recorded every self.thin samples
@param i0: Iteration to start MCMC (if i0 !=0, do not re-initialize)
@param neff: Number of effective samples to collect before terminating
Expand All @@ -393,6 +429,15 @@ def sample(
elif maxIter is None and self.MPIrank == 0:
maxIter = Niter

if isave % thin != 0:
raise ValueError("isave = %d is not a multiple of thin = %d" % (isave, thin))

if Niter % thin != 0:
print(
"Niter = %d is not a multiple of thin = %d. The last %d samples will be lost"
% (Niter, thin, Niter % thin)
)

# set up arrays to store lnprob, lnlike and chain
# if picking up from previous run, don't re-initialize
if i0 == 0:
Expand Down Expand Up @@ -426,28 +471,28 @@ def sample(
# if resuming, just start with first point in chain
if self.resume and self.resumeLength > 0:
p0, lnlike0, lnprob0 = self.resumechain[0, :-4], self.resumechain[0, -3], self.resumechain[0, -4]
self.ind_next_write = self.resumeLength
else:
# compute prior
lp = self.logp(p0)

if lp == float(-np.inf):

lnprob0 = -np.inf
lnlike0 = -np.inf

else:

lnlike0 = self.logl(p0)
lnprob0 = 1 / self.temp * lnlike0 + lp

# record first values
self.tstart = time.time()
self.updateChains(p0, lnlike0, lnprob0, i0)

self.comm.barrier()

# start iterations
iter = i0
self.tstart = time.time()

runComplete = False
Neff = 0
while runComplete is False:
Expand All @@ -456,7 +501,7 @@ def sample(
# call PTMCMCOneStep
p0, lnlike0, lnprob0 = self.PTMCMCOneStep(p0, lnlike0, lnprob0, iter)

# compute effective number of samples
# compute effective number of samples in cold chain
if iter % 1000 == 0 and iter > 2 * self.burn and self.MPIrank == 0:
try:
Neff = iter / max(
Expand All @@ -468,19 +513,21 @@ def sample(
Neff = 0
pass

# stop if reached maximum number of iterations
if self.MPIrank == 0 and iter >= self.Niter - 1:
if self.verbose:
print("\nRun Complete")
runComplete = True
# rank 0 decides whether to stop
if self.MPIrank == 0:
if iter >= self.Niter: # stop if reached maximum number of iterations
message = "\nRun Complete"
runComplete = True
elif int(Neff) > self.neff: # stop if reached maximum number of iterations
message = "\nRun Complete with {0} effective samples".format(int(Neff))
runComplete = True

# stop if reached effective number of samples
if self.MPIrank == 0 and int(Neff) > self.neff:
if self.verbose:
print("\nRun Complete with {0} effective samples".format(int(Neff)))
runComplete = True
runComplete = self.comm.bcast(runComplete, root=0) # rank 0 tells others whether to stop

runComplete = self.comm.bcast(runComplete, root=0)
if runComplete:
self.writeOutput(iter) # Possibly write partial block
if self.MPIrank == 0 and self.verbose:
print(message)

def PTMCMCOneStep(self, p0, lnlike0, lnprob0, iter):
"""
Expand Down Expand Up @@ -541,12 +588,17 @@ def PTMCMCOneStep(self, p0, lnlike0, lnprob0, iter):

# jump proposal ###

# if resuming, just use previous chain points
if self.resume and self.resumeLength > 0 and iter < self.resumeLength:
p0, lnlike0, lnprob0 = self.resumechain[iter, :-4], self.resumechain[iter, -3], self.resumechain[iter, -4]
# if resuming, just use previous chain points. Use each one thin times to compensate for
# thinning when they were written out
if self.resume and self.resumeLength > 0 and iter < self.resumeLength * self.thin:
p0, lnlike0, lnprob0 = (
self.resumechain[iter // self.thin, :-4],
self.resumechain[iter // self.thin, -3],
self.resumechain[iter // self.thin, -4],
)

# update acceptance counter
self.naccepted = iter * self.resumechain[iter, -2]
self.naccepted = iter * self.resumechain[iter // self.thin, -2]
else:
y, qxy, jump_name = self._jump(p0, iter)
self.jumpDict[jump_name][0] += 1
Expand All @@ -555,18 +607,15 @@ def PTMCMCOneStep(self, p0, lnlike0, lnprob0, iter):
lp = self.logp(y)

if lp == -np.inf:

newlnprob = -np.inf

else:

newlnlike = self.logl(y)
newlnprob = 1 / self.temp * newlnlike + lp

# hastings step
diff = newlnprob - lnprob0 + qxy
if diff > np.log(self.stream.random()):

# accept jump
p0, lnlike0, lnprob0 = y, newlnlike, newlnprob

Expand Down Expand Up @@ -664,32 +713,35 @@ def temperatureLadder(self, Tmin, Tmax=None, tstep=None):

def _writeToFile(self, iter):
"""
Function to write chain file. File has 3+ndim columns,
the first is log-posterior (unweighted), log-likelihood,
and acceptance probability, followed by parameter values.
Function to write chain file. File has ndim+4 columns,
appended to the parameter values are log-posterior (unnormalized),
log-likelihood, acceptance rate, and PT acceptance rate.
Rates are as of time of writing.
@param iter: Iteration of sampler
"""

self._chainfile = open(self.fname, "a+")
for jj in range((iter - self.isave), iter, self.thin):
ind = int(jj / self.thin)
# index 0 is the initial element. So after 10*thin iterations we need to write elements 1..10
write_end = iter // self.thin + 1 # First element not to write.
for ind in range(self.ind_next_write, write_end):
pt_acc = 1
if self.MPIrank < self.nchain - 1 and self.swapProposed != 0:
pt_acc = self.nswap_accepted / self.swapProposed

self._chainfile.write("\t".join(["%22.22f" % (self._chain[ind, kk]) for kk in range(self.ndim)]))
self._chainfile.write(
"\t%f\t%f\t%f\t%f\n" % (self._lnprob[ind], self._lnlike[ind], self.naccepted / iter, pt_acc)
"\t%f\t%f\t%f\t%f\n"
% (self._lnprob[ind], self._lnlike[ind], self.naccepted / iter if iter > 0 else 0, pt_acc)
)
self._chainfile.close()
self.ind_next_write = write_end # Ready for next write

# write jump statistics files ####

# only for T=1 chain
if self.MPIrank == 0:

# first write file contaning jump names and jump rates
fout = open(self.outDir + "/jumps.txt", "w")
njumps = len(self.propCycle)
Expand Down Expand Up @@ -726,7 +778,6 @@ def _updateRecursive(self, iter, mem):
diff = np.zeros(ndim)
it += 1
for jj in range(ndim):

diff[jj] = self._AMbuffer[ii, jj] - self.mu[jj]
self.mu[jj] += diff[jj] / it

Expand Down Expand Up @@ -917,7 +968,6 @@ def DEJump(self, x, iter, beta):
scale = self.stream.random() * 2.4 / np.sqrt(2 * ndim) * np.sqrt(1 / beta)

for ii in range(ndim):

# jump size
sigma = self._DEbuffer[mm, self.groups[jumpind][ii]] - self._DEbuffer[nn, self.groups[jumpind][ii]]

Expand Down

0 comments on commit 5b4a764

Please sign in to comment.