aboutsummaryrefslogtreecommitdiffstats
path: root/golemflavor/mcmc.py
diff options
context:
space:
mode:
Diffstat (limited to 'golemflavor/mcmc.py')
-rw-r--r--golemflavor/mcmc.py17
1 files changed, 12 insertions, 5 deletions
diff --git a/golemflavor/mcmc.py b/golemflavor/mcmc.py
index a1d3e27..c002620 100644
--- a/golemflavor/mcmc.py
+++ b/golemflavor/mcmc.py
@@ -9,10 +9,14 @@ Useful functions to use an MCMC for the BSM flavor ratio analysis
from __future__ import absolute_import, division, print_function
+import sys
from functools import partial
import emcee
-import tqdm
+if 'ipykernel' in sys.modules:
+ from tqdm import tqdm_notebook as tqdm
+else:
+ from tqdm import tqdm
import numpy as np
@@ -20,21 +24,20 @@ from golemflavor.enums import MCMCSeedType
from golemflavor.misc import enum_parse, make_dir, parse_bool
-def mcmc(p0, ln_prob, ndim, nwalkers, burnin, nsteps, args, threads=1):
+def mcmc(p0, ln_prob, ndim, nwalkers, burnin, nsteps, threads=1):
"""Run the MCMC."""
sampler = emcee.EnsembleSampler(
nwalkers, ndim, ln_prob, threads=threads
)
print("Running burn-in")
- for result in tqdm.tqdm(sampler.sample(p0, iterations=burnin), total=burnin):
+ for result in tqdm(sampler.sample(p0, iterations=burnin), total=burnin):
pos, prob, state = result
sampler.reset()
print("Finished burn-in")
- args.burnin = False
print("Running")
- for _ in tqdm.tqdm(sampler.sample(pos, iterations=nsteps), total=nsteps):
+ for _ in tqdm(sampler.sample(pos, iterations=nsteps), total=nsteps):
pass
print("Finished")
@@ -114,6 +117,10 @@ def save_chains(chains, outfile):
Output file location of chains
"""
+ if outfile[-4:] == '.npy':
+ of = outfile
+ else:
+ of = outfile + '.npy'
make_dir(outfile)
print('Saving chains to location {0}'.format(outfile+'.npy'))
np.save(outfile+'.npy', chains)