mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 03:52:39 +02:00
parameters changed more structured now, parameters changed goes from bottom to top, when calling _notify_parameters_changed()
This commit is contained in:
parent
f21e8e2394
commit
6419401d60
2 changed files with 48 additions and 41 deletions
|
|
@ -147,7 +147,7 @@ class Param(ObservableArray, Constrainable, Gradcheckable, Indexable, Parameteri
|
||||||
super(Param, self).__setitem__(s, val, update=update)
|
super(Param, self).__setitem__(s, val, update=update)
|
||||||
#self._notify_tied_parameters()
|
#self._notify_tied_parameters()
|
||||||
if update:
|
if update:
|
||||||
self._highest_parent_.parameters_changed()
|
self._notify_parameters_changed()
|
||||||
|
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
# Index Operations:
|
# Index Operations:
|
||||||
|
|
@ -340,14 +340,14 @@ class ParamConcatenation(object):
|
||||||
[numpy.place(p, ind[ps], vals[ps])# and p._notify_tied_parameters()
|
[numpy.place(p, ind[ps], vals[ps])# and p._notify_tied_parameters()
|
||||||
for p, ps in zip(self.params, self._param_slices_)]
|
for p, ps in zip(self.params, self._param_slices_)]
|
||||||
if update:
|
if update:
|
||||||
self.params[0]._highest_parent_.parameters_changed()
|
self.params[0]._notify_parameters_changed()
|
||||||
def _vals(self):
|
def _vals(self):
|
||||||
return numpy.hstack([p._get_params() for p in self.params])
|
return numpy.hstack([p._get_params() for p in self.params])
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
# parameter operations:
|
# parameter operations:
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
def update_all_params(self):
|
def update_all_params(self):
|
||||||
self.params[0]._highest_parent_.parameters_changed()
|
self.params[0]._notify_parameters_changed()
|
||||||
|
|
||||||
def constrain(self, constraint, warning=True):
|
def constrain(self, constraint, warning=True):
|
||||||
[param.constrain(constraint, update=False) for param in self.params]
|
[param.constrain(constraint, update=False) for param in self.params]
|
||||||
|
|
|
||||||
|
|
@ -20,43 +20,6 @@ class Observable(object):
|
||||||
def _notify_observers(self):
|
def _notify_observers(self):
|
||||||
[callble(self) for callble in self._observers_.itervalues()]
|
[callble(self) for callble in self._observers_.itervalues()]
|
||||||
|
|
||||||
class Parameterizable(object):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
from GPy.core.parameterization.array_core import ParamList
|
|
||||||
_parameters_ = ParamList()
|
|
||||||
|
|
||||||
def parameter_names(self, add_name=False):
|
|
||||||
if add_name:
|
|
||||||
return [adjust_name_for_printing(self.name) + "." + xi for x in self._parameters_ for xi in x.parameter_names(add_name=True)]
|
|
||||||
return [xi for x in self._parameters_ for xi in x.parameter_names(add_name=True)]
|
|
||||||
|
|
||||||
def _collect_gradient(self, target):
|
|
||||||
import itertools
|
|
||||||
[p._collect_gradient(target[s]) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
|
||||||
|
|
||||||
def _get_params(self):
|
|
||||||
import numpy as np
|
|
||||||
# don't overwrite this anymore!
|
|
||||||
if not self.size:
|
|
||||||
return np.empty(shape=(0,), dtype=np.float64)
|
|
||||||
return np.hstack([x._get_params() for x in self._parameters_ if x.size > 0])
|
|
||||||
|
|
||||||
def _set_params(self, params, update=True):
|
|
||||||
# don't overwrite this anymore!
|
|
||||||
import itertools
|
|
||||||
[p._set_params(params[s], update=update) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
|
||||||
self.parameters_changed()
|
|
||||||
|
|
||||||
|
|
||||||
def parameters_changed(self):
|
|
||||||
"""
|
|
||||||
This method gets called when parameters have changed.
|
|
||||||
Another way of listening to param changes is to
|
|
||||||
add self as a listener to the param, such that
|
|
||||||
updates get passed through. See :py:function:``GPy.core.param.Observable.add_observer``
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
class Pickleable(object):
|
class Pickleable(object):
|
||||||
def _getstate(self):
|
def _getstate(self):
|
||||||
"""
|
"""
|
||||||
|
|
@ -121,6 +84,50 @@ class Nameable(Parentable):
|
||||||
if self.has_parent():
|
if self.has_parent():
|
||||||
self._direct_parent_._name_changed(self, from_name)
|
self._direct_parent_._name_changed(self, from_name)
|
||||||
|
|
||||||
|
|
||||||
|
class Parameterizable(Parentable):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(Parameterizable, self).__init__(*args, **kwargs)
|
||||||
|
from GPy.core.parameterization.array_core import ParamList
|
||||||
|
_parameters_ = ParamList()
|
||||||
|
|
||||||
|
def parameter_names(self, add_name=False):
|
||||||
|
if add_name:
|
||||||
|
return [adjust_name_for_printing(self.name) + "." + xi for x in self._parameters_ for xi in x.parameter_names(add_name=True)]
|
||||||
|
return [xi for x in self._parameters_ for xi in x.parameter_names(add_name=True)]
|
||||||
|
|
||||||
|
def _collect_gradient(self, target):
|
||||||
|
import itertools
|
||||||
|
[p._collect_gradient(target[s]) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
||||||
|
|
||||||
|
def _get_params(self):
|
||||||
|
import numpy as np
|
||||||
|
# don't overwrite this anymore!
|
||||||
|
if not self.size:
|
||||||
|
return np.empty(shape=(0,), dtype=np.float64)
|
||||||
|
return np.hstack([x._get_params() for x in self._parameters_ if x.size > 0])
|
||||||
|
|
||||||
|
def _set_params(self, params, update=True):
|
||||||
|
# don't overwrite this anymore!
|
||||||
|
import itertools
|
||||||
|
[p._set_params(params[s], update=update) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
||||||
|
self.parameters_changed()
|
||||||
|
|
||||||
|
def parameters_changed(self):
|
||||||
|
"""
|
||||||
|
This method gets called when parameters have changed.
|
||||||
|
Another way of listening to param changes is to
|
||||||
|
add self as a listener to the param, such that
|
||||||
|
updates get passed through. See :py:function:``GPy.core.param.Observable.add_observer``
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _notify_parameters_changed(self):
|
||||||
|
self.parameters_changed()
|
||||||
|
if self.has_parent():
|
||||||
|
self._direct_parent_._notify_parameters_changed()
|
||||||
|
|
||||||
|
|
||||||
class Gradcheckable(Parentable):
|
class Gradcheckable(Parentable):
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
# Gradchecking
|
# Gradchecking
|
||||||
|
|
@ -322,7 +329,7 @@ class Constrainable(Nameable, Indexable, Parameterizable):
|
||||||
print "WARNING: reconstraining parameters {}".format(self.parameter_names() or self.name)
|
print "WARNING: reconstraining parameters {}".format(self.parameter_names() or self.name)
|
||||||
which.add(transform, self._raveled_index())
|
which.add(transform, self._raveled_index())
|
||||||
if update:
|
if update:
|
||||||
self._highest_parent_.parameters_changed()
|
self._notify_parameters_changed()
|
||||||
|
|
||||||
def _remove_from_index_operations(self, which, transforms):
|
def _remove_from_index_operations(self, which, transforms):
|
||||||
if len(transforms) == 0:
|
if len(transforms) == 0:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue