automatic slicing

This commit is contained in:
Max Zwiessele 2014-03-11 16:24:09 +00:00
parent e078bb47e1
commit 01f5d789c5
3 changed files with 72 additions and 144 deletions

View file

@ -23,7 +23,7 @@ class Add(CombinationKernel):
elif not isinstance(which_parts, (list, tuple)): elif not isinstance(which_parts, (list, tuple)):
# if only one part is given # if only one part is given
which_parts = [which_parts] which_parts = [which_parts]
return sum([p.K(X, X2) for p in which_parts]) return reduce(np.add, (p.K(X, X2) for p in which_parts))
def gradients_X(self, dL_dK, X, X2=None): def gradients_X(self, dL_dK, X, X2=None):
"""Compute the gradient of the objective function with respect to X. """Compute the gradient of the objective function with respect to X.
@ -49,14 +49,14 @@ class Add(CombinationKernel):
def psi0(self, Z, variational_posterior): def psi0(self, Z, variational_posterior):
return np.sum([p.psi0(Z[:, i_s], variational_posterior[:, i_s]) for p, i_s in zip(self._parameters_, self.input_slices)],0) return reduce(np.add, (p.psi0(Z, variational_posterior) for p in self.parts))
def psi1(self, Z, variational_posterior): def psi1(self, Z, variational_posterior):
return np.sum([p.psi1(Z[:, i_s], variational_posterior[:, i_s]) for p, i_s in zip(self._parameters_, self.input_slices)], 0) return reduce(np.add, (p.psi1(Z, variational_posterior) for p in self.parts))
def psi2(self, Z, variational_posterior): def psi2(self, Z, variational_posterior):
psi2 = np.sum([p.psi2(Z[:, i_s], variational_posterior[:, i_s]) for p, i_s in zip(self._parameters_, self.input_slices)], 0) psi2 = reduce(np.add, (p.psi2(Z, variational_posterior) for p in self.parts))
return psi2
# compute the "cross" terms # compute the "cross" terms
from static import White, Bias from static import White, Bias
from rbf import RBF from rbf import RBF
@ -64,18 +64,23 @@ class Add(CombinationKernel):
from linear import Linear from linear import Linear
#ffrom fixed import Fixed #ffrom fixed import Fixed
for (p1, i1), (p2, i2) in itertools.combinations(itertools.izip(self._parameters_, self.input_slices), 2): for p1, p2 in itertools.combinations(self.parts, 2):
i1, i2 = p1.active_dims, p2.active_dims
# white doesn;t combine with anything # white doesn;t combine with anything
if isinstance(p1, White) or isinstance(p2, White): if isinstance(p1, White) or isinstance(p2, White):
pass pass
# rbf X bias # rbf X bias
#elif isinstance(p1, (Bias, Fixed)) and isinstance(p2, (RBF, RBFInv)): #elif isinstance(p1, (Bias, Fixed)) and isinstance(p2, (RBF, RBFInv)):
elif isinstance(p1, Bias) and isinstance(p2, (RBF, Linear)): elif isinstance(p1, Bias) and isinstance(p2, (RBF, Linear)):
tmp = p2.psi1(Z[:,i2], variational_posterior[:, i_s]) # manual override for slicing:
p2._sliced_X = p1._sliced_X = True
tmp = p2.psi1(Z[:,i2], variational_posterior[:, i1])
psi2 += p1.variance * (tmp[:, :, None] + tmp[:, None, :]) psi2 += p1.variance * (tmp[:, :, None] + tmp[:, None, :])
#elif isinstance(p2, (Bias, Fixed)) and isinstance(p1, (RBF, RBFInv)): #elif isinstance(p2, (Bias, Fixed)) and isinstance(p1, (RBF, RBFInv)):
elif isinstance(p2, Bias) and isinstance(p1, (RBF, Linear)): elif isinstance(p2, Bias) and isinstance(p1, (RBF, Linear)):
tmp = p1.psi1(Z[:,i1], variational_posterior[:, i_s]) # manual override for slicing:
p2._sliced_X = p1._sliced_X = True
tmp = p1.psi1(Z[:,i1], variational_posterior[:, i2])
psi2 += p2.variance * (tmp[:, :, None] + tmp[:, None, :]) psi2 += p2.variance * (tmp[:, :, None] + tmp[:, None, :])
else: else:
raise NotImplementedError, "psi2 cannot be computed for this kernel" raise NotImplementedError, "psi2 cannot be computed for this kernel"
@ -83,11 +88,10 @@ class Add(CombinationKernel):
def update_gradients_expectations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior): def update_gradients_expectations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
from static import White, Bias from static import White, Bias
for p1, is1 in zip(self._parameters_, self.input_slices): for p1 in self.parts:
#compute the effective dL_dpsi1. Extra terms appear becaue of the cross terms in psi2! #compute the effective dL_dpsi1. Extra terms appear becaue of the cross terms in psi2!
eff_dL_dpsi1 = dL_dpsi1.copy() eff_dL_dpsi1 = dL_dpsi1.copy()
for p2, is2 in zip(self._parameters_, self.input_slices): for p2 in self.parts:
if p2 is p1: if p2 is p1:
continue continue
if isinstance(p2, White): if isinstance(p2, White):
@ -95,42 +99,35 @@ class Add(CombinationKernel):
elif isinstance(p2, Bias): elif isinstance(p2, Bias):
eff_dL_dpsi1 += dL_dpsi2.sum(1) * p2.variance * 2. eff_dL_dpsi1 += dL_dpsi2.sum(1) * p2.variance * 2.
else: else:
eff_dL_dpsi1 += dL_dpsi2.sum(1) * p2.psi1(Z[:,is2], variational_posterior[:, is1]) * 2. eff_dL_dpsi1 += dL_dpsi2.sum(1) * p2.psi1(Z, variational_posterior) * 2.
p1.update_gradients_expectations(dL_dpsi0, eff_dL_dpsi1, dL_dpsi2, Z, variational_posterior)
p1.update_gradients_expectations(dL_dpsi0, eff_dL_dpsi1, dL_dpsi2, Z[:,is1], variational_posterior[:, is1])
def gradients_Z_expectations(self, dL_dpsi1, dL_dpsi2, Z, variational_posterior): def gradients_Z_expectations(self, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
from static import White, Bias from static import White, Bias
target = np.zeros(Z.shape) target = np.zeros(Z.shape)
for p1, is1 in zip(self._parameters_, self.input_slices): for p1 in self.parts:
#compute the effective dL_dpsi1. extra terms appear becaue of the cross terms in psi2! #compute the effective dL_dpsi1. extra terms appear becaue of the cross terms in psi2!
eff_dL_dpsi1 = dL_dpsi1.copy() eff_dL_dpsi1 = dL_dpsi1.copy()
for p2, is2 in zip(self._parameters_, self.input_slices): for p2 in self.parts:
if p2 is p1: if p2 is p1:
continue continue
if isinstance(p2, White): if isinstance(p2, White):
continue continue
elif isinstance(p2, Bias): elif isinstance(p2, Bias):
eff_dL_dpsi1 += dL_dpsi2.sum(1) * p2.variance * 2. eff_dL_dpsi1 += 0#dL_dpsi2.sum(1) * p2.variance * 2.
else: else:
eff_dL_dpsi1 += dL_dpsi2.sum(1) * p2.psi1(Z[:,is2], variational_posterior[:, is2]) * 2. eff_dL_dpsi1 += 0#dL_dpsi2.sum(1) * p2.psi1(Z, variational_posterior) * 2.
target[:, p1.active_dims] += p1.gradients_Z_expectations(eff_dL_dpsi1, dL_dpsi2, Z, variational_posterior)
target += p1.gradients_Z_expectations(eff_dL_dpsi1, dL_dpsi2, Z[:,is1], variational_posterior[:, is1])
return target return target
def gradients_qX_expectations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior): def gradients_qX_expectations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
from static import White, Bias from static import White, Bias
target_mu = np.zeros(variational_posterior.shape) target_mu = np.zeros(variational_posterior.shape)
target_S = np.zeros(variational_posterior.shape) target_S = np.zeros(variational_posterior.shape)
for p1, is1 in zip(self._parameters_, self.input_slices): for p1 in self._parameters_:
#compute the effective dL_dpsi1. extra terms appear becaue of the cross terms in psi2! #compute the effective dL_dpsi1. extra terms appear becaue of the cross terms in psi2!
eff_dL_dpsi1 = dL_dpsi1.copy() eff_dL_dpsi1 = dL_dpsi1.copy()
for p2, is2 in zip(self._parameters_, self.input_slices): for p2 in self._parameters_:
if p2 is p1: if p2 is p1:
continue continue
if isinstance(p2, White): if isinstance(p2, White):
@ -138,35 +135,20 @@ class Add(CombinationKernel):
elif isinstance(p2, Bias): elif isinstance(p2, Bias):
eff_dL_dpsi1 += dL_dpsi2.sum(1) * p2.variance * 2. eff_dL_dpsi1 += dL_dpsi2.sum(1) * p2.variance * 2.
else: else:
eff_dL_dpsi1 += dL_dpsi2.sum(1) * p2.psi1(Z[:,is2], variational_posterior[:, is2]) * 2. eff_dL_dpsi1 += dL_dpsi2.sum(1) * p2.psi1(Z, variational_posterior) * 2.
a, b = p1.gradients_qX_expectations(dL_dpsi0, eff_dL_dpsi1, dL_dpsi2, Z, variational_posterior)
target_mu[:, p1.active_dims] += a
a, b = p1.gradients_qX_expectations(dL_dpsi0, eff_dL_dpsi1, dL_dpsi2, Z[:,is1], variational_posterior[:, is1]) target_S[:, p1.active_dims] += b
target_mu += a
target_S += b
return target_mu, target_S return target_mu, target_S
def input_sensitivity(self):
in_sen = np.zeros((self.num_params, self.input_dim))
for i, [p, i_s] in enumerate(zip(self._parameters_, self.input_slices)):
in_sen[i, i_s] = p.input_sensitivity()
return in_sen
def _getstate(self): def _getstate(self):
""" """
Get the current state of the class, Get the current state of the class,
here just all the indices, rest can get recomputed here just all the indices, rest can get recomputed
""" """
return Parameterized._getstate(self) + [#self._parameters_, return super(Add, self)._getstate()
self.input_dim,
self.input_slices,
self._param_slices_
]
def _setstate(self, state): def _setstate(self, state):
self._param_slices_ = state.pop() super(Add, self)._setstate(state)
self.input_slices = state.pop()
self.input_dim = state.pop()
Parameterized._setstate(self, state)

View file

@ -3,25 +3,18 @@
import sys import sys
import numpy as np import numpy as np
from ...core.parameterization.parameterized import ParametersChangedMeta, Parameterized from ...core.parameterization.parameterized import Parameterized
from kernel_slice_operations import KernCallsViaSlicerMeta
from ...util.caching import Cache_this from ...util.caching import Cache_this
class KernCallsViaSlicerMeta(ParametersChangedMeta):
def __call__(self, *args, **kw):
instance = super(KernCallsViaSlicerMeta, self).__call__(*args, **kw)
instance.K = instance._slice_wrapper(instance.K)
instance.Kdiag = instance._slice_wrapper(instance.Kdiag, True)
instance.update_gradients_full = instance._slice_wrapper(instance.update_gradients_full, False, True)
instance.update_gradients_diag = instance._slice_wrapper(instance.update_gradients_diag, True, True)
instance.gradients_X = instance._slice_wrapper(instance.gradients_X, False, True)
instance.gradients_X_diag = instance._slice_wrapper(instance.gradients_X_diag, True, True)
instance.psi0 = instance._slice_wrapper(instance.psi0, False, False)
instance.psi1 = instance._slice_wrapper(instance.psi1, False, False)
instance.psi2 = instance._slice_wrapper(instance.psi2, False, False)
return instance
class Kern(Parameterized): class Kern(Parameterized):
#===========================================================================
# This adds input slice support. The rather ugly code for slicing can be
# found in kernel_slice_operations
__metaclass__ = KernCallsViaSlicerMeta __metaclass__ = KernCallsViaSlicerMeta
#===========================================================================
def __init__(self, input_dim, name, *a, **kw): def __init__(self, input_dim, name, *a, **kw):
""" """
The base class for a kernel: a positive definite function The base class for a kernel: a positive definite function
@ -40,76 +33,11 @@ class Kern(Parameterized):
self.active_dims = input_dim self.active_dims = input_dim
self.input_dim = len(self.active_dims) self.input_dim = len(self.active_dims)
self._sliced_X = False self._sliced_X = False
self._sliced_X2 = False
@Cache_this(limit=10)#, ignore_args = (0,)) @Cache_this(limit=10)#, ignore_args = (0,))
def _slice_X(self, X): def _slice_X(self, X):
return X[:, self.active_dims] return X[:, self.active_dims]
def _slice_wrapper(self, operation, diag=False, derivative=False):
"""
This method wraps the functions in kernel to make sure all kernels allways see their respective input dimension.
The different switches are:
diag: if X2 exists
derivative: if firest arg is dL_dK
"""
if derivative:
if diag:
def x_slice_wrapper(dL_dK, X, *args, **kw):
X = self._slice_X(X) if not self._sliced_X else X
self._sliced_X = True
try:
ret = operation(dL_dK, X, *args, **kw)
except:
raise
finally:
self._sliced_X = False
return ret
else:
def x_slice_wrapper(dL_dK, X, X2=None, *args, **kw):
X, X2 = self._slice_X(X) if not self._sliced_X else X, self._slice_X(X2) if X2 is not None and not self._sliced_X2 else X2
self._sliced_X = True
self._sliced_X2 = True
try:
ret = operation(dL_dK, X, X2, *args, **kw)
except:
raise
finally:
self._sliced_X = False
self._sliced_X2 = False
return ret
else:
if diag:
def x_slice_wrapper(X, *args, **kw):
X = self._slice_X(X) if not self._sliced_X else X
self._sliced_X = True
try:
ret = operation(X, *args, **kw)
except:
raise
finally:
self._sliced_X = False
return ret
else:
def x_slice_wrapper(X, X2=None, *args, **kw):
X, X2 = self._slice_X(X) if not self._sliced_X else X, self._slice_X(X2) if X2 is not None and not self._sliced_X2 else X2
self._sliced_X = True
self._sliced_X2 = True
try:
ret = operation(X, X2, *args, **kw)
except: raise
finally:
self._sliced_X = False
self._sliced_X2 = False
return ret
x_slice_wrapper._operation = operation
x_slice_wrapper.__name__ = ("slicer("+operation.__name__
+(","+str(bool(diag)) if diag else'')
+(','+str(bool(derivative)) if derivative else '')
+')')
x_slice_wrapper.__doc__ = "**sliced**\n" + (operation.__doc__ or "")
return x_slice_wrapper
def K(self, X, X2): def K(self, X, X2):
""" """
Compute the kernel function. Compute the kernel function.
@ -241,6 +169,21 @@ class Kern(Parameterized):
else: kernels.append(other) else: kernels.append(other)
return Prod(self, other, name) return Prod(self, other, name)
def _getstate(self):
"""
Get the current state of the class,
here just all the indices, rest can get recomputed
"""
return super(Kern, self)._getstate() + [
self.active_dims,
self.input_dim,
self._sliced_X]
def _setstate(self, state):
self._sliced_X = state.pop()
self.input_dim = state.pop()
self.active_dims = state.pop()
super(Kern, self)._setstate(state)
class CombinationKernel(Kern): class CombinationKernel(Kern):
def __init__(self, kernels, name): def __init__(self, kernels, name):
@ -258,3 +201,9 @@ class CombinationKernel(Kern):
def update_gradients_diag(self, dL_dK, X): def update_gradients_diag(self, dL_dK, X):
[p.update_gradients_diag(dL_dK, X) for p in self.parts] [p.update_gradients_diag(dL_dK, X) for p in self.parts]
def input_sensitivity(self):
in_sen = np.zeros((self.num_params, self.input_dim))
for i, p in enumerate(self.parts):
in_sen[i, p.active_dims] = p.input_sensitivity()
return in_sen

