mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-02 14:45:15 +02:00
gradcheck fixes are not easy
This commit is contained in:
parent
c78ddde6de
commit
9af4c34f90
4 changed files with 42 additions and 29 deletions
|
|
@ -348,7 +348,7 @@ class Param(ObservableArray, Constrainable, Gradcheckable, Indexable, Parameteri
|
|||
def _description_str(self):
|
||||
if self.size <= 1: return ["%f" % self]
|
||||
else: return [str(self.shape)]
|
||||
def _parameter_names(self, add_name):
|
||||
def parameter_names(self, add_name=False):
|
||||
return [self.name]
|
||||
@property
|
||||
def flattened_parameters(self):
|
||||
|
|
|
|||
|
|
@ -25,8 +25,10 @@ class Parameterizable(object):
|
|||
from GPy.core.parameterization.array_core import ParamList
|
||||
_parameters_ = ParamList()
|
||||
|
||||
def parameter_names(self):
|
||||
return [p.name for p in self._parameters_]
|
||||
def parameter_names(self, add_name=False):
|
||||
if add_name:
|
||||
return [adjust_name_for_printing(self.name) + "." + xi for x in self._parameters_ for xi in x.parameter_names(add_name=True)]
|
||||
return [xi for x in self._parameters_ for xi in x.parameter_names(add_name=True)]
|
||||
|
||||
def parameters_changed(self):
|
||||
"""
|
||||
|
|
@ -209,7 +211,7 @@ class Constrainable(Nameable, Indexable, Parameterizable):
|
|||
reconstrained = self.unconstrain()
|
||||
self.constraints.add(transform, self._raveled_index())
|
||||
if warning and reconstrained.size > 0:
|
||||
print "WARNING: reconstraining parameters {}".format(self.parameter_names)
|
||||
print "WARNING: reconstraining parameters {}".format(self.parameter_names() or self.name)
|
||||
if update:
|
||||
self._highest_parent_.parameters_changed()
|
||||
|
||||
|
|
|
|||
|
|
@ -434,11 +434,7 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
|
|||
return self._direct_parent_.hirarchy_name() + adjust_name_for_printing(self.name)
|
||||
else:
|
||||
return adjust_name_for_printing(self.name)
|
||||
def _parameter_names(self, add_name=False):
|
||||
if add_name:
|
||||
return [adjust_name_for_printing(self.name) + "." + xi for x in self._parameters_ for xi in x._parameter_names(add_name=True)]
|
||||
return [xi for x in self._parameters_ for xi in x._parameter_names(add_name=True)]
|
||||
parameter_names = property(_parameter_names, doc="Names for all parameters handled by this parameterization object -- will add hirarchy name entries for printing")
|
||||
#parameter_names = property(parameter_names, doc="Names for all parameters handled by this parameterization object -- will add hirarchy name entries for printing")
|
||||
def _collect_gradient(self, target):
|
||||
[p._collect_gradient(target[s]) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
||||
@property
|
||||
|
|
@ -468,7 +464,7 @@ class Parameterized(Constrainable, Pickleable, Observable, Gradcheckable):
|
|||
|
||||
name = adjust_name_for_printing(self.name) + "."
|
||||
constrs = self._constraints_str; ts = self._ties_str
|
||||
desc = self._description_str; names = self.parameter_names
|
||||
desc = self._description_str; names = self.parameter_names()
|
||||
nl = max([len(str(x)) for x in names + [name]])
|
||||
sl = max([len(str(x)) for x in desc + ["Value"]])
|
||||
cl = max([len(str(x)) if x else 0 for x in constrs + ["Constraint"]])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue