Skip to content

Commit

Permalink
Merge pull request #1596 from devbhakt/sampler_chains
Browse files Browse the repository at this point in the history
Sampler.chain deprecated fix
  • Loading branch information
paulray authored Jun 22, 2023
2 parents 93f7d9c + d5ea5e8 commit e261b35
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 12 deletions.
5 changes: 3 additions & 2 deletions docs/examples/MCMC_walkthrough.broken
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ To make this run relatively fast for demonstration purposes, nsteps was purposef

```python
fitter.phaseogram()
samples = sampler.sampler.chain[:, 10:, :].reshape((-1, fitter.n_fit_params))
samples = np.transpose(sampler.sampler.get_chain(discard=10), (1, 0, 2)).reshape(
(-1, fitter.n_fit_params))
ranges = map(
lambda v: (v[1], v[2] - v[1], v[1] - v[0]),
zip(*np.percentile(samples, [16, 50, 84], axis=0)),
Expand Down Expand Up @@ -192,7 +193,7 @@ fitter2.fit_toas(maxiter=nsteps2, pos=None)
```

```python
samples2 = sampler2.sampler.chain[:, :, :].reshape((-1, fitter2.n_fit_params))
samples2 = np.transpose(sampler2.sampler.get_chain(), (1, 0, 2)).reshape((-1, fitter2.n_fit_params))
ranges2 = map(
lambda v: (v[1], v[2] - v[1], v[1] - v[0]),
zip(*np.percentile(samples2, [16, 50, 84], axis=0)),
Expand Down
5 changes: 3 additions & 2 deletions docs/examples/fit_NGC6440E_MCMC.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ def plot_chains(chain_dict, file=False):
plot_chains(chains, file=f"{f.model.PSR.value}_chains.png")

# triangle plot
# this doesn't include burn-in because we're not using it here, otherwise would have middle ':' --> 'burnin:'
samples = sampler.sampler.chain[:, :, :].reshape((-1, f.n_fit_params))
# this doesn't include burn-in because we're not using it here, otherwise set get_chain(discard=burnin)
# samples = sampler.sampler.chain[:, :, :].reshape((-1, f.n_fit_params))
samples = np.transpose(sampler.get_chain(), (1, 0, 2)).reshape((-1, ndim))
with contextlib.suppress(ImportError):
import corner

Expand Down
5 changes: 3 additions & 2 deletions src/pint/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,16 @@ def get_chain(self):
"""
if self.sampler is None:
raise ValueError("MCMCSampler object has not called initialize_sampler()")
return self.sampler.chain
return self.sampler.get_chain()

def chains_to_dict(self, names):
"""
Convert the sampler chains to a dictionary
"""
if self.sampler is None:
raise ValueError("MCMCSampler object has not called initialize_sampler()")
chains = [self.sampler.chain[:, :, ii].T for ii in range(len(names))]
samples = np.transpose(self.sampler.get_chain(), (1, 0, 2))
chains = [samples[:, :, ii].T for ii in range(len(names))]
return dict(zip(names, chains))

def run_mcmc(self, pos, nsteps):
Expand Down
7 changes: 5 additions & 2 deletions src/pint/scripts/event_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,8 @@ def unwrapped_lnpost(theta):
sampler.run_mcmc(pos, nsteps)

def chains_to_dict(names, sampler):
chains = [sampler.chain[:, :, ii].T for ii in range(len(names))]
samples = np.transpose(sampler.get_chain(), (1, 0, 2))
chains = [samples[:, :, ii].T for ii in range(len(names))]
return dict(zip(names, chains))

def plot_chains(chain_dict, file=False):
Expand All @@ -859,7 +860,9 @@ def plot_chains(chain_dict, file=False):
plot_chains(chains, file=filename + "_chains.png")

# Make the triangle plot.
samples = sampler.chain[:, burnin:, :].reshape((-1, ndim))
samples = np.transpose(sampler.get_chain(discard=burnin), (1, 0, 2)).reshape(
(-1, ndim)
)

blobs = sampler.get_blobs()
lnprior_samps = blobs["lnprior"]
Expand Down
5 changes: 4 additions & 1 deletion src/pint/scripts/event_optimize_MCMCFitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,10 @@ def plot_chains(chain_dict, file=False):
plot_chains(chains, file=ftr.model.PSR.value + "_chains.png")

# Make the triangle plot.
samples = sampler.sampler.chain[:, burnin:, :].reshape((-1, ftr.n_fit_params))
# samples = sampler.sampler.chain[:, burnin:, :].reshape((-1, ftr.n_fit_params))
samples = np.transpose(
sampler.sampler.get_chain(discard=burnin), (1, 0, 2)
).reshape((-1, ftr.n_fit_params))
try:
import corner

Expand Down
5 changes: 4 additions & 1 deletion src/pint/scripts/event_optimize_multiple.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,10 @@ def plot_chains(chain_dict, file=False):
plot_chains(chains, file=f"{ftr.model.PSR.value}_chains.png")

# Make the triangle plot.
samples = sampler.sampler.chain[:, burnin:, :].reshape((-1, ftr.n_fit_params))
# samples = sampler.sampler.chain[:, burnin:, :].reshape((-1, ftr.n_fit_params))
samples = np.transpose(
sampler.sampler.get_chain(discard=burnin), (1, 0, 2)
).reshape((-1, ftr.n_fit_params))
with contextlib.suppress(ImportError):
import corner

Expand Down
7 changes: 5 additions & 2 deletions tests/test_determinism.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,11 @@ def test_sampler():

# fitter.phaseogram()
# samples = sampler.sampler.chain[:, 10:, :].reshape((-1, fitter.n_fit_params))
samples = np.transpose(sampler.get_chain(), (1, 0, 2))

# r.append(np.random.randn())
r.append(sampler.sampler.chain[0])
# r.append(sampler.sampler.chain[0])
r.append(samples[0])
assert_array_equal(r[0], r[1])


Expand All @@ -109,6 +111,7 @@ def log_prob(x, ivar):
sampler.random_state = s
sampler.run_mcmc(p0, 100)

samples = sampler.chain.reshape((-1, ndim))
# samples = sampler.chain.reshape((-1, ndim))
samples = np.transpose(sampler.get_chain(), (1, 0, 2)).reshape((-1, ndim))
r.append(samples[0, 0])
assert r[0] == r[1]

0 comments on commit e261b35

Please sign in to comment.