[splitkern] bug fix

This commit is contained in:
Zhenwen Dai 2014-05-29 13:21:14 +01:00
parent f9291fe7da
commit c2568d2e9f

View file

@ -6,6 +6,8 @@ import numpy as np
from kern import Kern,CombinationKernel
from .independent_outputs import index_to_slices
import itertools
from GPy.kern import Linear,RBF
class SplitKern(CombinationKernel):
"""
@ -45,9 +47,9 @@ class SplitKern(CombinationKernel):
# diagonal blocks
[[target.__setitem__((s,s2), self.kern.K(X[s,:],X2[s2,:])) for s,s2 in itertools.product(slices[i], slices2[i])] for i in xrange(min(len(slices),len(slices)))]
if len(slices)>1:
[target.__setitem__((s,s2), self.kern.K(X[s,:],X2[s2,:])) for s,s2 in itertools.product(slices[1], slices2[0])]
[target.__setitem__((s,s2), self.kern_cross.K(X[s,:],X2[s2,:])) for s,s2 in itertools.product(slices[1], slices2[0])]
if len(slices2)>1:
[target.__setitem__((s,s2), self.kern.K(X[s,:],X2[s2,:])) for s,s2 in itertools.product(slices[0], slices2[1])]
[target.__setitem__((s,s2), self.kern_cross.K(X[s,:],X2[s2,:])) for s,s2 in itertools.product(slices[0], slices2[1])]
return target
def Kdiag(self,X):
@ -60,7 +62,7 @@ class SplitKern(CombinationKernel):
def collate_grads(dL, X, X2, cross=False):
if cross:
self.kern_cross.update_gradients_full(dL,X,X2)
target[:] += self.kern_cross.gradient
target[:] += self.kern_cross.kern.gradient
else:
self.kern.update_gradients_full(dL,X,X2)
target[:] += self.kern.gradient
@ -102,22 +104,22 @@ class SplitKern_cross(Kern):
def update_gradients_full(self, dL_dK, X, X2=None):
if X2 is None:
X2 = X
k1 = self.kern.K(X,self.Xp)
k2 = self.kern.K(self.Xp,X2)
k3 = self.kern.K(self.Xp,self.Xp)
dL_dk1 = np.einsum('ij,j->i',dL_dK,k2.flat)/k3.flat
dL_dk2 = np.einsum('ij,i->j',dL_dK,k1.flat)/k3.flat
dL_dk3 = np.einsum('ij,ij->',dL_dK,-np.dot(k1,k2)/(k3.flat*k3.flat))
dL_dk1 = np.einsum('ij,j->i',dL_dK,k2[0])/k3[0,0]
dL_dk2 = np.einsum('ij,i->j',dL_dK,k1[:,0])/k3[0,0]
dL_dk3 = np.einsum('ij,ij->',dL_dK,-np.dot(k1,k2)/(k3[0,0]*k3[0,0]))
self.kern.update_gradients_full(dL_dk1[:,None],X,self.Xp)
grad1 = self.kern.gradient.copy()
self.kern.update_gradients_full(dL_dk2[None,:],self.Xp,X)
grad2 = self.kern.gradient.copy()
grad = self.kern.gradient.copy()
self.kern.update_gradients_full(dL_dk2[None,:],self.Xp,X2)
grad += self.kern.gradient.copy()
self.kern.update_gradients_full(np.array([[dL_dk3]]),self.Xp,self.Xp)
grad3 = self.kern.gradient.copy()
grad += self.kern.gradient.copy()
self.kern.gradient = grad1+grad2+grad3
self.kern.gradient = grad
def update_gradients_diag(self, dL_dKdiag, X):
k1 = self.kern.K(X,self.Xp)