mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-08 03:22:38 +02:00
observable pattern through and thorugh
This commit is contained in:
parent
26aeb5e1db
commit
65fd6dd24e
11 changed files with 64 additions and 80 deletions
|
|
@ -59,11 +59,10 @@ class ObservableArray(np.ndarray, Observable):
|
||||||
else:
|
else:
|
||||||
return s.size != 0
|
return s.size != 0
|
||||||
|
|
||||||
def __setitem__(self, s, val, update=True):
|
def __setitem__(self, s, val):
|
||||||
if self._s_not_empty(s):
|
if self._s_not_empty(s):
|
||||||
super(ObservableArray, self).__setitem__(s, val)
|
super(ObservableArray, self).__setitem__(s, val)
|
||||||
if update:
|
self._notify_observers()
|
||||||
self._notify_observers()
|
|
||||||
|
|
||||||
def __getslice__(self, start, stop):
|
def __getslice__(self, start, stop):
|
||||||
return self.__getitem__(slice(start, stop))
|
return self.__getitem__(slice(start, stop))
|
||||||
|
|
|
||||||
|
|
@ -137,8 +137,6 @@ class Param(Constrainable, ObservableArray, Gradcheckable):
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
def _set_params(self, param, update=True):
|
def _set_params(self, param, update=True):
|
||||||
self.flat = param
|
self.flat = param
|
||||||
#self._notify_tied_parameters()
|
|
||||||
self._notify_observers()
|
|
||||||
|
|
||||||
def _get_params(self):
|
def _get_params(self):
|
||||||
return self.flat
|
return self.flat
|
||||||
|
|
@ -161,12 +159,10 @@ class Param(Constrainable, ObservableArray, Gradcheckable):
|
||||||
try: new_arr._current_slice_ = s; new_arr._original_ = self.base is new_arr.base
|
try: new_arr._current_slice_ = s; new_arr._original_ = self.base is new_arr.base
|
||||||
except AttributeError: pass # returning 0d array or float, double etc
|
except AttributeError: pass # returning 0d array or float, double etc
|
||||||
return new_arr
|
return new_arr
|
||||||
def __setitem__(self, s, val, update=True):
|
def __setitem__(self, s, val):
|
||||||
super(Param, self).__setitem__(s, val, update=update)
|
super(Param, self).__setitem__(s, val)
|
||||||
#self._notify_tied_parameters()
|
#self._notify_observers()
|
||||||
if update and self._s_not_empty(s):
|
|
||||||
self._notify_parameters_changed()
|
|
||||||
|
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
# Index Operations:
|
# Index Operations:
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
|
|
@ -185,6 +181,7 @@ class Param(Constrainable, ObservableArray, Gradcheckable):
|
||||||
a = self._realshape_[i] + a
|
a = self._realshape_[i] + a
|
||||||
internal_offset += a * extended_realshape[i]
|
internal_offset += a * extended_realshape[i]
|
||||||
return internal_offset
|
return internal_offset
|
||||||
|
|
||||||
def _raveled_index(self, slice_index=None):
|
def _raveled_index(self, slice_index=None):
|
||||||
# return an index array on the raveled array, which is formed by the current_slice
|
# return an index array on the raveled array, which is formed by the current_slice
|
||||||
# of this object
|
# of this object
|
||||||
|
|
@ -354,7 +351,7 @@ class ParamConcatenation(object):
|
||||||
val = val._vals()
|
val = val._vals()
|
||||||
ind = numpy.zeros(sum(self._param_sizes), dtype=bool); ind[s] = True;
|
ind = numpy.zeros(sum(self._param_sizes), dtype=bool); ind[s] = True;
|
||||||
vals = self._vals(); vals[s] = val; del val
|
vals = self._vals(); vals[s] = val; del val
|
||||||
[numpy.place(p, ind[ps], vals[ps]) and update and p._notify_parameters_changed()
|
[numpy.place(p, ind[ps], vals[ps]) and update and p._notify_observers()
|
||||||
for p, ps in zip(self.params, self._param_slices_)]
|
for p, ps in zip(self.params, self._param_slices_)]
|
||||||
def _vals(self):
|
def _vals(self):
|
||||||
return numpy.hstack([p._get_params() for p in self.params])
|
return numpy.hstack([p._get_params() for p in self.params])
|
||||||
|
|
@ -363,7 +360,7 @@ class ParamConcatenation(object):
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
def update_all_params(self):
|
def update_all_params(self):
|
||||||
for p in self.params:
|
for p in self.params:
|
||||||
p._notify_parameters_changed()
|
p._notify_observers()
|
||||||
|
|
||||||
def constrain(self, constraint, warning=True):
|
def constrain(self, constraint, warning=True):
|
||||||
[param.constrain(constraint, update=False) for param in self.params]
|
[param.constrain(constraint, update=False) for param in self.params]
|
||||||
|
|
|
||||||
|
|
@ -18,8 +18,13 @@ class Observable(object):
|
||||||
def add_observer(self, observer, callble):
|
def add_observer(self, observer, callble):
|
||||||
self._observer_callables_[observer].append(callble)
|
self._observer_callables_[observer].append(callble)
|
||||||
|
|
||||||
def remove_observer(self, observer, callble):
|
def remove_observer(self, observer, callble=None):
|
||||||
del self._observer_callables_[observer][callble]
|
if callble is None:
|
||||||
|
del self._observer_callables_[observer]
|
||||||
|
else:
|
||||||
|
self._observer_callables_[observer].remove(callble)
|
||||||
|
if len(self._observer_callables_[observer]) == 0:
|
||||||
|
self.remove_observer(observer)
|
||||||
|
|
||||||
def _notify_observers(self):
|
def _notify_observers(self):
|
||||||
[[callble(self) for callble in callables]
|
[[callble(self) for callble in callables]
|
||||||
|
|
@ -72,9 +77,8 @@ class Parentable(object):
|
||||||
return self._direct_parent_._highest_parent_
|
return self._direct_parent_._highest_parent_
|
||||||
|
|
||||||
def _notify_parameters_changed(self):
|
def _notify_parameters_changed(self):
|
||||||
if self.has_parent():
|
raise NotImplementedError, "shouldnt happen, abstract superclass"
|
||||||
self._direct_parent_._notify_parameters_changed()
|
|
||||||
|
|
||||||
class Nameable(Parentable):
|
class Nameable(Parentable):
|
||||||
def __init__(self, name, *a, **kw):
|
def __init__(self, name, *a, **kw):
|
||||||
super(Nameable, self).__init__(*a, **kw)
|
super(Nameable, self).__init__(*a, **kw)
|
||||||
|
|
@ -309,7 +313,7 @@ class Constrainable(Nameable, Indexable):
|
||||||
print "WARNING: reconstraining parameters {}".format(self.parameter_names() or self.name)
|
print "WARNING: reconstraining parameters {}".format(self.parameter_names() or self.name)
|
||||||
which.add(transform, self._raveled_index())
|
which.add(transform, self._raveled_index())
|
||||||
if update:
|
if update:
|
||||||
self._notify_parameters_changed()
|
self._notify_observers()
|
||||||
|
|
||||||
def _remove_from_index_operations(self, which, transforms):
|
def _remove_from_index_operations(self, which, transforms):
|
||||||
if len(transforms) == 0:
|
if len(transforms) == 0:
|
||||||
|
|
@ -325,7 +329,7 @@ class Constrainable(Nameable, Indexable):
|
||||||
return removed
|
return removed
|
||||||
|
|
||||||
|
|
||||||
class Parameterizable(Constrainable):
|
class Parameterizable(Constrainable, Observable):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(Parameterizable, self).__init__(*args, **kwargs)
|
super(Parameterizable, self).__init__(*args, **kwargs)
|
||||||
from GPy.core.parameterization.array_core import ParamList
|
from GPy.core.parameterization.array_core import ParamList
|
||||||
|
|
@ -386,7 +390,7 @@ class Parameterizable(Constrainable):
|
||||||
def _set_params(self, params, update=True):
|
def _set_params(self, params, update=True):
|
||||||
# don't overwrite this anymore!
|
# don't overwrite this anymore!
|
||||||
import itertools
|
import itertools
|
||||||
[p._set_params(params[s], update=update) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
[p._set_params(params[s]) for p, s in itertools.izip(self._parameters_, self._param_slices_)]
|
||||||
self.parameters_changed()
|
self.parameters_changed()
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
|
|
@ -420,11 +424,10 @@ class Parameterizable(Constrainable):
|
||||||
|
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def _notify_parameters_changed(self):
|
def _notify_parameters_changed(self, which):
|
||||||
self.parameters_changed()
|
self.parameters_changed()
|
||||||
if self.has_parent():
|
self._notify_observers()
|
||||||
self._direct_parent_._notify_parameters_changed()
|
|
||||||
|
|
||||||
def parameters_changed(self):
|
def parameters_changed(self):
|
||||||
"""
|
"""
|
||||||
This method gets called when parameters have changed.
|
This method gets called when parameters have changed.
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from parameter_core import Constrainable, Pickleable, Parentable, Observable, Pa
|
||||||
from transformations import __fixed__
|
from transformations import __fixed__
|
||||||
from array_core import ParamList
|
from array_core import ParamList
|
||||||
|
|
||||||
class Parameterized(Parameterizable, Pickleable, Observable, Gradcheckable):
|
class Parameterized(Parameterizable, Pickleable, Gradcheckable):
|
||||||
"""
|
"""
|
||||||
Parameterized class
|
Parameterized class
|
||||||
|
|
||||||
|
|
@ -92,6 +92,7 @@ class Parameterized(Parameterizable, Pickleable, Observable, Gradcheckable):
|
||||||
self.constraints.update(param.constraints, start)
|
self.constraints.update(param.constraints, start)
|
||||||
self.priors.update(param.priors, start)
|
self.priors.update(param.priors, start)
|
||||||
self._parameters_.insert(index, param)
|
self._parameters_.insert(index, param)
|
||||||
|
param.add_observer(self, self._notify_parameters_changed)
|
||||||
self.size += param.size
|
self.size += param.size
|
||||||
else:
|
else:
|
||||||
raise RuntimeError, """Parameter exists already added and no copy made"""
|
raise RuntimeError, """Parameter exists already added and no copy made"""
|
||||||
|
|
@ -120,6 +121,7 @@ class Parameterized(Parameterizable, Pickleable, Observable, Gradcheckable):
|
||||||
del self._parameters_[param._parent_index_]
|
del self._parameters_[param._parent_index_]
|
||||||
|
|
||||||
param._disconnect_parent()
|
param._disconnect_parent()
|
||||||
|
param.remove_observer(self, self._notify_parameters_changed)
|
||||||
self.constraints.shift_left(start, param.size)
|
self.constraints.shift_left(start, param.size)
|
||||||
self._connect_fixes()
|
self._connect_fixes()
|
||||||
self._connect_parameters()
|
self._connect_parameters()
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,6 @@ from ...util.linalg import tdot
|
||||||
from ...util.misc import fast_array_equal, param_to_array
|
from ...util.misc import fast_array_equal, param_to_array
|
||||||
from ...core.parameterization import Param
|
from ...core.parameterization import Param
|
||||||
from ...core.parameterization.transformations import Logexp
|
from ...core.parameterization.transformations import Logexp
|
||||||
from ...util.caching import cache_this
|
|
||||||
|
|
||||||
class Linear(Kern):
|
class Linear(Kern):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ import numpy as np
|
||||||
from scipy import weave
|
from scipy import weave
|
||||||
from ...util.misc import param_to_array
|
from ...util.misc import param_to_array
|
||||||
from stationary import Stationary
|
from stationary import Stationary
|
||||||
|
from GPy.util.caching import Cache_this
|
||||||
|
|
||||||
class RBF(Stationary):
|
class RBF(Stationary):
|
||||||
"""
|
"""
|
||||||
|
|
@ -166,7 +167,7 @@ class RBF(Stationary):
|
||||||
return target
|
return target
|
||||||
|
|
||||||
|
|
||||||
#@cache_this TODO
|
@Cache_this(limit=1)
|
||||||
def _psi1computations(self, Z, vp):
|
def _psi1computations(self, Z, vp):
|
||||||
mu, S = vp.mean, vp.variance
|
mu, S = vp.mean, vp.variance
|
||||||
l2 = self.lengthscale **2
|
l2 = self.lengthscale **2
|
||||||
|
|
@ -179,7 +180,7 @@ class RBF(Stationary):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#@cache_this TODO
|
@Cache_this(limit=1)
|
||||||
def _psi2computations(self, Z, vp):
|
def _psi2computations(self, Z, vp):
|
||||||
mu, S = vp.mean, vp.variance
|
mu, S = vp.mean, vp.variance
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ from kern import Kern
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ...util.linalg import tdot
|
from ...util.linalg import tdot
|
||||||
from ...util.config import *
|
from ...util.config import *
|
||||||
from ...util.caching import cache_this
|
|
||||||
from stationary import Stationary
|
from stationary import Stationary
|
||||||
|
|
||||||
class SSRBF(Stationary):
|
class SSRBF(Stationary):
|
||||||
|
|
@ -155,7 +154,7 @@ class SSRBF(Stationary):
|
||||||
# Precomputations #
|
# Precomputations #
|
||||||
#---------------------------------------#
|
#---------------------------------------#
|
||||||
|
|
||||||
@cache_this(1)
|
#@cache_this(1)
|
||||||
def _K_computations(self, X, X2):
|
def _K_computations(self, X, X2):
|
||||||
"""
|
"""
|
||||||
K(X,X2) - X is NxQ
|
K(X,X2) - X is NxQ
|
||||||
|
|
@ -175,7 +174,7 @@ class SSRBF(Stationary):
|
||||||
self._K_dist2 = -2.*np.dot(X, X2.T) + (np.sum(np.square(X), axis=1)[:, None] + np.sum(np.square(X2), axis=1)[None, :])
|
self._K_dist2 = -2.*np.dot(X, X2.T) + (np.sum(np.square(X), axis=1)[:, None] + np.sum(np.square(X2), axis=1)[None, :])
|
||||||
self._K_dvar = np.exp(-0.5 * self._K_dist2)
|
self._K_dvar = np.exp(-0.5 * self._K_dist2)
|
||||||
|
|
||||||
@cache_this(1)
|
#@cache_this(1)
|
||||||
def _psi_computations(self, Z, mu, S, gamma):
|
def _psi_computations(self, Z, mu, S, gamma):
|
||||||
"""
|
"""
|
||||||
Z - MxQ
|
Z - MxQ
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from ...util.linalg import tdot
|
||||||
from ... import util
|
from ... import util
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy import integrate
|
from scipy import integrate
|
||||||
|
from ...util.caching import Cache_this
|
||||||
|
|
||||||
class Stationary(Kern):
|
class Stationary(Kern):
|
||||||
def __init__(self, input_dim, variance, lengthscale, ARD, name):
|
def __init__(self, input_dim, variance, lengthscale, ARD, name):
|
||||||
|
|
@ -39,15 +40,18 @@ class Stationary(Kern):
|
||||||
def dK_dr(self, r):
|
def dK_dr(self, r):
|
||||||
raise NotImplementedError, "implement the covaraiance function as a fn of r to use this class"
|
raise NotImplementedError, "implement the covaraiance function as a fn of r to use this class"
|
||||||
|
|
||||||
|
@Cache_this(limit=5, ignore_args=())
|
||||||
def K(self, X, X2=None):
|
def K(self, X, X2=None):
|
||||||
r = self._scaled_dist(X, X2)
|
r = self._scaled_dist(X, X2)
|
||||||
return self.K_of_r(r)
|
return self.K_of_r(r)
|
||||||
|
|
||||||
|
@Cache_this(limit=5, ignore_args=(0,))
|
||||||
def _dist(self, X, X2):
|
def _dist(self, X, X2):
|
||||||
if X2 is None:
|
if X2 is None:
|
||||||
X2 = X
|
X2 = X
|
||||||
return X[:, None, :] - X2[None, :, :]
|
return X[:, None, :] - X2[None, :, :]
|
||||||
|
|
||||||
|
@Cache_this(limit=5, ignore_args=(0,))
|
||||||
def _unscaled_dist(self, X, X2=None):
|
def _unscaled_dist(self, X, X2=None):
|
||||||
"""
|
"""
|
||||||
Compute the square distance between each row of X and X2, or between
|
Compute the square distance between each row of X and X2, or between
|
||||||
|
|
@ -61,6 +65,7 @@ class Stationary(Kern):
|
||||||
X2sq = np.sum(np.square(X2),1)
|
X2sq = np.sum(np.square(X2),1)
|
||||||
return np.sqrt(-2.*np.dot(X, X2.T) + (X1sq[:,None] + X2sq[None,:]))
|
return np.sqrt(-2.*np.dot(X, X2.T) + (X1sq[:,None] + X2sq[None,:]))
|
||||||
|
|
||||||
|
@Cache_this(limit=5, ignore_args=())
|
||||||
def _scaled_dist(self, X, X2=None):
|
def _scaled_dist(self, X, X2=None):
|
||||||
"""
|
"""
|
||||||
Efficiently compute the scaled distance, r.
|
Efficiently compute the scaled distance, r.
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ class GPLVM(GP):
|
||||||
|
|
||||||
def parameters_changed(self):
|
def parameters_changed(self):
|
||||||
super(GPLVM, self).parameters_changed()
|
super(GPLVM, self).parameters_changed()
|
||||||
self.X.gradient = self.kern.gradients_X(self._dL_dK, self.X, None)
|
self.X.gradient = self.kern.gradients_X(self.dL_dK, self.X, None)
|
||||||
|
|
||||||
def _getstate(self):
|
def _getstate(self):
|
||||||
return GP._getstate(self)
|
return GP._getstate(self)
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ from .. import likelihoods
|
||||||
from .. import kern
|
from .. import kern
|
||||||
from ..inference.latent_function_inference import VarDTC
|
from ..inference.latent_function_inference import VarDTC
|
||||||
from ..util.misc import param_to_array
|
from ..util.misc import param_to_array
|
||||||
|
from ..core.parameterization.variational import VariationalPosterior
|
||||||
|
|
||||||
class SparseGPRegression(SparseGP):
|
class SparseGPRegression(SparseGP):
|
||||||
"""
|
"""
|
||||||
|
|
@ -44,7 +45,10 @@ class SparseGPRegression(SparseGP):
|
||||||
assert Z.shape[1] == input_dim
|
assert Z.shape[1] == input_dim
|
||||||
|
|
||||||
likelihood = likelihoods.Gaussian()
|
likelihood = likelihoods.Gaussian()
|
||||||
|
|
||||||
|
if not (X_variance is None):
|
||||||
|
X = VariationalPosterior(X,X_variance)
|
||||||
|
|
||||||
SparseGP.__init__(self, X, Y, Z, kernel, likelihood, inference_method=VarDTC())
|
SparseGP.__init__(self, X, Y, Z, kernel, likelihood, inference_method=VarDTC())
|
||||||
|
|
||||||
def _getstate(self):
|
def _getstate(self):
|
||||||
|
|
|
||||||
|
|
@ -1,21 +1,23 @@
|
||||||
from ..core.parameterization.parameter_core import Observable
|
from ..core.parameterization.parameter_core import Observable
|
||||||
|
|
||||||
class Cacher(object):
|
class Cacher(object):
|
||||||
def __init__(self, operation, limit=5, reset_on_first=False):
|
def __init__(self, operation, limit=5, ignore_args=()):
|
||||||
self.limit = int(limit)
|
self.limit = int(limit)
|
||||||
self._reset_on_first = reset_on_first
|
self.ignore_args = ignore_args
|
||||||
self.operation=operation
|
self.operation=operation
|
||||||
self.cached_inputs = []
|
self.cached_inputs = []
|
||||||
self.cached_outputs = []
|
self.cached_outputs = []
|
||||||
self.inputs_changed = []
|
self.inputs_changed = []
|
||||||
|
|
||||||
def __call__(self, *args):
|
def __call__(self, *args):
|
||||||
if self._reset_on_first:
|
if len(self.ignore_args) != 0:
|
||||||
assert isinstance(args[0], Observable)
|
ca = [a for i,a in enumerate(args) if i not in self.ignore_args]
|
||||||
args[0].add_observer(self, self.reset)
|
cached_args = []
|
||||||
cached_args = args
|
for a in ca:
|
||||||
|
if not any(a is ai for ai in cached_args):
|
||||||
|
cached_args.append(a)
|
||||||
else:
|
else:
|
||||||
cached_args = args[1:]
|
cached_args = args
|
||||||
|
|
||||||
|
|
||||||
if not all([isinstance(arg, Observable) for arg in cached_args]):
|
if not all([isinstance(arg, Observable) for arg in cached_args]):
|
||||||
|
|
@ -36,7 +38,7 @@ class Cacher(object):
|
||||||
self.cached_inputs.append(cached_args)
|
self.cached_inputs.append(cached_args)
|
||||||
self.cached_outputs.append(self.operation(*args))
|
self.cached_outputs.append(self.operation(*args))
|
||||||
self.inputs_changed.append(False)
|
self.inputs_changed.append(False)
|
||||||
[a.add_observer(self, self.on_cache_changed) for a in args]
|
[a.add_observer(self, self.on_cache_changed) for a in cached_args]
|
||||||
return self.cached_outputs[-1]
|
return self.cached_outputs[-1]
|
||||||
|
|
||||||
def on_cache_changed(self, arg):
|
def on_cache_changed(self, arg):
|
||||||
|
|
@ -48,42 +50,15 @@ class Cacher(object):
|
||||||
self.cached_outputs = []
|
self.cached_outputs = []
|
||||||
self.inputs_changed = []
|
self.inputs_changed = []
|
||||||
|
|
||||||
|
class Cache_this(object):
|
||||||
|
def __init__(self, limit=5, ignore_args=()):
|
||||||
|
self.limit = limit
|
||||||
def cache_this(limit=5, reset_on_self=False):
|
self.ignore_args = ignore_args
|
||||||
def limited_cache(f):
|
self.c = None
|
||||||
c = Cacher(f, limit, reset_on_first=reset_on_self)
|
def __call__(self, f):
|
||||||
def f_wrap(*args):
|
def f_wrap(*args):
|
||||||
return c(*args)
|
if self.c is None:
|
||||||
f_wrap._cacher = c
|
self.c = Cacher(f, self.limit, ignore_args=self.ignore_args)
|
||||||
return f_wrap
|
return self.c(*args)
|
||||||
return limited_cache
|
f_wrap._cacher = self
|
||||||
|
return f_wrap
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#Xbase = X
|
|
||||||
#while Xbase is not None:
|
|
||||||
#try:
|
|
||||||
#i = self.cached_inputs.index(X)
|
|
||||||
#break
|
|
||||||
#except ValueError:
|
|
||||||
#Xbase = X.base
|
|
||||||
#continue
|
|
||||||
#self.inputs_changed[i] = True
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue