plotting conflict fixed

This commit is contained in:
James Hensman 2014-02-24 13:55:11 +00:00
commit d3eaef5c99
13 changed files with 196 additions and 158 deletions

View file

@ -182,7 +182,7 @@ class RBF(Kern):
return grad
def gradients_q_variational(self, dL_dKmm, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, posterior_variational):
def update_gradients_q_variational(self, dL_dKmm, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, posterior_variational):
mu = posterior_variational.mean
S = posterior_variational.variance
self._psi_computations(Z, mu, S)
@ -194,8 +194,9 @@ class RBF(Kern):
tmp = self._psi2[:, :, :, None] / self.lengthscale2 / self._psi2_denom
grad_mu += -2.*(dL_dpsi2[:, :, :, None] * tmp * self._psi2_mudist).sum(1).sum(1)
grad_S += (dL_dpsi2[:, :, :, None] * tmp * (2.*self._psi2_mudist_sq - 1)).sum(1).sum(1)
return grad_mu, grad_S
posterior_variational.mean.gradient = grad_mu
posterior_variational.variance.gradient = grad_S
def gradients_X(self, dL_dK, X, X2=None):
#if self._X is None or X.base is not self._X.base or X2 is not None: