fixed target slicing bug in prod kernel

This commit is contained in:
James Hensman 2013-08-21 09:53:49 +01:00
parent 9e0795afc4
commit 65c0c72361
5 changed files with 35 additions and 30 deletions

View file

@ -80,8 +80,8 @@ class Prod(Kernpart):
def dK_dX(self,dL_dK,X,X2,target):
"""derivative of the covariance matrix with respect to X."""
self._K_computations(X,X2)
self.k1.dK_dX(dL_dK*self._K2, X[:,self.slice1], X2[:,self.slice1], target)
self.k2.dK_dX(dL_dK*self._K1, X[:,self.slice2], X2[:,self.slice2], target)
self.k1.dK_dX(dL_dK*self._K2, X[:,self.slice1], X2[:,self.slice1], target[:,self.slice1])
self.k2.dK_dX(dL_dK*self._K1, X[:,self.slice2], X2[:,self.slice2], target[:,self.slice2])
def dKdiag_dX(self, dL_dKdiag, X, target):
K1 = np.zeros(X.shape[0])
@ -89,8 +89,8 @@ class Prod(Kernpart):
self.k1.Kdiag(X[:,self.slice1],K1)
self.k2.Kdiag(X[:,self.slice2],K2)
self.k1.dK_dX(dL_dKdiag*K2, X[:,self.slice1], target)
self.k2.dK_dX(dL_dKdiag*K1, X[:,self.slice2], target)
self.k1.dK_dX(dL_dKdiag*K2, X[:,self.slice1], target[:,self.slice1])
self.k2.dK_dX(dL_dKdiag*K1, X[:,self.slice2], target[:,self.slice2])
def _K_computations(self,X,X2):
if not (np.array_equal(X,self._X) and np.array_equal(X2,self._X2) and np.array_equal(self._params , self._get_params())):