messing with kernels

This commit is contained in:
James Hensman 2014-02-25 17:15:38 +00:00
parent 6a667e749f
commit 80acca640f
8 changed files with 66 additions and 57 deletions

View file

@ -39,28 +39,21 @@ class Kern(Parameterized):
def update_gradients_full(self, dL_dK, X, X2):
"""Set the gradients of all parameters when doing full (N) inference."""
raise NotImplementedError
def update_gradients_sparse(self, dL_dKmm, dL_dKnm, dL_dKdiag, X, Z):
target = np.zeros(self.size)
self.update_gradients_diag(dL_dKdiag, X)
self._collect_gradient(target)
self.update_gradients_full(dL_dKnm, X, Z)
self._collect_gradient(target)
self.update_gradients_full(dL_dKmm, Z, None)
self._collect_gradient(target)
self._set_gradient(target)
def update_gradients_expectations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
"""
Set the gradients of all parameters when doing inference with
uncertain inputs, using expectations of the kernel.
"""
raise NotImplementedError
def gradients_Z_expectations(self, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
raise NotImplementedError
def gradients_qX_expectations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
"""
Compute the gradients wrt the parameters of the variational
distruibution q(X), chain-ruling via the expectations of the kernel
"""
raise NotImplementedError
def update_gradients_variational(self, dL_dKmm, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
"""Set the gradients of all parameters when doing variational (M) inference with uncertain inputs."""
raise NotImplementedError
def gradients_Z_sparse(self, dL_dKmm, dL_dKnm, dL_dKdiag, X, Z):
grad = self.gradients_X(dL_dKmm, Z)
grad += self.gradients_X(dL_dKnm.T, Z, X)
return grad
def gradients_Z_variational(self, dL_dKmm, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
raise NotImplementedError
def gradients_q_variational(self, dL_dKmm, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
raise NotImplementedError
def plot_ARD(self, *args, **kw):
if "matplotlib" in sys.modules:
from ...plotting.matplot_dep import kernel_plots
@ -68,13 +61,13 @@ class Kern(Parameterized):
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from ...plotting.matplot_dep import kernel_plots
return kernel_plots.plot_ARD(self,*args,**kw)
def input_sensitivity(self):
"""
Returns the sensitivity for each dimension of this kernel.
"""
return np.zeros(self.input_dim)
def __add__(self, other):
""" Overloading of the '+' operator. for more control, see self.add """
return self.add(other)