linear without caching, derivatives done

This commit is contained in:
Max Zwiessele 2014-02-21 09:14:31 +00:00
parent 1d722c4f28
commit 0c92fca31a
7 changed files with 71 additions and 54 deletions

View file

@ -30,12 +30,12 @@ class ObservableArray(np.ndarray, Observable):
def __new__(cls, input_array):
obj = np.atleast_1d(input_array).view(cls)
cls.__name__ = "ObservableArray\n "
obj._observers_ = {}
obj._observer_callables_ = {}
return obj
def __array_finalize__(self, obj):
# see InfoArray.__array_finalize__ for comments
if obj is None: return
self._observers_ = getattr(obj, '_observers_', None)
self._observer_callables_ = getattr(obj, '_observer_callables_', None)
def __array_wrap__(self, out_arr, context=None):
return out_arr.view(np.ndarray)

View file

@ -11,14 +11,14 @@ def adjust_name_for_printing(name):
return ''
class Observable(object):
_observers_ = {}
def add_observer(self, observer, callble):
self._observers_[observer] = callble
_observer_callables_ = {}
def add_observer(self, callble):
self._observer_callables_.append(callble)
#callble(self)
def remove_observer(self, observer):
del self._observers_[observer]
def remove_observer(self, callble):
del self._observer_callables_[callble]
def _notify_observers(self):
[callble(self) for callble in self._observers_.itervalues()]
[callble(self) for callble in self._observer_callables_]
class Pickleable(object):
def _getstate(self):

View file

@ -44,26 +44,26 @@ class SparseGP(GP):
self.Z = Param('inducing inputs', Z)
self.num_inducing = Z.shape[0]
if not (X_variance is None):
assert X_variance.shape == X.shape
self.X_variance = X_variance
if self.has_uncertain_inputs():
assert X_variance.shape == X.shape
GP.__init__(self, X, Y, kernel, likelihood, inference_method=inference_method, name=name)
self.add_parameter(self.Z, index=0)
self.parameters_changed()
def update_gradients_Z(self):
#The derivative of the bound wrt the inducing inputs Z ( unless they're all fixed)
if not self.Z.is_fixed:
if self.X_variance is None:
self.Z.gradient = self.kern.gradients_Z_sparse(X=self.X, Z=self.Z, **self.grad_dict)
else:
self.Z.gradient = self.kern.gradients_Z_variational(mu=self.X, S=self.X_variance, Z=self.Z, **self.grad_dict)
def has_uncertain_inputs(self):
return not (self.X_variance is None)
def parameters_changed(self):
self.posterior, self._log_marginal_likelihood, self.grad_dict = self.inference_method.inference(self.kern, self.X, self.X_variance, self.Z, self.likelihood, self.Y)
self.update_gradients_Z()
if self.has_uncertain_inputs():
self.kern.update_gradients_variational(mu=self.X, S=self.X_variance, Z=self.Z, **self.grad_dict)
self.Z.gradient = self.kern.gradients_Z_variational(mu=self.X, S=self.X_variance, Z=self.Z, **self.grad_dict)
else:
self.kern.update_gradients_sparse(X=self.X, Z=self.Z, **self.grad_dict)
self.Z.gradient = self.kern.gradients_Z_sparse(X=self.X, Z=self.Z, **self.grad_dict)
def _raw_predict(self, Xnew, X_variance_new=None, full_cov=False):
"""
@ -97,12 +97,10 @@ class SparseGP(GP):
"""
return GP._getstate(self) + [self.Z,
self.num_inducing,
self.has_uncertain_inputs,
self.X_variance]
def _setstate(self, state):
self.X_variance = state.pop()
self.has_uncertain_inputs = state.pop()
self.num_inducing = state.pop()
self.Z = state.pop()
GP._setstate(self, state)

View file

@ -70,10 +70,8 @@ class VarDTC(object):
if uncertain_inputs:
grad_dict = {'dL_dKmm': dL_dKmm, 'dL_dpsi0':dL_dpsi0, 'dL_dpsi1':dL_dpsi1, 'dL_dpsi2':dL_dpsi2}
kern.update_gradients_variational(mu=X, S=X_variance, Z=Z, **grad_dict)
else:
grad_dict = {'dL_dKmm': dL_dKmm, 'dL_dKdiag':dL_dpsi0, 'dL_dKnm':dL_dpsi1}
kern.update_gradients_sparse(X=X, Z=Z, **grad_dict)
#get sufficient things for posterior prediction
#TODO: do we really want to do this in the loop?

View file

@ -22,15 +22,15 @@ class Kern(Parameterized):
super(Kern, self).__init__(name)
self.input_dim = input_dim
def K(self, X, X2, target):
def K(self, X, X2):
raise NotImplementedError
def Kdiag(self, Xa ,target):
def Kdiag(self, Xa):
raise NotImplementedError
def psi0(self,Z,mu,S,target):
def psi0(self,Z,mu,S):
raise NotImplementedError
def psi1(self,Z,mu,S,target):
def psi1(self,Z,mu,S):
raise NotImplementedError
def psi2(self,Z,mu,S,target):
def psi2(self,Z,mu,S):
raise NotImplementedError
def gradients_X(self, dL_dK, X, X2):
raise NotImplementedError
@ -49,7 +49,11 @@ class Kern(Parameterized):
grad = self.gradients_X(dL_dKmm, Z)
grad += self.gradients_X(dL_dKnm.T, Z, X)
return grad
def gradients_Z_variational(self, dL_dKmm, dL_dpsi0, dL_dpsi1, dL_dpsi2, mu, S, Z):
raise NotImplementedError
def gradients_muS_variational(self, dL_dKmm, dL_dpsi0, dL_dpsi1, dL_dpsi2, mu, S, Z):
raise NotImplementedError
def plot_ARD(self, *args):
"""If an ARD kernel is present, plot a bar representation using matplotlib

View file

@ -119,34 +119,55 @@ class Linear(Kern):
def gradients_X_diag(self, dL_dKdiag, X):
return 2.*self.variances*dL_dKdiag[:,None]*X
#---------------------------------------#
# PSI statistics #
# variational #
#---------------------------------------#
def psi0(self, Z, mu, S, target):
self._psi_computations(Z, mu, S)
target += np.sum(self.variances * self.mu2_S, 1)
def gradients_Z_variational(self, dL_dKmm, dL_dpsi0, dL_dpsi1, dL_dpsi2, mu, S, Z):
# Kmm
grad = self.gradients_X(dL_dKmm, Z, None)
#psi1
grad += self.gradients_X(dL_dpsi1.T, Z, mu)
#psi2
self._weave_dpsi2_dZ(dL_dpsi2, Z, mu, S, grad)
return grad
def gradients_muS_variational(self, dL_dKmm, dL_dpsi0, dL_dpsi1, dL_dpsi2, mu, S, Z):
target_mu, target_S = np.zeros(mu.shape), np.zeros(mu.shape)
# psi0
target_mu += dL_dpsi0[:, None] * (2.0 * mu * self.variances)
target_S += dL_dpsi0[:, None] * self.variances
# psi1
target_mu += (dL_dpsi1[:, :, None] * (Z * self.variances)).sum(1)
# psi2
self._weave_dpsi2_dmuS(dL_dpsi2, Z, mu, S, target_mu, target_S)
return target_mu, target_S
def psi0(self, Z, mu, S):
self._psi_computations(Z, mu, S)
return np.sum(self.variances * self.mu2_S, 1)
def psi1(self, Z, mu, S):
"""the variance, it does nothing"""
self._psi1 = self.K(mu, Z)
return self._psi1
def psi2(self, Z, mu, S):
self._psi_computations(Z, mu, S)
return self._psi2
def dpsi0_dmuS(self, dL_dpsi0, Z, mu, S, target_mu, target_S):
target_mu += dL_dpsi0[:, None] * (2.0 * mu * self.variances)
target_S += dL_dpsi0[:, None] * self.variances
def psi1(self, Z, mu, S, target):
"""the variance, it does nothing"""
self._psi1 = self.K(mu, Z, target)
def dpsi1_dmuS(self, dL_dpsi1, Z, mu, S, target_mu, target_S):
"""Do nothing for S, it does not affect psi1"""
self._psi_computations(Z, mu, S)
target_mu += (dL_dpsi1[:, :, None] * (Z * self.variances)).sum(1)
def dpsi1_dZ(self, dL_dpsi1, Z, mu, S, target):
self.gradients_X(dL_dpsi1.T, Z, mu, target)
def psi2(self, Z, mu, S, target):
self._psi_computations(Z, mu, S)
target += self._psi2
def psi2_new(self,Z,mu,S,target):
tmp = np.zeros((mu.shape[0], Z.shape[0]))
@ -172,7 +193,7 @@ class Linear(Kern):
Zs_sq = Zs[:,None,:]*Zs[None,:,:]
target_S += (dL_dpsi2[:,:,:,None]*Zs_sq[None,:,:,:]).sum(1).sum(1)
def dpsi2_dmuS(self, dL_dpsi2, Z, mu, S, target_mu, target_S):
def _weave_dpsi2_dmuS(self, dL_dpsi2, Z, mu, S, target_mu, target_S):
"""Think N,num_inducing,num_inducing,input_dim """
self._psi_computations(Z, mu, S)
AZZA = self.ZA.T[:, None, :, None] * self.ZA[None, :, None, :]
@ -226,7 +247,7 @@ class Linear(Kern):
type_converters=weave.converters.blitz,**weave_options)
def dpsi2_dZ(self, dL_dpsi2, Z, mu, S, target):
def _weave_dpsi2_dZ(self, dL_dpsi2, Z, mu, S, target):
self._psi_computations(Z, mu, S)
#psi2_dZ = dL_dpsi2[:, :, :, None] * self.variances * self.ZAinner[:, :, None, :]
#dummy_target = np.zeros_like(target)
@ -261,9 +282,6 @@ class Linear(Kern):
type_converters=weave.converters.blitz,**weave_options)
#---------------------------------------#
# Precomputations #
#---------------------------------------#

View file

@ -1,5 +1,4 @@
from ..core.parameterization.parameter_core import Observable
from ..core.parameterization.array_core import ParamList
class Cacher(object):
def __init__(self, operation, limit=5, reset_on_first=False):
@ -13,7 +12,7 @@ class Cacher(object):
def __call__(self, *args):
if self._reset_on_first:
assert isinstance(args[0], Observable)
args[0].add_observer(args[0], self.reset)
args[0].add_observer(self.reset)
cached_args = args
else:
cached_args = args[1:]
@ -30,21 +29,21 @@ class Cacher(object):
else:
if len(self.cached_inputs) == self.limit:
args_ = self.cached_inputs.pop(0)
[a.remove_observer(self) for a in args_]
[a.remove_observer(self.on_cache_changed) for a in args_]
self.inputs_changed.pop(0)
self.cached_outputs.pop(0)
self.cached_inputs.append(cached_args)
self.cached_outputs.append(self.operation(*args))
self.inputs_changed.append(False)
[a.add_observer(self, self.on_cache_changed) for a in args]
[a.add_observer(self.on_cache_changed) for a in args]
return self.cached_outputs[-1]
def on_cache_changed(self, arg):
self.inputs_changed = [any([a is arg for a in args]) or old_ic for args, old_ic in zip(self.cached_inputs, self.inputs_changed)]
def reset(self, obj):
[[a.remove_observer(self) for a in args] for args in self.cached_inputs]
[[a.remove_observer(self.reset) for a in args] for args in self.cached_inputs]
self.cached_inputs = []
self.cached_outputs = []
self.inputs_changed = []