mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-08 11:32:39 +02:00
slicing now returns the right shape, when computing derivative wrt X or Z
This commit is contained in:
parent
f88b4c92e8
commit
62d594d977
3 changed files with 50 additions and 18 deletions
|
|
@ -58,7 +58,13 @@ class Add(CombinationKernel):
|
|||
:type X2: np.ndarray (num_inducing x input_dim)"""
|
||||
|
||||
target = np.zeros(X.shape)
|
||||
[target.__setitem__([Ellipsis, p.active_dims], target[:, p.active_dims]+p.gradients_X(dL_dK, X, X2)) for p in self.parts]
|
||||
[target.__iadd__(p.gradients_X(dL_dK, X, X2)) for p in self.parts]
|
||||
return target
|
||||
|
||||
def gradients_X_diag(self, dL_dKdiag, X):
|
||||
target = np.zeros(X.shape)
|
||||
[target.__iadd__(p.gradients_X_diag(dL_dKdiag, X)) for p in self.parts]
|
||||
#[target.__setitem__([Ellipsis, p.active_dims], target[:, p.active_dims]+p.gradients_X(dL_dK, X, X2)) for p in self.parts]
|
||||
return target
|
||||
|
||||
def psi0(self, Z, variational_posterior):
|
||||
|
|
@ -131,7 +137,7 @@ class Add(CombinationKernel):
|
|||
eff_dL_dpsi1 += dL_dpsi2.sum(1) * p2.variance * 2.
|
||||
else:
|
||||
eff_dL_dpsi1 += dL_dpsi2.sum(1) * p2.psi1(Z, variational_posterior) * 2.
|
||||
target[:, p1.active_dims] += p1.gradients_Z_expectations(eff_dL_dpsi1, dL_dpsi2, Z, variational_posterior)
|
||||
target += p1.gradients_Z_expectations(eff_dL_dpsi1, dL_dpsi2, Z, variational_posterior)
|
||||
return target
|
||||
|
||||
def gradients_qX_expectations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
||||
|
|
@ -151,8 +157,8 @@ class Add(CombinationKernel):
|
|||
else:
|
||||
eff_dL_dpsi1 += dL_dpsi2.sum(1) * p2.psi1(Z, variational_posterior) * 2.
|
||||
a, b = p1.gradients_qX_expectations(dL_dpsi0, eff_dL_dpsi1, dL_dpsi2, Z, variational_posterior)
|
||||
target_mu[:, p1.active_dims] += a
|
||||
target_S[:, p1.active_dims] += b
|
||||
target_mu += a
|
||||
target_S += b
|
||||
return target_mu, target_S
|
||||
|
||||
def _getstate(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue