changed to 'update_gradients_q_variational'

This commit is contained in:
Zhenwen Dai 2014-02-24 11:33:58 +00:00
parent c2750ae719
commit f311bfdf17
3 changed files with 7 additions and 8 deletions

View file

@ -182,7 +182,7 @@ class RBF(Kern):
return grad
def gradients_q_variational(self, dL_dKmm, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, posterior_variational):
def update_gradients_q_variational(self, dL_dKmm, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, posterior_variational):
mu = posterior_variational.mean
S = posterior_variational.variance
self._psi_computations(Z, mu, S)
@ -194,8 +194,9 @@ class RBF(Kern):
tmp = self._psi2[:, :, :, None] / self.lengthscale2 / self._psi2_denom
grad_mu += -2.*(dL_dpsi2[:, :, :, None] * tmp * self._psi2_mudist).sum(1).sum(1)
grad_S += (dL_dpsi2[:, :, :, None] * tmp * (2.*self._psi2_mudist_sq - 1)).sum(1).sum(1)
return grad_mu, grad_S
posterior_variational.mean.gradient = grad_mu
posterior_variational.variance.gradient = grad_S
def gradients_X(self, dL_dK, X, X2=None):
#if self._X is None or X.base is not self._X.base or X2 is not None: