mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-13 05:52:38 +02:00
observable pattern through and thorugh
This commit is contained in:
parent
26aeb5e1db
commit
65fd6dd24e
11 changed files with 64 additions and 80 deletions
|
|
@ -59,11 +59,10 @@ class ObservableArray(np.ndarray, Observable):
|
|||
else:
|
||||
return s.size != 0
|
||||
|
||||
def __setitem__(self, s, val, update=True):
|
||||
def __setitem__(self, s, val):
|
||||
if self._s_not_empty(s):
|
||||
super(ObservableArray, self).__setitem__(s, val)
|
||||
if update:
|
||||
self._notify_observers()
|
||||
self._notify_observers()
|
||||
|
||||
def __getslice__(self, start, stop):
|
||||
return self.__getitem__(slice(start, stop))
|
||||
|
|
|
|||
|
|
@ -137,8 +137,6 @@ class Param(Constrainable, ObservableArray, Gradcheckable):
|
|||
#===========================================================================
|
||||
def _set_params(self, param, update=True):
|
||||
self.flat = param
|
||||
#self._notify_tied_parameters()
|
||||
self._notify_observers()
|
||||
|
||||
def _get_params(self):
|
||||
return self.flat
|
||||
|
|
@ -161,12 +159,10 @@ class Param(Constrainable, ObservableArray, Gradcheckable):
|
|||
try: new_arr._current_slice_ = s; new_arr._original_ = self.base is new_arr.base
|
||||
except AttributeError: pass # returning 0d array or float, double etc
|
||||
return new_arr
|
||||
def __setitem__(self, s, val, update=True):
|
||||
super(Param, self).__setitem__(s, val, update=update)
|
||||
#self._notify_tied_parameters()
|
||||
if update and self._s_not_empty(s):
|
||||
self._notify_parameters_changed()
|
||||
|
||||
def __setitem__(self, s, val):
|
||||
super(Param, self).__setitem__(s, val)
|
||||
#self._notify_observers()
|
||||
|
||||
#===========================================================================
|
||||
# Index Operations:
|
||||
#===========================================================================
|
||||
|
|
@ -185,6 +181,7 @@ class Param(Constrainable, ObservableArray, Gradcheckable):
|
|||
a = self._realshape_[i] + a
|
||||
internal_offset += a * extended_realshape[i]
|
||||
return internal_offset
|
||||
|
||||
def _raveled_index(self, slice_index=None):
|
||||
# return an index array on the raveled array, which is formed by the current_slice
|
||||
# of this object
|
||||
|
|
@ -354,7 +351,7 @@ class ParamConcatenation(object):
|
|||
val = val._vals()
|
||||
ind = numpy.zeros(sum(self._param_sizes), dtype=bool); ind[s] = True;
|
||||
vals = self._vals(); vals[s] = val; del val
|
||||
[numpy.place(p, ind[ps], vals[ps]) and update and p._notify_parameters_changed()
|
||||
[numpy.place(p, ind[ps], vals[ps]) and update and p._notify_observers()
|
||||
for p, ps in zip(self.params, self._param_slices_)]
|
||||
def _vals(self):
|
||||
return numpy.hstack([p._get_params() for p in self.params])
|
||||
|
|
@ -363,7 +360,7 @@ class ParamConcatenation(object):
|
|||
#===========================================================================
|
||||
def update_all_params(self):
|
||||
for p in self.params:
|
||||
p._notify_parameters_changed()
|
||||
p._notify_observers()
|
||||
|
||||
def constrain(self, constraint, warning=True):
|
||||
[param.constrain(constraint, update=False) for param in self.params]
|
||||
|
|
|
|||
|
|
@ -18,8 +18,13 @@ class Observable(object):
|
|||
def add_observer(self, observer, callble):
|
||||
self._observer_callables_[observer].append(callble)
|
||||
|
||||
def remove_observer(self, observer, callble):
|
||||
del self._observer_callables_[observer][callble]
|
||||
def remove_observer(self, observer, callble=None):
|
||||
if callble is None:
|
||||
del self._observer_callables_[observer]
|
||||
else:
|
||||
self._observer_callables_[observer].remove(callble)
|
||||
if len(self._observer_callables_[observer]) == 0:
|
||||
self.remove_observer(observer)
|
||||
|
||||
def _notify_observers(self):
|
||||
[[callble(self) for callble in callables]
|
||||
|
|
@ -72,9 +77,8 @@ class Parentable(object):
|
|||
return self._direct_parent_._highest_parent_
|
||||
|
||||
def _notify_parameters_changed(self):
|
||||
if self.has_parent():
|
||||
self._direct_parent_._notify_parameters_changed()
|
||||
|
||||
raise NotImplementedError, "shouldnt happen, abstract superclass"
|
||||
|
||||
class Nameable(Parentable):
|
||||
def __init__(self, name, *a, **kw):
|
||||
super(Nameable, self).__init__(*a, **kw)
|
||||
|
|
@ -309,7 +313,7 @@ class Constrainable(Nameable, Indexable):
|
|||
print "WARNING: reconstraining parameters {}".format(self.parameter_names() or self.name)
|
||||
which.add(transform, self._raveled_index())
|
||||
if update:
|
||||
self._notify_parameters_changed()
|
||||
self._notify_observers()
|
||||
|
||||
def _remove_from_index_operations(self, which, transforms):
|
||||
if len(transforms) == 0:
|
||||
|
|
@ -325,7 +329,7 @@ class Constrainable(Nameable, Indexable):
|
|||
return removed
|
||||
|
||||
|
||||
class Parameterizable(Constrainable):
|
||||
class Parameterizable(Constrainable, Observable):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Parameterizable, self).__init__(*args, **kwargs)
|
||||
from GPy.core.parameterization.array_core import ParamList
|
||||
|
|
@ -386,7 +390,7 @@ class Parameterizable(Constrainable):
|
|||
def _set_params(self, params, update=True):
|
||||
# don't overwrite this anymore!
|
||||
import itertools
|
||||
[p._set_params(params[s], update=update) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
||||
[p._set_params(params[s]) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
||||
self.parameters_changed()
|
||||
|
||||
def copy(self):
|
||||
|
|
@ -420,11 +424,10 @@ class Parameterizable(Constrainable):
|
|||
|
||||
return s
|
||||
|
||||
def _notify_parameters_changed(self):
|
||||
def _notify_parameters_changed(self, which):
|
||||
self.parameters_changed()
|
||||
if self.has_parent():
|
||||
self._direct_parent_._notify_parameters_changed()
|
||||
|
||||
self._notify_observers()
|
||||
|
||||
def parameters_changed(self):
|
||||
"""
|
||||
This method gets called when parameters have changed.
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from parameter_core import Constrainable, Pickleable, Parentable, Observable, Pa
|
|||
from transformations import __fixed__
|
||||
from array_core import ParamList
|
||||
|
||||
class Parameterized(Parameterizable, Pickleable, Observable, Gradcheckable):
|
||||
class Parameterized(Parameterizable, Pickleable, Gradcheckable):
|
||||
"""
|
||||
Parameterized class
|
||||
|
||||
|
|
@ -92,6 +92,7 @@ class Parameterized(Parameterizable, Pickleable, Observable, Gradcheckable):
|
|||
self.constraints.update(param.constraints, start)
|
||||
self.priors.update(param.priors, start)
|
||||
self._parameters_.insert(index, param)
|
||||
param.add_observer(self, self._notify_parameters_changed)
|
||||
self.size += param.size
|
||||
else:
|
||||
raise RuntimeError, """Parameter exists already added and no copy made"""
|
||||
|
|
@ -120,6 +121,7 @@ class Parameterized(Parameterizable, Pickleable, Observable, Gradcheckable):
|
|||
del self._parameters_[param._parent_index_]
|
||||
|
||||
param._disconnect_parent()
|
||||
param.remove_observer(self, self._notify_parameters_changed)
|
||||
self.constraints.shift_left(start, param.size)
|
||||
self._connect_fixes()
|
||||
self._connect_parameters()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue