observable pattern through and thorugh

This commit is contained in:
Max Zwiessele 2014-02-26 15:46:14 +00:00
parent 26aeb5e1db
commit 65fd6dd24e
11 changed files with 64 additions and 80 deletions

View file

@ -59,10 +59,9 @@ class ObservableArray(np.ndarray, Observable):
else: else:
return s.size != 0 return s.size != 0
def __setitem__(self, s, val, update=True): def __setitem__(self, s, val):
if self._s_not_empty(s): if self._s_not_empty(s):
super(ObservableArray, self).__setitem__(s, val) super(ObservableArray, self).__setitem__(s, val)
if update:
self._notify_observers() self._notify_observers()
def __getslice__(self, start, stop): def __getslice__(self, start, stop):

View file

@ -137,8 +137,6 @@ class Param(Constrainable, ObservableArray, Gradcheckable):
#=========================================================================== #===========================================================================
def _set_params(self, param, update=True): def _set_params(self, param, update=True):
self.flat = param self.flat = param
#self._notify_tied_parameters()
self._notify_observers()
def _get_params(self): def _get_params(self):
return self.flat return self.flat
@ -161,11 +159,9 @@ class Param(Constrainable, ObservableArray, Gradcheckable):
try: new_arr._current_slice_ = s; new_arr._original_ = self.base is new_arr.base try: new_arr._current_slice_ = s; new_arr._original_ = self.base is new_arr.base
except AttributeError: pass # returning 0d array or float, double etc except AttributeError: pass # returning 0d array or float, double etc
return new_arr return new_arr
def __setitem__(self, s, val, update=True): def __setitem__(self, s, val):
super(Param, self).__setitem__(s, val, update=update) super(Param, self).__setitem__(s, val)
#self._notify_tied_parameters() #self._notify_observers()
if update and self._s_not_empty(s):
self._notify_parameters_changed()
#=========================================================================== #===========================================================================
# Index Operations: # Index Operations:
@ -185,6 +181,7 @@ class Param(Constrainable, ObservableArray, Gradcheckable):
a = self._realshape_[i] + a a = self._realshape_[i] + a
internal_offset += a * extended_realshape[i] internal_offset += a * extended_realshape[i]
return internal_offset return internal_offset
def _raveled_index(self, slice_index=None): def _raveled_index(self, slice_index=None):
# return an index array on the raveled array, which is formed by the current_slice # return an index array on the raveled array, which is formed by the current_slice
# of this object # of this object
@ -354,7 +351,7 @@ class ParamConcatenation(object):
val = val._vals() val = val._vals()
ind = numpy.zeros(sum(self._param_sizes), dtype=bool); ind[s] = True; ind = numpy.zeros(sum(self._param_sizes), dtype=bool); ind[s] = True;
vals = self._vals(); vals[s] = val; del val vals = self._vals(); vals[s] = val; del val
[numpy.place(p, ind[ps], vals[ps]) and update and p._notify_parameters_changed() [numpy.place(p, ind[ps], vals[ps]) and update and p._notify_observers()
for p, ps in zip(self.params, self._param_slices_)] for p, ps in zip(self.params, self._param_slices_)]
def _vals(self): def _vals(self):
return numpy.hstack([p._get_params() for p in self.params]) return numpy.hstack([p._get_params() for p in self.params])
@ -363,7 +360,7 @@ class ParamConcatenation(object):
#=========================================================================== #===========================================================================
def update_all_params(self): def update_all_params(self):
for p in self.params: for p in self.params:
p._notify_parameters_changed() p._notify_observers()
def constrain(self, constraint, warning=True): def constrain(self, constraint, warning=True):
[param.constrain(constraint, update=False) for param in self.params] [param.constrain(constraint, update=False) for param in self.params]

View file

