diff options
Diffstat (limited to 'utils/misc.py')
| -rw-r--r-- | utils/misc.py | 44 |
1 files changed, 41 insertions, 3 deletions
diff --git a/utils/misc.py b/utils/misc.py index 5c3eb2e..c54d25c 100644 --- a/utils/misc.py +++ b/utils/misc.py @@ -16,21 +16,23 @@ import multiprocessing import numpy as np -from utils.enums import Likelihood +from utils.enums import Likelihood, ParamTag class Param(object): """Parameter class to store parameters. """ - def __init__(self, name, value, ranges, std=None, tex=None): + def __init__(self, name, value, ranges, std=None, tex=None, tag=None): self._ranges = None self._tex = None + self._tag = None self.name = name self.value = value self.ranges = ranges self.std = std self.tex = tex + self.tag = tag @property def ranges(self): @@ -42,12 +44,23 @@ class Param(object): @property def tex(self): - return r'{0}'.format(self.tex) + return r'{0}'.format(self._tex) @tex.setter def tex(self, t): self._tex = t if t is not None else r'{\rm %s}' % self.name + @property + def tag(self): + return self._tag + + @tex.setter + def tag(self, t): + if t is None: self._tag = ParamTag.NONE + else: + assert t in ParamTag + self._tag = t + class ParamSet(Sequence): """Container class for a set of parameters. @@ -105,6 +118,10 @@ class ParamSet(Sequence): return tuple([obj.name for obj in self._params]) @property + def labels(self): + return tuple([obj.tex for obj in self._params]) + + @property def values(self): return tuple([obj.value for obj in self._params]) @@ -117,12 +134,23 @@ class ParamSet(Sequence): return tuple([obj.std for obj in self._params]) @property + def tags(self): + return tuple([obj.tag for obj in self._params]) + + @property def params(self): return self._params def to_dict(self): return {obj.name: obj.value for obj in self._params} + def from_tag(self, tag): + return tuple([obj for obj in self._params if obj.tag is tag]) + + def idx_from_tag(self, tag): + return tuple([idx for idx, obj in enumerate(self._params) + if obj.tag is tag]) + def gen_outfile_name(args): """Generate a name for the output file based on the input args. @@ -222,6 +250,16 @@ def enum_parse(s, c): return c[s.upper()] +def make_dir(outfile): + try: + os.makedirs(outfile[:-len(os.path.basename(outfile))]) + except OSError as exc: # Python >2.5 + if exc.errno == errno.EEXIST and os.path.isdir(outfile[:-len(os.path.basename(outfile))]): + pass + else: + raise + + def thread_type(t): if t.lower() == 'max': return multiprocessing.cpu_count() |
