Added gradients_X_diag to symmetric.py

This commit is contained in:
Viktor Mirjanic 2023-01-16 23:48:48 +00:00
parent f63ed48b0d
commit db439d7bad

View file

@ -168,3 +168,8 @@ class Symmetric(Kern):
+ self.base_kernel.gradients_X(dL_dK, X_sym, X2_sym).dot(self.transform.T) + self.base_kernel.gradients_X(dL_dK, X_sym, X2_sym).dot(self.transform.T)
+ self.symmetry_sign * self.base_kernel.gradients_X(dL_dK, X, X2_sym) + self.symmetry_sign * self.base_kernel.gradients_X(dL_dK, X, X2_sym)
+ self.symmetry_sign * self.base_kernel.gradients_X(dL_dK, X_sym, X2).dot(self.transform.T)) + self.symmetry_sign * self.base_kernel.gradients_X(dL_dK, X_sym, X2).dot(self.transform.T))
def gradients_X_diag(self, dL_dKdiag, X):
X_sym = X.dot(self.transform)
return ((1 + self.symmetry_sign) * (self.base_kernel.gradients_X_diag(dL_dKdiag, X)
+ self.base_kernel.gradients_X_diag(dL_dKdiag, X_sym).dot(self.transform.T)))