aboutsummaryrefslogtreecommitdiffstats
path: root/golemflavor/mcmc.py
diff options
context:
space:
mode:
authorShivesh Mandalia <shivesh.mandalia@outlook.com>2020-02-28 18:39:45 +0000
committerShivesh Mandalia <shivesh.mandalia@outlook.com>2020-02-28 18:39:45 +0000
commit402f8b53dd892b8fd44ae5ad45eac91b5f6b3750 (patch)
treeb619c6efb0eb303e164bbd27691cdd9f8fce36a2 /golemflavor/mcmc.py
parent3a5a6c658e45402d413970e8d273a656ed74dcf5 (diff)
downloadGolemFlavor-402f8b53dd892b8fd44ae5ad45eac91b5f6b3750.tar.gz
GolemFlavor-402f8b53dd892b8fd44ae5ad45eac91b5f6b3750.zip
reogranise into a python package
Diffstat (limited to 'golemflavor/mcmc.py')
-rw-r--r--golemflavor/mcmc.py120
1 files changed, 120 insertions, 0 deletions
diff --git a/golemflavor/mcmc.py b/golemflavor/mcmc.py
new file mode 100644
index 0000000..49e5022
--- /dev/null
+++ b/golemflavor/mcmc.py
@@ -0,0 +1,120 @@
+# author : S. Mandalia
+# s.p.mandalia@qmul.ac.uk
+#
+# date : March 17, 2018
+
+"""
+Useful functions to use an MCMC for the BSM flavour ratio analysis
+"""
+
+from __future__ import absolute_import, division
+
+from functools import partial
+
+import emcee
+import tqdm
+
+import numpy as np
+
+from utils.enums import MCMCSeedType
+from utils.misc import enum_parse, make_dir, parse_bool
+
+
+def mcmc(p0, ln_prob, ndim, nwalkers, burnin, nsteps, args, threads=1):
+ """Run the MCMC."""
+ sampler = emcee.EnsembleSampler(
+ nwalkers, ndim, ln_prob, threads=threads
+ )
+
+ print "Running burn-in"
+ for result in tqdm.tqdm(sampler.sample(p0, iterations=burnin), total=burnin):
+ pos, prob, state = result
+ sampler.reset()
+ print "Finished burn-in"
+ args.burnin = False
+
+ print "Running"
+ for _ in tqdm.tqdm(sampler.sample(pos, iterations=nsteps), total=nsteps):
+ pass
+ print "Finished"
+
+ samples = sampler.chain.reshape((-1, ndim))
+ print 'acceptance fraction', sampler.acceptance_fraction
+ print 'sum of acceptance fraction', np.sum(sampler.acceptance_fraction)
+ print 'np.unique(samples[:,0]).shape', np.unique(samples[:,0]).shape
+ try:
+ print 'autocorrelation', sampler.acor
+ except:
+ print 'WARNING : NEED TO RUN MORE SAMPLES'
+
+ return samples
+
+
+def mcmc_argparse(parser):
+ parser.add_argument(
+ '--run-mcmc', type=parse_bool, default='True',
+ help='Run the MCMC'
+ )
+ parser.add_argument(
+ '--burnin', type=int, default=100,
+ help='Amount to burnin'
+ )
+ parser.add_argument(
+ '--nwalkers', type=int, default=60,
+ help='Number of walkers'
+ )
+ parser.add_argument(
+ '--nsteps', type=int, default=2000,
+ help='Number of steps to run'
+ )
+ parser.add_argument(
+ '--mcmc-seed-type', default='uniform',
+ type=partial(enum_parse, c=MCMCSeedType), choices=MCMCSeedType,
+ help='Type of distrbution to make the initial MCMC seed'
+ )
+ parser.add_argument(
+ '--plot-angles', type=parse_bool, default='False',
+ help='Plot MCMC triangle in the angles space'
+ )
+ parser.add_argument(
+ '--plot-elements', type=parse_bool, default='False',
+ help='Plot MCMC triangle in the mixing elements space'
+ )
+
+
+def flat_seed(paramset, nwalkers):
+ """Get gaussian seed values for the MCMC."""
+ ndim = len(paramset)
+ low = np.array(paramset.seeds).T[0]
+ high = np.array(paramset.seeds).T[1]
+ p0 = np.random.uniform(
+ low=low, high=high, size=[nwalkers, ndim]
+ )
+ return p0
+
+
+def gaussian_seed(paramset, nwalkers):
+ """Get gaussian seed values for the MCMC."""
+ ndim = len(paramset)
+ p0 = np.random.normal(
+ paramset.values, paramset.stds, size=[nwalkers, ndim]
+ )
+ return p0
+
+
+def save_chains(chains, outfile):
+ """Save the chains.
+
+ Parameters
+ ----------
+ chains : numpy ndarray
+ MCMC chains to save
+
+ outfile : str
+ Output file location of chains
+
+ """
+ make_dir(outfile)
+ print 'Saving chains to location {0}'.format(outfile+'.npy')
+ np.save(outfile+'.npy', chains)
+