aboutsummaryrefslogtreecommitdiffstats
path: root/utils/mcmc.py
diff options
context:
space:
mode:
authorshivesh <s.p.mandalia@qmul.ac.uk>2018-04-09 17:15:52 -0500
committershivesh <s.p.mandalia@qmul.ac.uk>2018-04-09 17:15:52 -0500
commitb2a022cd77c2f068d5530d3c04407f716094da66 (patch)
tree323d2f6640974d54bad24f45eb3ced97a596f37b /utils/mcmc.py
parent5e4ed5a6f8935d71049a521d5efcc2c09a633e3e (diff)
downloadGolemFlavor-b2a022cd77c2f068d5530d3c04407f716094da66.tar.gz
GolemFlavor-b2a022cd77c2f068d5530d3c04407f716094da66.zip
Mon Apr 9 17:15:52 CDT 2018
Diffstat (limited to 'utils/mcmc.py')
-rw-r--r--utils/mcmc.py16
1 files changed, 8 insertions, 8 deletions
diff --git a/utils/mcmc.py b/utils/mcmc.py
index f898b83..c712c3a 100644
--- a/utils/mcmc.py
+++ b/utils/mcmc.py
@@ -21,10 +21,10 @@ from utils.enums import MCMCSeedType
from utils.misc import enum_parse, make_dir, parse_bool
-def mcmc(p0, triangle_llh, lnprior, ndim, nwalkers, burnin, nsteps, ntemps=1, threads=1):
+def mcmc(p0, ln_prob, ndim, nwalkers, burnin, nsteps, threads=1):
"""Run the MCMC."""
- sampler = emcee.PTSampler(
- ntemps, nwalkers, ndim, triangle_llh, lnprior, threads=threads
+ sampler = emcee.EnsembleSampler(
+ nwalkers, ndim, ln_prob, threads=threads
)
print "Running burn-in"
@@ -38,7 +38,7 @@ def mcmc(p0, triangle_llh, lnprior, ndim, nwalkers, burnin, nsteps, ntemps=1, th
pass
print "Finished"
- samples = sampler.chain[0, :, :, :].reshape((-1, ndim))
+ samples = sampler.chain.reshape((-1, ndim))
print 'acceptance fraction', sampler.acceptance_fraction
print 'sum of acceptance fraction', np.sum(sampler.acceptance_fraction)
print 'np.unique(samples[:,0]).shape', np.unique(samples[:,0]).shape
@@ -74,22 +74,22 @@ def mcmc_argparse(parser):
)
-def flat_seed(paramset, ntemps, nwalkers):
+def flat_seed(paramset, nwalkers):
"""Get gaussian seed values for the MCMC."""
ndim = len(paramset)
low = np.array(paramset.seeds).T[0]
high = np.array(paramset.seeds).T[1]
p0 = np.random.uniform(
- low=low, high=high, size=[ntemps, nwalkers, ndim]
+ low=low, high=high, size=[nwalkers, ndim]
)
return p0
-def gaussian_seed(paramset, ntemps, nwalkers):
+def gaussian_seed(paramset, nwalkers):
"""Get gaussian seed values for the MCMC."""
ndim = len(paramset)
p0 = np.random.normal(
- paramset.values, paramset.stds, size=[ntemps, nwalkers, ndim]
+ paramset.values, paramset.stds, size=[nwalkers, ndim]
)
return p0