regexp now on all parameters

This commit is contained in:
Max Zwiessele 2014-02-19 16:54:25 +00:00
parent fd0dd8df85
commit 1c3fe0c51e
3 changed files with 27 additions and 32 deletions

View file

@ -221,7 +221,9 @@ class Param(ObservableArray, Constrainable, Gradcheckable, Indexable, Parentable
def _description_str(self): def _description_str(self):
if self.size <= 1: return ["%f" % self] if self.size <= 1: return ["%f" % self]
else: return [str(self.shape)] else: return [str(self.shape)]
def parameter_names(self, add_name=False): def parameter_names(self, add_self=False, adjust_for_printing=False):
if adjust_for_printing:
return [adjust_name_for_printing(self.name)]
return [self.name] return [self.name]
@property @property
def flattened_parameters(self): def flattened_parameters(self):
@ -240,7 +242,7 @@ class Param(ObservableArray, Constrainable, Gradcheckable, Indexable, Parentable
return [t._short() for t in self._tied_to_] or [''] return [t._short() for t in self._tied_to_] or ['']
def __repr__(self, *args, **kwargs): def __repr__(self, *args, **kwargs):
name = "\033[1m{x:s}\033[0;0m:\n".format( name = "\033[1m{x:s}\033[0;0m:\n".format(
x=self.name_hirarchical) x=self.hirarchy_name())
return name + super(Param, self).__repr__(*args, **kwargs) return name + super(Param, self).__repr__(*args, **kwargs)
def _ties_for(self, rav_index): def _ties_for(self, rav_index):
# size = sum(p.size for p in self._tied_to_) # size = sum(p.size for p in self._tied_to_)
@ -274,7 +276,7 @@ class Param(ObservableArray, Constrainable, Gradcheckable, Indexable, Parentable
gen = map(lambda x: " ".join(map(str, x)), gen) gen = map(lambda x: " ".join(map(str, x)), gen)
return reduce(lambda a, b:max(a, len(b)), gen, len(header)) return reduce(lambda a, b:max(a, len(b)), gen, len(header))
def _max_len_values(self): def _max_len_values(self):
return reduce(lambda a, b:max(a, len("{x:=.{0}g}".format(__precision__, x=b))), self.flat, len(self.name_hirarchical)) return reduce(lambda a, b:max(a, len("{x:=.{0}g}".format(__precision__, x=b))), self.flat, len(self.hirarchy_name()))
def _max_len_index(self, ind): def _max_len_index(self, ind):
return reduce(lambda a, b:max(a, len(str(b))), ind, len(__index_name__)) return reduce(lambda a, b:max(a, len(str(b))), ind, len(__index_name__))
def _short(self): def _short(self):
@ -302,8 +304,8 @@ class Param(ObservableArray, Constrainable, Gradcheckable, Indexable, Parentable
if lp is None: lp = self._max_len_names(prirs, __tie_name__) if lp is None: lp = self._max_len_names(prirs, __tie_name__)
sep = '-' sep = '-'
header_format = " {i:{5}^{2}s} | \033[1m{x:{5}^{1}s}\033[0;0m | {c:{5}^{0}s} | {p:{5}^{4}s} | {t:{5}^{3}s}" header_format = " {i:{5}^{2}s} | \033[1m{x:{5}^{1}s}\033[0;0m | {c:{5}^{0}s} | {p:{5}^{4}s} | {t:{5}^{3}s}"
if only_name: header = header_format.format(lc, lx, li, lt, lp, ' ', x=self.name_hirarchical, c=sep*lc, i=sep*li, t=sep*lt, p=sep*lp) # nice header for printing if only_name: header = header_format.format(lc, lx, li, lt, lp, ' ', x=self.hirarchy_name(), c=sep*lc, i=sep*li, t=sep*lt, p=sep*lp) # nice header for printing
else: header = header_format.format(lc, lx, li, lt, lp, ' ', x=self.name_hirarchical, c=__constraints_name__, i=__index_name__, t=__tie_name__, p=__priors_name__) # nice header for printing else: header = header_format.format(lc, lx, li, lt, lp, ' ', x=self.hirarchy_name(), c=__constraints_name__, i=__index_name__, t=__tie_name__, p=__priors_name__) # nice header for printing
if not ties: ties = itertools.cycle(['']) if not ties: ties = itertools.cycle([''])
return "\n".join([header] + [" {i!s:^{3}s} | {x: >{1}.{2}g} | {c:^{0}s} | {p:^{5}s} | {t:^{4}s} ".format(lc, lx, __precision__, li, lt, lp, x=x, c=" ".join(map(str, c)), p=" ".join(map(str, p)), t=(t or ''), i=i) for i, x, c, t, p in itertools.izip(indices, vals, constr_matrix, ties, prirs)]) # return all the constraints with right indices return "\n".join([header] + [" {i!s:^{3}s} | {x: >{1}.{2}g} | {c:^{0}s} | {p:^{5}s} | {t:^{4}s} ".format(lc, lx, __precision__, li, lt, lp, x=x, c=" ".join(map(str, c)), p=" ".join(map(str, p)), t=(t or ''), i=i) for i, x, c, t, p in itertools.izip(indices, vals, constr_matrix, ties, prirs)]) # return all the constraints with right indices
# except: return super(Param, self).__str__() # except: return super(Param, self).__str__()

View file

@ -88,10 +88,12 @@ class Nameable(Parentable):
self._name = name self._name = name
if self.has_parent(): if self.has_parent():
self._direct_parent_._name_changed(self, from_name) self._direct_parent_._name_changed(self, from_name)
def hirarchy_name(self): def hirarchy_name(self, adjust_for_printing=True):
if adjust_for_printing: adjust = lambda x: adjust_name_for_printing(x)
else: adjust = lambda x: x
if self.has_parent(): if self.has_parent():
return self._direct_parent_.hirarchy_name() + "." + adjust_name_for_printing(self.name) return self._direct_parent_.hirarchy_name() + "." + adjust(self.name)
return adjust_name_for_printing(self.name) return adjust(self.name)
class Parameterizable(Parentable): class Parameterizable(Parentable):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -100,10 +102,13 @@ class Parameterizable(Parentable):
_parameters_ = ParamList() _parameters_ = ParamList()
self._added_names_ = set() self._added_names_ = set()
def parameter_names(self, add_name=False): def parameter_names(self, add_self=False, adjust_for_printing=False, recursive=True):
if add_name: if adjust_for_printing: adjust = lambda x: adjust_name_for_printing(x)
return [adjust_name_for_printing(self.name) + "." + xi for x in self._parameters_ for xi in x.parameter_names(add_name=True)] else: adjust = lambda x: x
return [xi for x in self._parameters_ for xi in x.parameter_names(add_name=True)] if recursive: names = [xi for x in self._parameters_ for xi in x.parameter_names(add_self=True, adjust_for_printing=adjust_for_printing)]
else: names = [adjust(x.name) for x in self._parameters_]
if add_self: names = map(lambda x: adjust(self.name) + "." + x, names)
return names
def _add_parameter_name(self, param): def _add_parameter_name(self, param):
pname = adjust_name_for_printing(param.name) pname = adjust_name_for_printing(param.name)

View file

@ -305,13 +305,11 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable, Parame
""" """
if not isinstance(regexp, _pattern_type): regexp = compile(regexp) if not isinstance(regexp, _pattern_type): regexp = compile(regexp)
found_params = [] found_params = []
for p in self.flattened_parameters: for n, p in itertools.izip(self.parameter_names(False, False, True), self.flattened_parameters):
if regexp.match(p.hirarchy_name()) is not None: if regexp.match(n) is not None:
found_params.append(p) found_params.append(p)
if isinstance(p, Parameterized):
found_params.extend(p.grep_param_names(regexp))
return found_params return found_params
return [param for param in self._parameters_ if regexp.match(param.name) is not None]
def __getitem__(self, name, paramlist=None): def __getitem__(self, name, paramlist=None):
if paramlist is None: if paramlist is None:
paramlist = self.grep_param_names(name) paramlist = self.grep_param_names(name)
@ -323,26 +321,16 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable, Parame
return ParamConcatenation(paramlist) return ParamConcatenation(paramlist)
return paramlist[-1] return paramlist[-1]
return ParamConcatenation(paramlist) return ParamConcatenation(paramlist)
def __setitem__(self, name, value, paramlist=None): def __setitem__(self, name, value, paramlist=None):
try: param = self.__getitem__(name, paramlist) try: param = self.__getitem__(name, paramlist)
except AttributeError as a: raise a except AttributeError as a: raise a
param[:] = value param[:] = value
# def __getattr__(self, name):
# return self.__getitem__(name)
# def __getattribute__(self, name):
# #try:
# return object.__getattribute__(self, name)
# except AttributeError:
# _, a, tb = sys.exc_info()
# try:
# return self.__getitem__(name)
# except AttributeError:
# raise AttributeError, a.message, tb
def __setattr__(self, name, val): def __setattr__(self, name, val):
# override the default behaviour, if setting a param, so broadcasting can by used # override the default behaviour, if setting a param, so broadcasting can by used
if hasattr(self, "_parameters_"): if hasattr(self, '_parameters_'):
paramlist = self.grep_param_names(name) pnames = self.parameter_names(False, adjust_for_printing=True, recursive=False)
if len(paramlist) == 1: self.__setitem__(name, val, paramlist); return if name in pnames: self._parameters_[pnames.index(name)][:] = val; return
object.__setattr__(self, name, val); object.__setattr__(self, name, val);
#=========================================================================== #===========================================================================
# Printing: # Printing: