diff --git a/docs/examples/MCMC_walkthrough.broken b/docs/examples/MCMC_walkthrough.broken index 4bc55a9ec..bb910e1b6 100644 --- a/docs/examples/MCMC_walkthrough.broken +++ b/docs/examples/MCMC_walkthrough.broken @@ -127,7 +127,8 @@ To make this run relatively fast for demonstration purposes, nsteps was purposef ```python fitter.phaseogram() -samples = sampler.sampler.chain[:, 10:, :].reshape((-1, fitter.n_fit_params)) +samples = np.transpose(sampler.sampler.get_chain(discard=10), (1, 0, 2)).reshape( + (-1, fitter.n_fit_params)) ranges = map( lambda v: (v[1], v[2] - v[1], v[1] - v[0]), zip(*np.percentile(samples, [16, 50, 84], axis=0)), @@ -192,7 +193,7 @@ fitter2.fit_toas(maxiter=nsteps2, pos=None) ``` ```python -samples2 = sampler2.sampler.chain[:, :, :].reshape((-1, fitter2.n_fit_params)) +samples2 = np.transpose(sampler2.sampler.get_chain(), (1, 0, 2)).reshape((-1, fitter2.n_fit_params)) ranges2 = map( lambda v: (v[1], v[2] - v[1], v[1] - v[0]), zip(*np.percentile(samples2, [16, 50, 84], axis=0)), diff --git a/docs/examples/fit_NGC6440E_MCMC.py b/docs/examples/fit_NGC6440E_MCMC.py index 9615c2a14..ea4fa718c 100644 --- a/docs/examples/fit_NGC6440E_MCMC.py +++ b/docs/examples/fit_NGC6440E_MCMC.py @@ -87,8 +87,9 @@ def plot_chains(chain_dict, file=False): plot_chains(chains, file=f"{f.model.PSR.value}_chains.png") # triangle plot -# this doesn't include burn-in because we're not using it here, otherwise would have middle ':' --> 'burnin:' -samples = sampler.sampler.chain[:, :, :].reshape((-1, f.n_fit_params)) +# this doesn't include burn-in because we're not using it here, otherwise set get_chain(discard=burnin) +# samples = sampler.sampler.chain[:, :, :].reshape((-1, f.n_fit_params)) +samples = np.transpose(sampler.get_chain(), (1, 0, 2)).reshape((-1, ndim)) with contextlib.suppress(ImportError): import corner diff --git a/src/pint/sampler.py b/src/pint/sampler.py index e4f49c6a1..53d09bddf 100644 --- a/src/pint/sampler.py +++ b/src/pint/sampler.py @@ -150,7 +150,7 @@ def get_chain(self): """ if self.sampler is None: raise ValueError("MCMCSampler object has not called initialize_sampler()") - return self.sampler.chain + return self.sampler.get_chain() def chains_to_dict(self, names): """ @@ -158,7 +158,8 @@ def chains_to_dict(self, names): """ if self.sampler is None: raise ValueError("MCMCSampler object has not called initialize_sampler()") - chains = [self.sampler.chain[:, :, ii].T for ii in range(len(names))] + samples = np.transpose(self.sampler.get_chain(), (1, 0, 2)) + chains = [samples[:, :, ii].T for ii in range(len(names))] return dict(zip(names, chains)) def run_mcmc(self, pos, nsteps): diff --git a/src/pint/scripts/event_optimize.py b/src/pint/scripts/event_optimize.py index 23b05fc06..b7560d320 100755 --- a/src/pint/scripts/event_optimize.py +++ b/src/pint/scripts/event_optimize.py @@ -837,7 +837,8 @@ def unwrapped_lnpost(theta): sampler.run_mcmc(pos, nsteps) def chains_to_dict(names, sampler): - chains = [sampler.chain[:, :, ii].T for ii in range(len(names))] + samples = np.transpose(sampler.get_chain(), (1, 0, 2)) + chains = [samples[:, :, ii].T for ii in range(len(names))] return dict(zip(names, chains)) def plot_chains(chain_dict, file=False): @@ -859,7 +860,9 @@ def plot_chains(chain_dict, file=False): plot_chains(chains, file=filename + "_chains.png") # Make the triangle plot. - samples = sampler.chain[:, burnin:, :].reshape((-1, ndim)) + samples = np.transpose(sampler.get_chain(discard=burnin), (1, 0, 2)).reshape( + (-1, ndim) + ) blobs = sampler.get_blobs() lnprior_samps = blobs["lnprior"] diff --git a/src/pint/scripts/event_optimize_MCMCFitter.py b/src/pint/scripts/event_optimize_MCMCFitter.py index bde72ed93..bbab8ccee 100755 --- a/src/pint/scripts/event_optimize_MCMCFitter.py +++ b/src/pint/scripts/event_optimize_MCMCFitter.py @@ -358,7 +358,10 @@ def plot_chains(chain_dict, file=False): plot_chains(chains, file=ftr.model.PSR.value + "_chains.png") # Make the triangle plot. - samples = sampler.sampler.chain[:, burnin:, :].reshape((-1, ftr.n_fit_params)) + # samples = sampler.sampler.chain[:, burnin:, :].reshape((-1, ftr.n_fit_params)) + samples = np.transpose( + sampler.sampler.get_chain(discard=burnin), (1, 0, 2) + ).reshape((-1, ftr.n_fit_params)) try: import corner diff --git a/src/pint/scripts/event_optimize_multiple.py b/src/pint/scripts/event_optimize_multiple.py index eda745c4e..41316a000 100755 --- a/src/pint/scripts/event_optimize_multiple.py +++ b/src/pint/scripts/event_optimize_multiple.py @@ -426,7 +426,10 @@ def plot_chains(chain_dict, file=False): plot_chains(chains, file=f"{ftr.model.PSR.value}_chains.png") # Make the triangle plot. - samples = sampler.sampler.chain[:, burnin:, :].reshape((-1, ftr.n_fit_params)) + # samples = sampler.sampler.chain[:, burnin:, :].reshape((-1, ftr.n_fit_params)) + samples = np.transpose( + sampler.sampler.get_chain(discard=burnin), (1, 0, 2) + ).reshape((-1, ftr.n_fit_params)) with contextlib.suppress(ImportError): import corner diff --git a/tests/test_determinism.py b/tests/test_determinism.py index 733c09c4b..a7ea7820b 100644 --- a/tests/test_determinism.py +++ b/tests/test_determinism.py @@ -83,9 +83,11 @@ def test_sampler(): # fitter.phaseogram() # samples = sampler.sampler.chain[:, 10:, :].reshape((-1, fitter.n_fit_params)) + samples = np.transpose(sampler.get_chain(), (1, 0, 2)) # r.append(np.random.randn()) - r.append(sampler.sampler.chain[0]) + # r.append(sampler.sampler.chain[0]) + r.append(samples[0]) assert_array_equal(r[0], r[1]) @@ -109,6 +111,7 @@ def log_prob(x, ivar): sampler.random_state = s sampler.run_mcmc(p0, 100) - samples = sampler.chain.reshape((-1, ndim)) + # samples = sampler.chain.reshape((-1, ndim)) + samples = np.transpose(sampler.get_chain(), (1, 0, 2)).reshape((-1, ndim)) r.append(samples[0, 0]) assert r[0] == r[1]