aboutsummaryrefslogtreecommitdiffstats
path: root/golemflavor/param.py
diff options
context:
space:
mode:
authorShivesh Mandalia <shivesh.mandalia@outlook.com>2020-02-28 18:39:45 +0000
committerShivesh Mandalia <shivesh.mandalia@outlook.com>2020-02-28 18:39:45 +0000
commit402f8b53dd892b8fd44ae5ad45eac91b5f6b3750 (patch)
treeb619c6efb0eb303e164bbd27691cdd9f8fce36a2 /golemflavor/param.py
parent3a5a6c658e45402d413970e8d273a656ed74dcf5 (diff)
downloadGolemFlavor-402f8b53dd892b8fd44ae5ad45eac91b5f6b3750.tar.gz
GolemFlavor-402f8b53dd892b8fd44ae5ad45eac91b5f6b3750.zip
reogranise into a python package
Diffstat (limited to 'golemflavor/param.py')
-rw-r--r--golemflavor/param.py213
1 files changed, 213 insertions, 0 deletions
diff --git a/golemflavor/param.py b/golemflavor/param.py
new file mode 100644
index 0000000..2378758
--- /dev/null
+++ b/golemflavor/param.py
@@ -0,0 +1,213 @@
+# author : S. Mandalia
+# s.p.mandalia@qmul.ac.uk
+#
+# date : April 19, 2018
+
+"""
+Param class and functions for the BSM flavour ratio analysis
+"""
+
+from __future__ import absolute_import, division
+
+import sys
+
+from collections import Sequence
+from copy import deepcopy
+
+import numpy as np
+
+from utils.fr import fr_to_angles
+from utils.enums import DataType, Likelihood, ParamTag, PriorsCateg
+
+
+class Param(object):
+ """Parameter class to store parameters."""
+ def __init__(self, name, value, ranges, prior=None, seed=None, std=None,
+ tex=None, tag=None):
+ self._prior = None
+ self._seed = None
+ self._ranges = None
+ self._tex = None
+ self._tag = None
+
+ self.name = name
+ self.value = value
+ self.nominal_value = deepcopy(value)
+ self.prior = prior
+ self.ranges = ranges
+ self.seed = seed
+ self.std = std
+ self.tex = tex
+ self.tag = tag
+
+ @property
+ def ranges(self):
+ return tuple(self._ranges)
+
+ @ranges.setter
+ def ranges(self, values):
+ self._ranges = [val for val in values]
+
+ @property
+ def prior(self):
+ return self._prior
+
+ @prior.setter
+ def prior(self, value):
+ if value is None:
+ self._prior = PriorsCateg.UNIFORM
+ else:
+ assert value in PriorsCateg
+ self._prior = value
+
+ @property
+ def seed(self):
+ if self._seed is None: return self.ranges
+ return tuple(self._seed)
+
+ @seed.setter
+ def seed(self, values):
+ if values is None: return
+ self._seed = [val for val in values]
+
+ @property
+ def tex(self):
+ return r'{0}'.format(self._tex)
+
+ @tex.setter
+ def tex(self, t):
+ self._tex = t if t is not None else r'{\rm %s}' % self.name
+
+ @property
+ def tag(self):
+ return self._tag
+
+ @tag.setter
+ def tag(self, t):
+ if t is None: self._tag = ParamTag.NONE
+ else:
+ assert t in ParamTag
+ self._tag = t
+
+
+class ParamSet(Sequence):
+ """Container class for a set of parameters."""
+ def __init__(self, *args):
+ param_sequence = []
+ for arg in args:
+ try:
+ param_sequence.extend(arg)
+ except TypeError:
+ param_sequence.append(arg)
+
+ if len(param_sequence) != 0:
+ # Disallow duplicated params
+ all_names = [p.name for p in param_sequence]
+ unique_names = set(all_names)
+ if len(unique_names) != len(all_names):
+ duplicates = set([x for x in all_names if all_names.count(x) > 1])
+ raise ValueError('Duplicate definitions found for param(s): ' +
+ ', '.join(str(e) for e in duplicates))
+
+ # Elements of list must be Param type
+ assert all([isinstance(x, Param) for x in param_sequence]), \
+ 'All params must be of type "Param"'
+
+ self._params = param_sequence
+
+ def __len__(self):
+ return len(self._params)
+
+ def __getitem__(self, i):
+ if isinstance(i, int):
+ return self._params[i]
+ elif isinstance(i, basestring):
+ return self._by_name[i]
+
+ def __getattr__(self, attr):
+ return super(ParamSet, self).__getattribute__(attr)
+
+ def __iter__(self):
+ return iter(self._params)
+
+ def __str__(self):
+ o = '\n'
+ for obj in self._params:
+ o += '== {0:<15} = {1:<15}, tag={2:<15}\n'.format(
+ obj.name, obj.value, obj.tag
+ )
+ return o
+
+ @property
+ def _by_name(self):
+ return {obj.name: obj for obj in self._params}
+
+ @property
+ def names(self):
+ return tuple([obj.name for obj in self._params])
+
+ @property
+ def labels(self):
+ return tuple([obj.tex for obj in self._params])
+
+ @property
+ def values(self):
+ return tuple([obj.value for obj in self._params])
+
+ @property
+ def nominal_values(self):
+ return tuple([obj.nominal_value for obj in self._params])
+
+ @property
+ def seeds(self):
+ return tuple([obj.seed for obj in self._params])
+
+ @property
+ def ranges(self):
+ return tuple([obj.ranges for obj in self._params])
+
+ @property
+ def stds(self):
+ return tuple([obj.std for obj in self._params])
+
+ @property
+ def tags(self):
+ return tuple([obj.tag for obj in self._params])
+
+ @property
+ def params(self):
+ return self._params
+
+ def to_dict(self):
+ return {obj.name: obj.value for obj in self._params}
+
+ def from_tag(self, tag, values=False, index=False, invert=False):
+ if values and index: assert 0
+ tag = np.atleast_1d(tag)
+ if not invert:
+ ps = [(idx, obj) for idx, obj in enumerate(self._params)
+ if obj.tag in tag]
+ else:
+ ps = [(idx, obj) for idx, obj in enumerate(self._params)
+ if obj.tag not in tag]
+ if values:
+ return tuple([io[1].value for io in ps])
+ elif index:
+ return tuple([io[0] for io in ps])
+ else:
+ return ParamSet([io[1] for io in ps])
+
+ def remove_params(self, params):
+ rm_paramset = []
+ for parm in self.params:
+ if parm.name not in params.names:
+ rm_paramset.append(parm)
+ return ParamSet(rm_paramset)
+
+ def extend(self, p):
+ param_sequence = self.params
+ if isinstance(p, Param):
+ param_sequence.append(p)
+ elif isinstance(p, ParamSet):
+ param_sequence.extend(p.params)
+ return ParamSet(param_sequence)