mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-15 06:52:39 +02:00
[splitkern] bug fix
This commit is contained in:
parent
f9291fe7da
commit
c2568d2e9f
1 changed files with 15 additions and 13 deletions
|
|
@ -6,6 +6,8 @@ import numpy as np
|
||||||
from kern import Kern,CombinationKernel
|
from kern import Kern,CombinationKernel
|
||||||
from .independent_outputs import index_to_slices
|
from .independent_outputs import index_to_slices
|
||||||
import itertools
|
import itertools
|
||||||
|
from GPy.kern import Linear,RBF
|
||||||
|
|
||||||
|
|
||||||
class SplitKern(CombinationKernel):
|
class SplitKern(CombinationKernel):
|
||||||
"""
|
"""
|
||||||
|
|
@ -45,9 +47,9 @@ class SplitKern(CombinationKernel):
|
||||||
# diagonal blocks
|
# diagonal blocks
|
||||||
[[target.__setitem__((s,s2), self.kern.K(X[s,:],X2[s2,:])) for s,s2 in itertools.product(slices[i], slices2[i])] for i in xrange(min(len(slices),len(slices)))]
|
[[target.__setitem__((s,s2), self.kern.K(X[s,:],X2[s2,:])) for s,s2 in itertools.product(slices[i], slices2[i])] for i in xrange(min(len(slices),len(slices)))]
|
||||||
if len(slices)>1:
|
if len(slices)>1:
|
||||||
[target.__setitem__((s,s2), self.kern.K(X[s,:],X2[s2,:])) for s,s2 in itertools.product(slices[1], slices2[0])]
|
[target.__setitem__((s,s2), self.kern_cross.K(X[s,:],X2[s2,:])) for s,s2 in itertools.product(slices[1], slices2[0])]
|
||||||
if len(slices2)>1:
|
if len(slices2)>1:
|
||||||
[target.__setitem__((s,s2), self.kern.K(X[s,:],X2[s2,:])) for s,s2 in itertools.product(slices[0], slices2[1])]
|
[target.__setitem__((s,s2), self.kern_cross.K(X[s,:],X2[s2,:])) for s,s2 in itertools.product(slices[0], slices2[1])]
|
||||||
return target
|
return target
|
||||||
|
|
||||||
def Kdiag(self,X):
|
def Kdiag(self,X):
|
||||||
|
|
@ -60,7 +62,7 @@ class SplitKern(CombinationKernel):
|
||||||
def collate_grads(dL, X, X2, cross=False):
|
def collate_grads(dL, X, X2, cross=False):
|
||||||
if cross:
|
if cross:
|
||||||
self.kern_cross.update_gradients_full(dL,X,X2)
|
self.kern_cross.update_gradients_full(dL,X,X2)
|
||||||
target[:] += self.kern_cross.gradient
|
target[:] += self.kern_cross.kern.gradient
|
||||||
else:
|
else:
|
||||||
self.kern.update_gradients_full(dL,X,X2)
|
self.kern.update_gradients_full(dL,X,X2)
|
||||||
target[:] += self.kern.gradient
|
target[:] += self.kern.gradient
|
||||||
|
|
@ -102,22 +104,22 @@ class SplitKern_cross(Kern):
|
||||||
def update_gradients_full(self, dL_dK, X, X2=None):
|
def update_gradients_full(self, dL_dK, X, X2=None):
|
||||||
if X2 is None:
|
if X2 is None:
|
||||||
X2 = X
|
X2 = X
|
||||||
|
|
||||||
k1 = self.kern.K(X,self.Xp)
|
k1 = self.kern.K(X,self.Xp)
|
||||||
k2 = self.kern.K(self.Xp,X2)
|
k2 = self.kern.K(self.Xp,X2)
|
||||||
k3 = self.kern.K(self.Xp,self.Xp)
|
k3 = self.kern.K(self.Xp,self.Xp)
|
||||||
dL_dk1 = np.einsum('ij,j->i',dL_dK,k2.flat)/k3.flat
|
dL_dk1 = np.einsum('ij,j->i',dL_dK,k2[0])/k3[0,0]
|
||||||
dL_dk2 = np.einsum('ij,i->j',dL_dK,k1.flat)/k3.flat
|
dL_dk2 = np.einsum('ij,i->j',dL_dK,k1[:,0])/k3[0,0]
|
||||||
dL_dk3 = np.einsum('ij,ij->',dL_dK,-np.dot(k1,k2)/(k3.flat*k3.flat))
|
dL_dk3 = np.einsum('ij,ij->',dL_dK,-np.dot(k1,k2)/(k3[0,0]*k3[0,0]))
|
||||||
|
|
||||||
self.kern.update_gradients_full(dL_dk1[:,None],X,self.Xp)
|
self.kern.update_gradients_full(dL_dk1[:,None],X,self.Xp)
|
||||||
grad1 = self.kern.gradient.copy()
|
grad = self.kern.gradient.copy()
|
||||||
self.kern.update_gradients_full(dL_dk2[None,:],self.Xp,X)
|
self.kern.update_gradients_full(dL_dk2[None,:],self.Xp,X2)
|
||||||
grad2 = self.kern.gradient.copy()
|
grad += self.kern.gradient.copy()
|
||||||
self.kern.update_gradients_full(np.array([[dL_dk3]]),self.Xp,self.Xp)
|
self.kern.update_gradients_full(np.array([[dL_dk3]]),self.Xp,self.Xp)
|
||||||
grad3 = self.kern.gradient.copy()
|
grad += self.kern.gradient.copy()
|
||||||
|
|
||||||
self.kern.gradient = grad1+grad2+grad3
|
self.kern.gradient = grad
|
||||||
|
|
||||||
def update_gradients_diag(self, dL_dKdiag, X):
|
def update_gradients_diag(self, dL_dKdiag, X):
|
||||||
k1 = self.kern.K(X,self.Xp)
|
k1 = self.kern.K(X,self.Xp)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue