diff options
Diffstat (limited to 'utils/mcmc.py')
| -rw-r--r-- | utils/mcmc.py | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/utils/mcmc.py b/utils/mcmc.py index f898b83..c712c3a 100644 --- a/utils/mcmc.py +++ b/utils/mcmc.py @@ -21,10 +21,10 @@ from utils.enums import MCMCSeedType from utils.misc import enum_parse, make_dir, parse_bool -def mcmc(p0, triangle_llh, lnprior, ndim, nwalkers, burnin, nsteps, ntemps=1, threads=1): +def mcmc(p0, ln_prob, ndim, nwalkers, burnin, nsteps, threads=1): """Run the MCMC.""" - sampler = emcee.PTSampler( - ntemps, nwalkers, ndim, triangle_llh, lnprior, threads=threads + sampler = emcee.EnsembleSampler( + nwalkers, ndim, ln_prob, threads=threads ) print "Running burn-in" @@ -38,7 +38,7 @@ def mcmc(p0, triangle_llh, lnprior, ndim, nwalkers, burnin, nsteps, ntemps=1, th pass print "Finished" - samples = sampler.chain[0, :, :, :].reshape((-1, ndim)) + samples = sampler.chain.reshape((-1, ndim)) print 'acceptance fraction', sampler.acceptance_fraction print 'sum of acceptance fraction', np.sum(sampler.acceptance_fraction) print 'np.unique(samples[:,0]).shape', np.unique(samples[:,0]).shape @@ -74,22 +74,22 @@ def mcmc_argparse(parser): ) -def flat_seed(paramset, ntemps, nwalkers): +def flat_seed(paramset, nwalkers): """Get gaussian seed values for the MCMC.""" ndim = len(paramset) low = np.array(paramset.seeds).T[0] high = np.array(paramset.seeds).T[1] p0 = np.random.uniform( - low=low, high=high, size=[ntemps, nwalkers, ndim] + low=low, high=high, size=[nwalkers, ndim] ) return p0 -def gaussian_seed(paramset, ntemps, nwalkers): +def gaussian_seed(paramset, nwalkers): """Get gaussian seed values for the MCMC.""" ndim = len(paramset) p0 = np.random.normal( - paramset.values, paramset.stds, size=[ntemps, nwalkers, ndim] + paramset.values, paramset.stds, size=[nwalkers, ndim] ) return p0 |
