diff --git a/GPy/kern/src/symmetric.py b/GPy/kern/src/symmetric.py index c7207023..8797a5a3 100644 --- a/GPy/kern/src/symmetric.py +++ b/GPy/kern/src/symmetric.py @@ -168,3 +168,8 @@ class Symmetric(Kern): + self.base_kernel.gradients_X(dL_dK, X_sym, X2_sym).dot(self.transform.T) + self.symmetry_sign * self.base_kernel.gradients_X(dL_dK, X, X2_sym) + self.symmetry_sign * self.base_kernel.gradients_X(dL_dK, X_sym, X2).dot(self.transform.T)) + +def gradients_X_diag(self, dL_dKdiag, X): + X_sym = X.dot(self.transform) + return ((1 + self.symmetry_sign) * (self.base_kernel.gradients_X_diag(dL_dKdiag, X) + + self.base_kernel.gradients_X_diag(dL_dKdiag, X_sym).dot(self.transform.T)))