diff options
Diffstat (limited to 'golemflavor/mcmc.py')
| -rw-r--r-- | golemflavor/mcmc.py | 17 |
1 files changed, 12 insertions, 5 deletions
diff --git a/golemflavor/mcmc.py b/golemflavor/mcmc.py index a1d3e27..c002620 100644 --- a/golemflavor/mcmc.py +++ b/golemflavor/mcmc.py @@ -9,10 +9,14 @@ Useful functions to use an MCMC for the BSM flavor ratio analysis from __future__ import absolute_import, division, print_function +import sys from functools import partial import emcee -import tqdm +if 'ipykernel' in sys.modules: + from tqdm import tqdm_notebook as tqdm +else: + from tqdm import tqdm import numpy as np @@ -20,21 +24,20 @@ from golemflavor.enums import MCMCSeedType from golemflavor.misc import enum_parse, make_dir, parse_bool -def mcmc(p0, ln_prob, ndim, nwalkers, burnin, nsteps, args, threads=1): +def mcmc(p0, ln_prob, ndim, nwalkers, burnin, nsteps, threads=1): """Run the MCMC.""" sampler = emcee.EnsembleSampler( nwalkers, ndim, ln_prob, threads=threads ) print("Running burn-in") - for result in tqdm.tqdm(sampler.sample(p0, iterations=burnin), total=burnin): + for result in tqdm(sampler.sample(p0, iterations=burnin), total=burnin): pos, prob, state = result sampler.reset() print("Finished burn-in") - args.burnin = False print("Running") - for _ in tqdm.tqdm(sampler.sample(pos, iterations=nsteps), total=nsteps): + for _ in tqdm(sampler.sample(pos, iterations=nsteps), total=nsteps): pass print("Finished") @@ -114,6 +117,10 @@ def save_chains(chains, outfile): Output file location of chains """ + if outfile[-4:] == '.npy': + of = outfile + else: + of = outfile + '.npy' make_dir(outfile) print('Saving chains to location {0}'.format(outfile+'.npy')) np.save(outfile+'.npy', chains) |