View file

@ -56,28 +56,28 @@ class RBF(Stationary):
if isinstance(variational_posterior, variational.SpikeAndSlabPosterior): if isinstance(variational_posterior, variational.SpikeAndSlabPosterior):
_, _dpsi1_dvariance, _, _, _, _, _dpsi1_dlengthscale = ssrbf_psi_comp._psi1computations(self.variance, self.lengthscale, Z, variational_posterior.mean, variational_posterior.variance, variational_posterior.binary_prob) _, _dpsi1_dvariance, _, _, _, _, _dpsi1_dlengthscale = ssrbf_psi_comp._psi1computations(self.variance, self.lengthscale, Z, variational_posterior.mean, variational_posterior.variance, variational_posterior.binary_prob)
_, _dpsi2_dvariance, _, _, _, _, _dpsi2_dlengthscale = ssrbf_psi_comp._psi2computations(self.variance, self.lengthscale, Z, variational_posterior.mean, variational_posterior.variance, variational_posterior.binary_prob) _, _dpsi2_dvariance, _, _, _, _, _dpsi2_dlengthscale = ssrbf_psi_comp._psi2computations(self.variance, self.lengthscale, Z, variational_posterior.mean, variational_posterior.variance, variational_posterior.binary_prob)
#contributions from psi0: #contributions from psi0:
self.variance.gradient = np.sum(dL_dpsi0) self.variance.gradient = np.sum(dL_dpsi0)
#from psi1 #from psi1
self.variance.gradient += np.sum(dL_dpsi1 * _dpsi1_dvariance) self.variance.gradient += np.sum(dL_dpsi1 * _dpsi1_dvariance)
if self.ARD: if self.ARD:
self.lengthscale.gradient = (dL_dpsi1[:,:,None]*_dpsi1_dlengthscale).reshape(-1,self.input_dim).sum(axis=0) self.lengthscale.gradient = (dL_dpsi1[:,:,None]*_dpsi1_dlengthscale).reshape(-1,self.input_dim).sum(axis=0)
else: else:
self.lengthscale.gradient = (dL_dpsi1[:,:,None]*_dpsi1_dlengthscale).sum() self.lengthscale.gradient = (dL_dpsi1[:,:,None]*_dpsi1_dlengthscale).sum()
#from psi2 #from psi2
self.variance.gradient += (dL_dpsi2 * _dpsi2_dvariance).sum() self.variance.gradient += (dL_dpsi2 * _dpsi2_dvariance).sum()
if self.ARD: if self.ARD:
self.lengthscale.gradient += (dL_dpsi2[:,:,:,None] * _dpsi2_dlengthscale).reshape(-1,self.input_dim).sum(axis=0) self.lengthscale.gradient += (dL_dpsi2[:,:,:,None] * _dpsi2_dlengthscale).reshape(-1,self.input_dim).sum(axis=0)
else: else:
self.lengthscale.gradient += (dL_dpsi2[:,:,:,None] * _dpsi2_dlengthscale).sum() self.lengthscale.gradient += (dL_dpsi2[:,:,:,None] * _dpsi2_dlengthscale).sum()
elif isinstance(variational_posterior, variational.NormalPosterior): elif isinstance(variational_posterior, variational.NormalPosterior):
l2 = self.lengthscale**2
l2 = self.lengthscale **2 if l2.size != self.input_dim:
l2 = l2*np.ones(self.input_dim)
#contributions from psi0: #contributions from psi0:
self.variance.gradient = np.sum(dL_dpsi0) self.variance.gradient = np.sum(dL_dpsi0)
@ -92,11 +92,9 @@ class RBF(Stationary):
else: else:
self.lengthscale.gradient += dpsi1_dlength.sum() self.lengthscale.gradient += dpsi1_dlength.sum()
self.variance.gradient += np.sum(dL_dpsi1 * psi1) / self.variance self.variance.gradient += np.sum(dL_dpsi1 * psi1) / self.variance
#from psi2 #from psi2
S = variational_posterior.variance S = variational_posterior.variance
_, Zdist_sq, _, mudist_sq, psi2 = self._psi2computations(Z, variational_posterior) _, Zdist_sq, _, mudist_sq, psi2 = self._psi2computations(Z, variational_posterior)
if not self.ARD: if not self.ARD:
self.lengthscale.gradient += self._weave_psi2_lengthscale_grads(dL_dpsi2, psi2, Zdist_sq, S, mudist_sq, l2).sum() self.lengthscale.gradient += self._weave_psi2_lengthscale_grads(dL_dpsi2, psi2, Zdist_sq, S, mudist_sq, l2).sum()
else: else:
@ -112,17 +110,16 @@ class RBF(Stationary):
if isinstance(variational_posterior, variational.SpikeAndSlabPosterior): if isinstance(variational_posterior, variational.SpikeAndSlabPosterior):
_, _, _, _, _, _dpsi1_dZ, _ = ssrbf_psi_comp._psi1computations(self.variance, self.lengthscale, Z, variational_posterior.mean, variational_posterior.variance, variational_posterior.binary_prob) _, _, _, _, _, _dpsi1_dZ, _ = ssrbf_psi_comp._psi1computations(self.variance, self.lengthscale, Z, variational_posterior.mean, variational_posterior.variance, variational_posterior.binary_prob)
_, _, _, _, _, _dpsi2_dZ, _ = ssrbf_psi_comp._psi2computations(self.variance, self.lengthscale, Z, variational_posterior.mean, variational_posterior.variance, variational_posterior.binary_prob) _, _, _, _, _, _dpsi2_dZ, _ = ssrbf_psi_comp._psi2computations(self.variance, self.lengthscale, Z, variational_posterior.mean, variational_posterior.variance, variational_posterior.binary_prob)
#psi1 #psi1
grad = (dL_dpsi1[:, :, None] * _dpsi1_dZ).sum(axis=0) grad = (dL_dpsi1[:, :, None] * _dpsi1_dZ).sum(axis=0)
#psi2 #psi2
grad += (dL_dpsi2[:, :, :, None] * _dpsi2_dZ).sum(axis=0).sum(axis=1) grad += (dL_dpsi2[:, :, :, None] * _dpsi2_dZ).sum(axis=0).sum(axis=1)
return grad return grad
elif isinstance(variational_posterior, variational.NormalPosterior): elif isinstance(variational_posterior, variational.NormalPosterior):
l2 = self.lengthscale **2 l2 = self.lengthscale **2
#psi1 #psi1
@ -145,10 +142,10 @@ class RBF(Stationary):
# Spike-and-Slab GPLVM # Spike-and-Slab GPLVM
if isinstance(variational_posterior, variational.SpikeAndSlabPosterior): if isinstance(variational_posterior, variational.SpikeAndSlabPosterior):
ndata = variational_posterior.mean.shape[0] ndata = variational_posterior.mean.shape[0]
_, _, _dpsi1_dgamma, _dpsi1_dmu, _dpsi1_dS, _, _ = ssrbf_psi_comp._psi1computations(self.variance, self.lengthscale, Z, variational_posterior.mean, variational_posterior.variance, variational_posterior.binary_prob) _, _, _dpsi1_dgamma, _dpsi1_dmu, _dpsi1_dS, _, _ = ssrbf_psi_comp._psi1computations(self.variance, self.lengthscale, Z, variational_posterior.mean, variational_posterior.variance, variational_posterior.binary_prob)
_, _, _dpsi2_dgamma, _dpsi2_dmu, _dpsi2_dS, _, _ = ssrbf_psi_comp._psi2computations(self.variance, self.lengthscale, Z, variational_posterior.mean, variational_posterior.variance, variational_posterior.binary_prob) _, _, _dpsi2_dgamma, _dpsi2_dmu, _dpsi2_dS, _, _ = ssrbf_psi_comp._psi2computations(self.variance, self.lengthscale, Z, variational_posterior.mean, variational_posterior.variance, variational_posterior.binary_prob)
#psi1 #psi1
grad_mu = (dL_dpsi1[:, :, None] * _dpsi1_dmu).sum(axis=1) grad_mu = (dL_dpsi1[:, :, None] * _dpsi1_dmu).sum(axis=1)
grad_S = (dL_dpsi1[:, :, None] * _dpsi1_dS).sum(axis=1) grad_S = (dL_dpsi1[:, :, None] * _dpsi1_dS).sum(axis=1)
@ -157,11 +154,11 @@ class RBF(Stationary):
grad_mu += (dL_dpsi2[:, :, :, None] * _dpsi2_dmu).reshape(ndata,-1,self.input_dim).sum(axis=1) grad_mu += (dL_dpsi2[:, :, :, None] * _dpsi2_dmu).reshape(ndata,-1,self.input_dim).sum(axis=1)
grad_S += (dL_dpsi2[:, :, :, None] * _dpsi2_dS).reshape(ndata,-1,self.input_dim).sum(axis=1) grad_S += (dL_dpsi2[:, :, :, None] * _dpsi2_dS).reshape(ndata,-1,self.input_dim).sum(axis=1)
grad_gamma += (dL_dpsi2[:,:,:, None] * _dpsi2_dgamma).reshape(ndata,-1,self.input_dim).sum(axis=1) grad_gamma += (dL_dpsi2[:,:,:, None] * _dpsi2_dgamma).reshape(ndata,-1,self.input_dim).sum(axis=1)
return grad_mu, grad_S, grad_gamma return grad_mu, grad_S, grad_gamma
elif isinstance(variational_posterior, variational.NormalPosterior): elif isinstance(variational_posterior, variational.NormalPosterior):
l2 = self.lengthscale **2 l2 = self.lengthscale **2
#psi1 #psi1
denom, dist, dist_sq, psi1 = self._psi1computations(Z, variational_posterior) denom, dist, dist_sq, psi1 = self._psi1computations(Z, variational_posterior)