parameters changed more structured now, parameters changed goes from bottom to top, when calling _notify_parameters_changed()

This commit is contained in:
Max Zwiessele 2014-02-14 14:40:32 +00:00
parent f21e8e2394
commit 6419401d60
2 changed files with 48 additions and 41 deletions

View file

@ -147,7 +147,7 @@ class Param(ObservableArray, Constrainable, Gradcheckable, Indexable, Parameteri
super(Param, self).__setitem__(s, val, update=update) super(Param, self).__setitem__(s, val, update=update)
#self._notify_tied_parameters() #self._notify_tied_parameters()
if update: if update:
self._highest_parent_.parameters_changed() self._notify_parameters_changed()
#=========================================================================== #===========================================================================
# Index Operations: # Index Operations:
@ -340,14 +340,14 @@ class ParamConcatenation(object):
[numpy.place(p, ind[ps], vals[ps])# and p._notify_tied_parameters() [numpy.place(p, ind[ps], vals[ps])# and p._notify_tied_parameters()
for p, ps in zip(self.params, self._param_slices_)] for p, ps in zip(self.params, self._param_slices_)]
if update: if update:
self.params[0]._highest_parent_.parameters_changed() self.params[0]._notify_parameters_changed()
def _vals(self): def _vals(self):
return numpy.hstack([p._get_params() for p in self.params]) return numpy.hstack([p._get_params() for p in self.params])
#=========================================================================== #===========================================================================
# parameter operations: # parameter operations:
#=========================================================================== #===========================================================================
def update_all_params(self): def update_all_params(self):
self.params[0]._highest_parent_.parameters_changed() self.params[0]._notify_parameters_changed()
def constrain(self, constraint, warning=True): def constrain(self, constraint, warning=True):
[param.constrain(constraint, update=False) for param in self.params] [param.constrain(constraint, update=False) for param in self.params]

View file

@ -19,43 +19,6 @@ class Observable(object):
del self._observers_[observer] del self._observers_[observer]
def _notify_observers(self): def _notify_observers(self):
[callble(self) for callble in self._observers_.itervalues()] [callble(self) for callble in self._observers_.itervalues()]
class Parameterizable(object):
def __init__(self, *args, **kwargs):
from GPy.core.parameterization.array_core import ParamList
_parameters_ = ParamList()
def parameter_names(self, add_name=False):
if add_name:
return [adjust_name_for_printing(self.name) + "." + xi for x in self._parameters_ for xi in x.parameter_names(add_name=True)]
return [xi for x in self._parameters_ for xi in x.parameter_names(add_name=True)]
def _collect_gradient(self, target):
import itertools
[p._collect_gradient(target[s]) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
def _get_params(self):
import numpy as np
# don't overwrite this anymore!
if not self.size:
return np.empty(shape=(0,), dtype=np.float64)
return np.hstack([x._get_params() for x in self._parameters_ if x.size > 0])
def _set_params(self, params, update=True):
# don't overwrite this anymore!
import itertools
[p._set_params(params[s], update=update) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
self.parameters_changed()
def parameters_changed(self):
"""
This method gets called when parameters have changed.
Another way of listening to param changes is to
add self as a listener to the param, such that
updates get passed through. See :py:function:``GPy.core.param.Observable.add_observer``
"""
pass
class Pickleable(object): class Pickleable(object):
def _getstate(self): def _getstate(self):
@ -121,6 +84,50 @@ class Nameable(Parentable):
if self.has_parent(): if self.has_parent():
self._direct_parent_._name_changed(self, from_name) self._direct_parent_._name_changed(self, from_name)
class Parameterizable(Parentable):
def __init__(self, *args, **kwargs):
super(Parameterizable, self).__init__(*args, **kwargs)
from GPy.core.parameterization.array_core import ParamList
_parameters_ = ParamList()
def parameter_names(self, add_name=False):
if add_name:
return [adjust_name_for_printing(self.name) + "." + xi for x in self._parameters_ for xi in x.parameter_names(add_name=True)]
return [xi for x in self._parameters_ for xi in x.parameter_names(add_name=True)]
def _collect_gradient(self, target):
import itertools
[p._collect_gradient(target[s]) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
def _get_params(self):
import numpy as np
# don't overwrite this anymore!
if not self.size:
return np.empty(shape=(0,), dtype=np.float64)
return np.hstack([x._get_params() for x in self._parameters_ if x.size > 0])
def _set_params(self, params, update=True):
# don't overwrite this anymore!
import itertools
[p._set_params(params[s], update=update) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
self.parameters_changed()
def parameters_changed(self):
"""
This method gets called when parameters have changed.
Another way of listening to param changes is to
add self as a listener to the param, such that
updates get passed through. See :py:function:``GPy.core.param.Observable.add_observer``
"""
pass
def _notify_parameters_changed(self):
self.parameters_changed()
if self.has_parent():
self._direct_parent_._notify_parameters_changed()
class Gradcheckable(Parentable): class Gradcheckable(Parentable):
#=========================================================================== #===========================================================================
# Gradchecking # Gradchecking
@ -322,7 +329,7 @@ class Constrainable(Nameable, Indexable, Parameterizable):
print "WARNING: reconstraining parameters {}".format(self.parameter_names() or self.name) print "WARNING: reconstraining parameters {}".format(self.parameter_names() or self.name)
which.add(transform, self._raveled_index()) which.add(transform, self._raveled_index())
if update: if update:
self._highest_parent_.parameters_changed() self._notify_parameters_changed()
def _remove_from_index_operations(self, which, transforms): def _remove_from_index_operations(self, which, transforms):
if len(transforms) == 0: if len(transforms) == 0: