mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-02 14:45:15 +02:00
new gradient handling way nicer
This commit is contained in:
parent
f0ac290eb3
commit
e128059377
6 changed files with 15 additions and 26 deletions
|
|
@ -26,9 +26,6 @@ class Param(ObservableArray, Constrainable):
|
|||
|
||||
:param name: name of the parameter to be printed
|
||||
:param input_array: array which this parameter handles
|
||||
:param gradient: callable with one argument, which is the model of this parameter
|
||||
:param args: additional arguments to gradient
|
||||
:param kwargs: additional keyword arguments to gradient
|
||||
|
||||
You can add/remove constraints by calling constrain on the parameter itself, e.g:
|
||||
|
||||
|
|
@ -156,6 +153,8 @@ class Param(ObservableArray, Constrainable):
|
|||
@property
|
||||
def _parameters_(self):
|
||||
return []
|
||||
def _collect_gradient(self, target):
|
||||
target[:] = self.gradient
|
||||
#===========================================================================
|
||||
# Fixing Parameters:
|
||||
#===========================================================================
|
||||
|
|
|
|||
|
|
@ -75,7 +75,6 @@ class Parameterized(Constrainable, Pickleable, Observable):
|
|||
if not self._has_fixes():
|
||||
self._fixes_ = None
|
||||
self._connect_parameters()
|
||||
self.gradient_mapping = {}
|
||||
self._added_names_ = set()
|
||||
del self._in_init_
|
||||
|
||||
|
|
@ -118,12 +117,10 @@ class Parameterized(Constrainable, Pickleable, Observable):
|
|||
def _has_fixes(self):
|
||||
return hasattr(self, "_fixes_") and self._fixes_ is not None
|
||||
|
||||
def add_parameter(self, param, gradient=None, index=None):
|
||||
def add_parameter(self, param, index=None):
|
||||
"""
|
||||
:param parameters: the parameters to add
|
||||
:type parameters: list of or one :py:class:`GPy.core.param.Param`
|
||||
:param [gradients]: gradients for each param,
|
||||
one gradient per param
|
||||
:param [index]: index of where to put parameters
|
||||
|
||||
|
||||
|
|
@ -167,8 +164,6 @@ class Parameterized(Constrainable, Pickleable, Observable):
|
|||
self._fixes_ = np.ones(self.size+param.size, dtype=bool)
|
||||
self._fixes_[ins:ins+param.size] = fixes_param
|
||||
self.size += param.size
|
||||
if gradient:
|
||||
self.gradient_mapping[param] = gradient
|
||||
self._connect_parameters()
|
||||
# make sure the constraints are pulled over:
|
||||
if hasattr(param, "_constraints_") and param._constraints_ is not None:
|
||||
|
|
@ -206,7 +201,7 @@ class Parameterized(Constrainable, Pickleable, Observable):
|
|||
add self as a listener to the param, such that
|
||||
updates get passed through. See :py:function:``GPy.core.param.Observable.add_observer``
|
||||
"""
|
||||
# will be called as soon as paramters have changed
|
||||
# will be called as soon as parameters have changed
|
||||
pass
|
||||
|
||||
def _connect_parameters(self):
|
||||
|
|
@ -282,13 +277,11 @@ class Parameterized(Constrainable, Pickleable, Observable):
|
|||
self._constraints_,
|
||||
self._parameters_,
|
||||
self._name,
|
||||
#self.gradient_mapping,
|
||||
self._added_names_,
|
||||
]
|
||||
|
||||
def setstate(self, state):
|
||||
self._added_names_ = state.pop()
|
||||
#self.gradient_mapping = state.pop()
|
||||
self._name = state.pop()
|
||||
self._parameters_ = state.pop()
|
||||
self._connect_parameters()
|
||||
|
|
@ -547,6 +540,8 @@ class Parameterized(Constrainable, Pickleable, Observable):
|
|||
return [adjust_name_for_printing(self.name) + "." + xi for x in self._parameters_ for xi in x._parameter_names(add_name=True)]
|
||||
return [xi for x in self._parameters_ for xi in x._parameter_names(add_name=True)]
|
||||
parameter_names = property(_parameter_names, doc="Names for all parameters handled by this parameterization object -- will add hirarchy name entries for printing")
|
||||
def _collect_gradient(self, target):
|
||||
[p._collect_gradient(target[s]) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
||||
@property
|
||||
def flattened_parameters(self):
|
||||
return [xi for x in self._parameters_ for xi in x.flattened_parameters]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue