aboutsummaryrefslogtreecommitdiffstats
path: root/golemflavor/mcmc.py
blob: a1d3e27da1fce35b029bf6150c0a63a66e520fe3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# author : S. Mandalia
#          s.p.mandalia@qmul.ac.uk
#
# date   : March 17, 2018

"""
Useful functions to use an MCMC for the BSM flavor ratio analysis
"""

from __future__ import absolute_import, division, print_function

from functools import partial

import emcee
import tqdm

import numpy as np

from golemflavor.enums import MCMCSeedType
from golemflavor.misc import enum_parse, make_dir, parse_bool


def mcmc(p0, ln_prob, ndim, nwalkers, burnin, nsteps, args, threads=1):
    """Run the MCMC."""
    sampler = emcee.EnsembleSampler(
        nwalkers, ndim, ln_prob, threads=threads
    )

    print("Running burn-in")
    for result in tqdm.tqdm(sampler.sample(p0, iterations=burnin), total=burnin):
        pos, prob, state = result
    sampler.reset()
    print("Finished burn-in")
    args.burnin = False

    print("Running")
    for _ in tqdm.tqdm(sampler.sample(pos, iterations=nsteps), total=nsteps):
        pass
    print("Finished")

    samples = sampler.chain.reshape((-1, ndim))
    print('acceptance fraction', sampler.acceptance_fraction)
    print('sum of acceptance fraction', np.sum(sampler.acceptance_fraction))
    print('np.unique(samples[:,0]).shape', np.unique(samples[:,0]).shape)
    try:
        print('autocorrelation', sampler.acor)
    except:
        print('WARNING : NEED TO RUN MORE SAMPLES')

    return samples


def mcmc_argparse(parser):
    parser.add_argument(
        '--run-mcmc', type=parse_bool, default='True',
        help='Run the MCMC'
    )
    parser.add_argument(
        '--burnin', type=int, default=100,
        help='Amount to burnin'
    )
    parser.add_argument(
        '--nwalkers', type=int, default=60,
        help='Number of walkers'
    )
    parser.add_argument(
        '--nsteps', type=int, default=2000,
        help='Number of steps to run'
    )
    parser.add_argument(
        '--mcmc-seed-type', default='uniform',
        type=partial(enum_parse, c=MCMCSeedType), choices=MCMCSeedType,
        help='Type of distrbution to make the initial MCMC seed'
    )
    parser.add_argument(
        '--plot-angles', type=parse_bool, default='False',
        help='Plot MCMC triangle in the angles space'
    )
    parser.add_argument(
        '--plot-elements', type=parse_bool, default='False',
        help='Plot MCMC triangle in the mixing elements space'
    )


def flat_seed(paramset, nwalkers):
    """Get gaussian seed values for the MCMC."""
    ndim = len(paramset)
    low = np.array(paramset.seeds).T[0]
    high = np.array(paramset.seeds).T[1]
    p0 = np.random.uniform(
        low=low, high=high, size=[nwalkers, ndim]
    )
    return p0


def gaussian_seed(paramset, nwalkers):
    """Get gaussian seed values for the MCMC."""
    ndim = len(paramset)
    p0 = np.random.normal(
        paramset.values, paramset.stds, size=[nwalkers, ndim]
    )
    return p0


def save_chains(chains, outfile):
    """Save the chains.

    Parameters
    ----------
    chains : numpy ndarray
        MCMC chains to save

    outfile : str
        Output file location of chains

    """
    make_dir(outfile)
    print('Saving chains to location {0}'.format(outfile+'.npy'))
    np.save(outfile+'.npy', chains)