gradcheck fixes are not easy

This commit is contained in:
Max Zwiessele 2014-02-13 21:14:08 +00:00
parent c78ddde6de
commit 9af4c34f90
4 changed files with 42 additions and 29 deletions

View file

@ -348,7 +348,7 @@ class Param(ObservableArray, Constrainable, Gradcheckable, Indexable, Parameteri
def _description_str(self):
if self.size <= 1: return ["%f" % self]
else: return [str(self.shape)]
def _parameter_names(self, add_name):
def parameter_names(self, add_name=False):
return [self.name]
@property
def flattened_parameters(self):

View file

@ -25,8 +25,10 @@ class Parameterizable(object):
from GPy.core.parameterization.array_core import ParamList
_parameters_ = ParamList()
def parameter_names(self):
return [p.name for p in self._parameters_]
def parameter_names(self, add_name=False):
if add_name:
return [adjust_name_for_printing(self.name) + "." + xi for x in self._parameters_ for xi in x.parameter_names(add_name=True)]
return [xi for x in self._parameters_ for xi in x.parameter_names(add_name=True)]
def parameters_changed(self):
"""
@ -209,7 +211,7 @@ class Constrainable(Nameable, Indexable, Parameterizable):
reconstrained = self.unconstrain()
self.constraints.add(transform, self._raveled_index())
if warning and reconstrained.size > 0:
print "WARNING: reconstraining parameters {}".format(self.parameter_names)
print "WARNING: reconstraining parameters {}".format(self.parameter_names() or self.name)
if update:
self._highest_parent_.parameters_changed()

View file

@ -434,11 +434,7 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
return self._direct_parent_.hirarchy_name() + adjust_name_for_printing(self.name)
else:
return adjust_name_for_printing(self.name)
def _parameter_names(self, add_name=False):
if add_name:
return [adjust_name_for_printing(self.name) + "." + xi for x in self._parameters_ for xi in x._parameter_names(add_name=True)]
return [xi for x in self._parameters_ for xi in x._parameter_names(add_name=True)]
parameter_names = property(_parameter_names, doc="Names for all parameters handled by this parameterization object -- will add hirarchy name entries for printing")
#parameter_names = property(parameter_names, doc="Names for all parameters handled by this parameterization object -- will add hirarchy name entries for printing")
def _collect_gradient(self, target):
[p._collect_gradient(target[s]) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
@property
@ -468,7 +464,7 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
name = adjust_name_for_printing(self.name) + "."
constrs = self._constraints_str; ts = self._ties_str
desc = self._description_str; names = self.parameter_names
desc = self._description_str; names = self.parameter_names()
nl = max([len(str(x)) for x in names + [name]])
sl = max([len(str(x)) for x in desc + ["Value"]])
cl = max([len(str(x)) if x else 0 for x in constrs + ["Constraint"]])