@ -18,8 +18,13 @@ class Observable(object):
def add_observer(self, observer, callble): def add_observer(self, observer, callble):
self._observer_callables_[observer].append(callble) self._observer_callables_[observer].append(callble)
def remove_observer(self, observer, callble): def remove_observer(self, observer, callble=None):
del self._observer_callables_[observer][callble] if callble is None:
del self._observer_callables_[observer]
else:
self._observer_callables_[observer].remove(callble)
if len(self._observer_callables_[observer]) == 0:
self.remove_observer(observer)
def _notify_observers(self): def _notify_observers(self):
[[callble(self) for callble in callables] [[callble(self) for callble in callables]
@ -72,8 +77,7 @@ class Parentable(object):
return self._direct_parent_._highest_parent_ return self._direct_parent_._highest_parent_
def _notify_parameters_changed(self): def _notify_parameters_changed(self):
if self.has_parent(): raise NotImplementedError, "shouldnt happen, abstract superclass"
self._direct_parent_._notify_parameters_changed()
class Nameable(Parentable): class Nameable(Parentable):
def __init__(self, name, *a, **kw): def __init__(self, name, *a, **kw):
@ -309,7 +313,7 @@ class Constrainable(Nameable, Indexable):
print "WARNING: reconstraining parameters {}".format(self.parameter_names() or self.name) print "WARNING: reconstraining parameters {}".format(self.parameter_names() or self.name)
which.add(transform, self._raveled_index()) which.add(transform, self._raveled_index())
if update: if update:
self._notify_parameters_changed() self._notify_observers()
def _remove_from_index_operations(self, which, transforms): def _remove_from_index_operations(self, which, transforms):
if len(transforms) == 0: if len(transforms) == 0:
@ -325,7 +329,7 @@ class Constrainable(Nameable, Indexable):
return removed return removed
class Parameterizable(Constrainable): class Parameterizable(Constrainable, Observable):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(Parameterizable, self).__init__(*args, **kwargs) super(Parameterizable, self).__init__(*args, **kwargs)
from GPy.core.parameterization.array_core import ParamList from GPy.core.parameterization.array_core import ParamList
@ -386,7 +390,7 @@ class Parameterizable(Constrainable):
def _set_params(self, params, update=True): def _set_params(self, params, update=True):
# don't overwrite this anymore! # don't overwrite this anymore!
import itertools import itertools
[p._set_params(params[s], update=update) for p, s in itertools.izip(self._parameters_, self._param_slices_)] [p._set_params(params[s]) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
self.parameters_changed() self.parameters_changed()
def copy(self): def copy(self):
@ -420,10 +424,9 @@ class Parameterizable(Constrainable):
return s return s
def _notify_parameters_changed(self): def _notify_parameters_changed(self, which):
self.parameters_changed() self.parameters_changed()
if self.has_parent(): self._notify_observers()
self._direct_parent_._notify_parameters_changed()
def parameters_changed(self): def parameters_changed(self):
""" """

View file

@ -11,7 +11,7 @@ from parameter_core import Constrainable, Pickleable, Parentable, Observable, Pa
from transformations import __fixed__ from transformations import __fixed__
from array_core import ParamList from array_core import ParamList
class Parameterized(Parameterizable, Pickleable, Observable, Gradcheckable): class Parameterized(Parameterizable, Pickleable, Gradcheckable):
""" """
Parameterized class Parameterized class
@ -92,6 +92,7 @@ class Parameterized(Parameterizable, Pickleable, Observable, Gradcheckable):
self.constraints.update(param.constraints, start) self.constraints.update(param.constraints, start)
self.priors.update(param.priors, start) self.priors.update(param.priors, start)
self._parameters_.insert(index, param) self._parameters_.insert(index, param)
param.add_observer(self, self._notify_parameters_changed)
self.size += param.size self.size += param.size
else: else:
raise RuntimeError, """Parameter exists already added and no copy made""" raise RuntimeError, """Parameter exists already added and no copy made"""
@ -120,6 +121,7 @@ class Parameterized(Parameterizable, Pickleable, Observable, Gradcheckable):
del self._parameters_[param._parent_index_] del self._parameters_[param._parent_index_]
param._disconnect_parent() param._disconnect_parent()
param.remove_observer(self, self._notify_parameters_changed)
self.constraints.shift_left(start, param.size) self.constraints.shift_left(start, param.size)
self._connect_fixes() self._connect_fixes()
self._connect_parameters() self._connect_parameters()

View file

@ -9,7 +9,6 @@ from ...util.linalg import tdot
from ...util.misc import fast_array_equal, param_to_array from ...util.misc import fast_array_equal, param_to_array
from ...core.parameterization import Param from ...core.parameterization import Param
from ...core.parameterization.transformations import Logexp from ...core.parameterization.transformations import Logexp
from ...util.caching import cache_this
class Linear(Kern): class Linear(Kern):
""" """

View file

@ -6,6 +6,7 @@ import numpy as np
from scipy import weave from scipy import weave
from ...util.misc import param_to_array from ...util.misc import param_to_array
from stationary import Stationary from stationary import Stationary
from GPy.util.caching import Cache_this
class RBF(Stationary): class RBF(Stationary):
""" """
@ -166,7 +167,7 @@ class RBF(Stationary):
return target return target
#@cache_this TODO @Cache_this(limit=1)
def _psi1computations(self, Z, vp): def _psi1computations(self, Z, vp):
mu, S = vp.mean, vp.variance mu, S = vp.mean, vp.variance
l2 = self.lengthscale **2 l2 = self.lengthscale **2
@ -179,7 +180,7 @@ class RBF(Stationary):
#@cache_this TODO @Cache_this(limit=1)
def _psi2computations(self, Z, vp): def _psi2computations(self, Z, vp):
mu, S = vp.mean, vp.variance mu, S = vp.mean, vp.variance

View file

@ -6,7 +6,6 @@ from kern import Kern
import numpy as np import numpy as np
from ...util.linalg import tdot from ...util.linalg import tdot
from ...util.config import * from ...util.config import *
from ...util.caching import cache_this
from stationary import Stationary from stationary import Stationary
class SSRBF(Stationary): class SSRBF(Stationary):
@ -155,7 +154,7 @@ class SSRBF(Stationary):
# Precomputations # # Precomputations #
#---------------------------------------# #---------------------------------------#
@cache_this(1) #@cache_this(1)
def _K_computations(self, X, X2): def _K_computations(self, X, X2):
""" """
K(X,X2) - X is NxQ K(X,X2) - X is NxQ
@ -175,7 +174,7 @@ class SSRBF(Stationary):
self._K_dist2 = -2.*np.dot(X, X2.T) + (np.sum(np.square(X), axis=1)[:, None] + np.sum(np.square(X2), axis=1)[None, :]) self._K_dist2 = -2.*np.dot(X, X2.T) + (np.sum(np.square(X), axis=1)[:, None] + np.sum(np.square(X2), axis=1)[None, :])
self._K_dvar = np.exp(-0.5 * self._K_dist2) self._K_dvar = np.exp(-0.5 * self._K_dist2)
@cache_this(1) #@cache_this(1)
def _psi_computations(self, Z, mu, S, gamma): def _psi_computations(self, Z, mu, S, gamma):
""" """
Z - MxQ Z - MxQ

View file

@ -9,6 +9,7 @@ from ...util.linalg import tdot
from ... import util from ... import util
import numpy as np import numpy as np
from scipy import integrate from scipy import integrate
from ...util.caching import Cache_this
class Stationary(Kern): class Stationary(Kern):
def __init__(self, input_dim, variance, lengthscale, ARD, name): def __init__(self, input_dim, variance, lengthscale, ARD, name):
@ -39,15 +40,18 @@ class Stationary(Kern):
def dK_dr(self, r): def dK_dr(self, r):
raise NotImplementedError, "implement the covaraiance function as a fn of r to use this class" raise NotImplementedError, "implement the covaraiance function as a fn of r to use this class"
@Cache_this(limit=5, ignore_args=())
def K(self, X, X2=None): def K(self, X, X2=None):
r = self._scaled_dist(X, X2) r = self._scaled_dist(X, X2)
return self.K_of_r(r) return self.K_of_r(r)
@Cache_this(limit=5, ignore_args=(0,))
def _dist(self, X, X2): def _dist(self, X, X2):
if X2 is None: if X2 is None:
X2 = X X2 = X
return X[:, None, :] - X2[None, :, :] return X[:, None, :] - X2[None, :, :]
@Cache_this(limit=5, ignore_args=(0,))
def _unscaled_dist(self, X, X2=None): def _unscaled_dist(self, X, X2=None):
""" """
Compute the square distance between each row of X and X2, or between Compute the square distance between each row of X and X2, or between
@ -61,6 +65,7 @@ class Stationary(Kern):
X2sq = np.sum(np.square(X2),1) X2sq = np.sum(np.square(X2),1)
return np.sqrt(-2.*np.dot(X, X2.T) + (X1sq[:,None] + X2sq[None,:])) return np.sqrt(-2.*np.dot(X, X2.T) + (X1sq[:,None] + X2sq[None,:]))
@Cache_this(limit=5, ignore_args=())
def _scaled_dist(self, X, X2=None): def _scaled_dist(self, X, X2=None):
""" """
Efficiently compute the scaled distance, r. Efficiently compute the scaled distance, r.

View file

@ -41,7 +41,7 @@ class GPLVM(GP):
def parameters_changed(self): def parameters_changed(self):
super(GPLVM, self).parameters_changed() super(GPLVM, self).parameters_changed()
self.X.gradient = self.kern.gradients_X(self._dL_dK, self.X, None) self.X.gradient = self.kern.gradients_X(self.dL_dK, self.X, None)
def _getstate(self): def _getstate(self):
return GP._getstate(self) return GP._getstate(self)

View file

@ -8,6 +8,7 @@ from .. import likelihoods
from .. import kern from .. import kern
from ..inference.latent_function_inference import VarDTC from ..inference.latent_function_inference import VarDTC
from ..util.misc import param_to_array from ..util.misc import param_to_array
from ..core.parameterization.variational import VariationalPosterior
class SparseGPRegression(SparseGP): class SparseGPRegression(SparseGP):
""" """
@ -45,6 +46,9 @@ class SparseGPRegression(SparseGP):
likelihood = likelihoods.Gaussian() likelihood = likelihoods.Gaussian()
if not (X_variance is None):
X = VariationalPosterior(X,X_variance)
SparseGP.__init__(self, X, Y, Z, kernel, likelihood, inference_method=VarDTC()) SparseGP.__init__(self, X, Y, Z, kernel, likelihood, inference_method=VarDTC())
def _getstate(self): def _getstate(self):

View file

@ -1,21 +1,23 @@
from ..core.parameterization.parameter_core import Observable from ..core.parameterization.parameter_core import Observable
class Cacher(object): class Cacher(object):
def __init__(self, operation, limit=5, reset_on_first=False): def __init__(self, operation, limit=5, ignore_args=()):
self.limit = int(limit) self.limit = int(limit)
self._reset_on_first = reset_on_first self.ignore_args = ignore_args
self.operation=operation self.operation=operation
self.cached_inputs = [] self.cached_inputs = []
self.cached_outputs = [] self.cached_outputs = []
self.inputs_changed = [] self.inputs_changed = []
def __call__(self, *args): def __call__(self, *args):
if self._reset_on_first: if len(self.ignore_args) != 0:
assert isinstance(args[0], Observable) ca = [a for i,a in enumerate(args) if i not in self.ignore_args]
args[0].add_observer(self, self.reset) cached_args = []
cached_args = args for a in ca:
if not any(a is ai for ai in cached_args):
cached_args.append(a)
else: else:
cached_args = args[1:] cached_args = args
if not all([isinstance(arg, Observable) for arg in cached_args]): if not all([isinstance(arg, Observable) for arg in cached_args]):
@ -36,7 +38,7 @@ class Cacher(object):
self.cached_inputs.append(cached_args) self.cached_inputs.append(cached_args)
self.cached_outputs.append(self.operation(*args)) self.cached_outputs.append(self.operation(*args))
self.inputs_changed.append(False) self.inputs_changed.append(False)
[a.add_observer(self, self.on_cache_changed) for a in args] [a.add_observer(self, self.on_cache_changed) for a in cached_args]
return self.cached_outputs[-1] return self.cached_outputs[-1]
def on_cache_changed(self, arg): def on_cache_changed(self, arg):
@ -48,42 +50,15 @@ class Cacher(object):
self.cached_outputs = [] self.cached_outputs = []
self.inputs_changed = [] self.inputs_changed = []
class Cache_this(object):
def __init__(self, limit=5, ignore_args=()):
self.limit = limit
def cache_this(limit=5, reset_on_self=False): self.ignore_args = ignore_args
def limited_cache(f): self.c = None
c = Cacher(f, limit, reset_on_first=reset_on_self) def __call__(self, f):
def f_wrap(*args): def f_wrap(*args):
return c(*args) if self.c is None:
f_wrap._cacher = c self.c = Cacher(f, self.limit, ignore_args=self.ignore_args)
return self.c(*args)
f_wrap._cacher = self
return f_wrap return f_wrap
return limited_cache
#Xbase = X
#while Xbase is not None:
#try:
#i = self.cached_inputs.index(X)
#break
#except ValueError:
#Xbase = X.base
#continue
#self.inputs_changed[i] = True