mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-24 14:15:14 +02:00
plotting conflict fixed
This commit is contained in:
commit
d3eaef5c99
13 changed files with 196 additions and 158 deletions
|
|
@ -182,7 +182,7 @@ class RBF(Kern):
|
|||
|
||||
return grad
|
||||
|
||||
def gradients_q_variational(self, dL_dKmm, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, posterior_variational):
|
||||
def update_gradients_q_variational(self, dL_dKmm, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, posterior_variational):
|
||||
mu = posterior_variational.mean
|
||||
S = posterior_variational.variance
|
||||
self._psi_computations(Z, mu, S)
|
||||
|
|
@ -194,8 +194,9 @@ class RBF(Kern):
|
|||
tmp = self._psi2[:, :, :, None] / self.lengthscale2 / self._psi2_denom
|
||||
grad_mu += -2.*(dL_dpsi2[:, :, :, None] * tmp * self._psi2_mudist).sum(1).sum(1)
|
||||
grad_S += (dL_dpsi2[:, :, :, None] * tmp * (2.*self._psi2_mudist_sq - 1)).sum(1).sum(1)
|
||||
|
||||
return grad_mu, grad_S
|
||||
|
||||
posterior_variational.mean.gradient = grad_mu
|
||||
posterior_variational.variance.gradient = grad_S
|
||||
|
||||
def gradients_X(self, dL_dK, X, X2=None):
|
||||
#if self._X is None or X.base is not self._X.base or X2 is not None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue