new gradient handling way nicer

This commit is contained in:
Max Zwiessele 2014-01-24 15:07:28 +00:00
parent f0ac290eb3
commit e128059377
6 changed files with 15 additions and 26 deletions

View file

@ -26,9 +26,6 @@ class Param(ObservableArray, Constrainable):
:param name: name of the parameter to be printed
:param input_array: array which this parameter handles
:param gradient: callable with one argument, which is the model of this parameter
:param args: additional arguments to gradient
:param kwargs: additional keyword arguments to gradient
You can add/remove constraints by calling constrain on the parameter itself, e.g:
@ -156,6 +153,8 @@ class Param(ObservableArray, Constrainable):
@property
def _parameters_(self):
return []
def _collect_gradient(self, target):
target[:] = self.gradient
#===========================================================================
# Fixing Parameters:
#===========================================================================

View file

@ -75,7 +75,6 @@ class Parameterized(Constrainable, Pickleable, Observable):
if not self._has_fixes():
self._fixes_ = None
self._connect_parameters()
self.gradient_mapping = {}
self._added_names_ = set()
del self._in_init_
@ -118,12 +117,10 @@ class Parameterized(Constrainable, Pickleable, Observable):
def _has_fixes(self):
return hasattr(self, "_fixes_") and self._fixes_ is not None
def add_parameter(self, param, gradient=None, index=None):
def add_parameter(self, param, index=None):
"""
:param parameters: the parameters to add
:type parameters: list of or one :py:class:`GPy.core.param.Param`
:param [gradients]: gradients for each param,
one gradient per param
:param [index]: index of where to put parameters
@ -167,8 +164,6 @@ class Parameterized(Constrainable, Pickleable, Observable):
self._fixes_ = np.ones(self.size+param.size, dtype=bool)
self._fixes_[ins:ins+param.size] = fixes_param
self.size += param.size
if gradient:
self.gradient_mapping[param] = gradient
self._connect_parameters()
# make sure the constraints are pulled over:
if hasattr(param, "_constraints_") and param._constraints_ is not None:
@ -206,7 +201,7 @@ class Parameterized(Constrainable, Pickleable, Observable):
add self as a listener to the param, such that
updates get passed through. See :py:function:``GPy.core.param.Observable.add_observer``
"""
# will be called as soon as paramters have changed
# will be called as soon as parameters have changed
pass
def _connect_parameters(self):
@ -282,13 +277,11 @@ class Parameterized(Constrainable, Pickleable, Observable):
self._constraints_,
self._parameters_,
self._name,
#self.gradient_mapping,
self._added_names_,
]
def setstate(self, state):
self._added_names_ = state.pop()
#self.gradient_mapping = state.pop()
self._name = state.pop()
self._parameters_ = state.pop()
self._connect_parameters()
@ -547,6 +540,8 @@ class Parameterized(Constrainable, Pickleable, Observable):
return [adjust_name_for_printing(self.name) + "." + xi for x in self._parameters_ for xi in x._parameter_names(add_name=True)]
return [xi for x in self._parameters_ for xi in x._parameter_names(add_name=True)]
parameter_names = property(_parameter_names, doc="Names for all parameters handled by this parameterization object -- will add hirarchy name entries for printing")
def _collect_gradient(self, target):
[p._collect_gradient(target[s]) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
@property
def flattened_parameters(self):
return [xi for x in self._parameters_ for xi in x.flattened_parameters]