diff options
| author | Shivesh Mandalia <shivesh.mandalia@outlook.com> | 2020-02-28 18:39:45 +0000 |
|---|---|---|
| committer | Shivesh Mandalia <shivesh.mandalia@outlook.com> | 2020-02-28 18:39:45 +0000 |
| commit | 402f8b53dd892b8fd44ae5ad45eac91b5f6b3750 (patch) | |
| tree | b619c6efb0eb303e164bbd27691cdd9f8fce36a2 /golemflavor/param.py | |
| parent | 3a5a6c658e45402d413970e8d273a656ed74dcf5 (diff) | |
| download | GolemFlavor-402f8b53dd892b8fd44ae5ad45eac91b5f6b3750.tar.gz GolemFlavor-402f8b53dd892b8fd44ae5ad45eac91b5f6b3750.zip | |
reogranise into a python package
Diffstat (limited to 'golemflavor/param.py')
| -rw-r--r-- | golemflavor/param.py | 213 |
1 files changed, 213 insertions, 0 deletions
diff --git a/golemflavor/param.py b/golemflavor/param.py new file mode 100644 index 0000000..2378758 --- /dev/null +++ b/golemflavor/param.py @@ -0,0 +1,213 @@ +# author : S. Mandalia +# s.p.mandalia@qmul.ac.uk +# +# date : April 19, 2018 + +""" +Param class and functions for the BSM flavour ratio analysis +""" + +from __future__ import absolute_import, division + +import sys + +from collections import Sequence +from copy import deepcopy + +import numpy as np + +from utils.fr import fr_to_angles +from utils.enums import DataType, Likelihood, ParamTag, PriorsCateg + + +class Param(object): + """Parameter class to store parameters.""" + def __init__(self, name, value, ranges, prior=None, seed=None, std=None, + tex=None, tag=None): + self._prior = None + self._seed = None + self._ranges = None + self._tex = None + self._tag = None + + self.name = name + self.value = value + self.nominal_value = deepcopy(value) + self.prior = prior + self.ranges = ranges + self.seed = seed + self.std = std + self.tex = tex + self.tag = tag + + @property + def ranges(self): + return tuple(self._ranges) + + @ranges.setter + def ranges(self, values): + self._ranges = [val for val in values] + + @property + def prior(self): + return self._prior + + @prior.setter + def prior(self, value): + if value is None: + self._prior = PriorsCateg.UNIFORM + else: + assert value in PriorsCateg + self._prior = value + + @property + def seed(self): + if self._seed is None: return self.ranges + return tuple(self._seed) + + @seed.setter + def seed(self, values): + if values is None: return + self._seed = [val for val in values] + + @property + def tex(self): + return r'{0}'.format(self._tex) + + @tex.setter + def tex(self, t): + self._tex = t if t is not None else r'{\rm %s}' % self.name + + @property + def tag(self): + return self._tag + + @tag.setter + def tag(self, t): + if t is None: self._tag = ParamTag.NONE + else: + assert t in ParamTag + self._tag = t + + +class ParamSet(Sequence): + """Container class for a set of parameters.""" + def __init__(self, *args): + param_sequence = [] + for arg in args: + try: + param_sequence.extend(arg) + except TypeError: + param_sequence.append(arg) + + if len(param_sequence) != 0: + # Disallow duplicated params + all_names = [p.name for p in param_sequence] + unique_names = set(all_names) + if len(unique_names) != len(all_names): + duplicates = set([x for x in all_names if all_names.count(x) > 1]) + raise ValueError('Duplicate definitions found for param(s): ' + + ', '.join(str(e) for e in duplicates)) + + # Elements of list must be Param type + assert all([isinstance(x, Param) for x in param_sequence]), \ + 'All params must be of type "Param"' + + self._params = param_sequence + + def __len__(self): + return len(self._params) + + def __getitem__(self, i): + if isinstance(i, int): + return self._params[i] + elif isinstance(i, basestring): + return self._by_name[i] + + def __getattr__(self, attr): + return super(ParamSet, self).__getattribute__(attr) + + def __iter__(self): + return iter(self._params) + + def __str__(self): + o = '\n' + for obj in self._params: + o += '== {0:<15} = {1:<15}, tag={2:<15}\n'.format( + obj.name, obj.value, obj.tag + ) + return o + + @property + def _by_name(self): + return {obj.name: obj for obj in self._params} + + @property + def names(self): + return tuple([obj.name for obj in self._params]) + + @property + def labels(self): + return tuple([obj.tex for obj in self._params]) + + @property + def values(self): + return tuple([obj.value for obj in self._params]) + + @property + def nominal_values(self): + return tuple([obj.nominal_value for obj in self._params]) + + @property + def seeds(self): + return tuple([obj.seed for obj in self._params]) + + @property + def ranges(self): + return tuple([obj.ranges for obj in self._params]) + + @property + def stds(self): + return tuple([obj.std for obj in self._params]) + + @property + def tags(self): + return tuple([obj.tag for obj in self._params]) + + @property + def params(self): + return self._params + + def to_dict(self): + return {obj.name: obj.value for obj in self._params} + + def from_tag(self, tag, values=False, index=False, invert=False): + if values and index: assert 0 + tag = np.atleast_1d(tag) + if not invert: + ps = [(idx, obj) for idx, obj in enumerate(self._params) + if obj.tag in tag] + else: + ps = [(idx, obj) for idx, obj in enumerate(self._params) + if obj.tag not in tag] + if values: + return tuple([io[1].value for io in ps]) + elif index: + return tuple([io[0] for io in ps]) + else: + return ParamSet([io[1] for io in ps]) + + def remove_params(self, params): + rm_paramset = [] + for parm in self.params: + if parm.name not in params.names: + rm_paramset.append(parm) + return ParamSet(rm_paramset) + + def extend(self, p): + param_sequence = self.params + if isinstance(p, Param): + param_sequence.append(p) + elif isinstance(p, ParamSet): + param_sequence.extend(p.params) + return ParamSet(param_sequence) |
