some missing .Ks

This commit is contained in:
Ricardo 2014-03-05 11:32:54 +00:00
parent cde8722e1b
commit 8e22373a00

View file

@ -39,17 +39,17 @@ class Prod(Kern):
return self.k1.Kdiag(X[:,self.slice1]) * self.k2.Kdiag(X[:,self.slice2])
def update_gradients_full(self, dL_dK, X):
self.k1.update_gradients_full(dL_dK*self.k2(X[:,self.slice2]), X[:,self.slice1])
self.k2.update_gradients_full(dL_dK*self.k1(X[:,self.slice1]), X[:,self.slice2])
self.k1.update_gradients_full(dL_dK*self.k2.K(X[:,self.slice2]), X[:,self.slice1])
self.k2.update_gradients_full(dL_dK*self.k1.K(X[:,self.slice1]), X[:,self.slice2])
def gradients_X(self, dL_dK, X, X2=None):
target = np.zeros(X.shape)
if X2 is None:
target[:,self.slice1] += self.k1.gradients_X(dL_dK*self.k2(X[:,self.slice2]), X[:,self.slice1], None)
target[:,self.slice2] += self.k2.gradients_X(dL_dK*self.k1(X[:,self.slice1]), X[:,self.slice2], None)
target[:,self.slice1] += self.k1.gradients_X(dL_dK*self.k2.K(X[:,self.slice2]), X[:,self.slice1], None)
target[:,self.slice2] += self.k2.gradients_X(dL_dK*self.k1.K(X[:,self.slice1]), X[:,self.slice2], None)
else:
target[:,self.slice1] += self.k1.gradients_X(dL_dK*self.k2(X[:,self.slice2], X2[:,self.slice2]), X[:,self.slice1], X2[:,self.slice1])
target[:,self.slice2] += self.k2.gradients_X(dL_dK*self.k1(X[:,self.slice1], X2[:,self.slice1]), X[:,self.slice2], X2[:,self.slice2])
target[:,self.slice1] += self.k1.gradients_X(dL_dK*self.k2.K(X[:,self.slice2], X2[:,self.slice2]), X[:,self.slice1], X2[:,self.slice1])
target[:,self.slice2] += self.k2.gradients_X(dL_dK*self.k1.K(X[:,self.slice1], X2[:,self.slice1]), X[:,self.slice2], X2[:,self.slice2])
return target
def gradients_X_diag(self, dL_dKdiag, X):