[splitkern] support idx_p==0

This commit is contained in:
Zhenwen Dai 2014-05-30 18:56:23 +01:00
parent 47ba2542c2
commit 6e30f3c6c3
2 changed files with 12 additions and 0 deletions

View file

@ -20,6 +20,8 @@ def index_to_slices(index):
returns returns
>>> [[slice(0,2,None),slice(4,5,None)],[slice(2,4,None),slice(8,10,None)],[slice(5,8,None)]] >>> [[slice(0,2,None),slice(4,5,None)],[slice(2,4,None),slice(8,10,None)],[slice(5,8,None)]]
""" """
if len(index)==0:
return[]
#contruct the return structure #contruct the return structure
ind = np.asarray(index,dtype=np.int) ind = np.asarray(index,dtype=np.int)

View file

@ -20,6 +20,9 @@ class DiffGenomeKern(Kern):
assert X2==None assert X2==None
K = self.kern.K(X,X2) K = self.kern.K(X,X2)
if self.idx_p<=0 or self.idx_p>X.shape[0]/2:
return K
slices = index_to_slices(X[:,self.index_dim]) slices = index_to_slices(X[:,self.index_dim])
idx_start = slices[1][0].start idx_start = slices[1][0].start
idx_end = idx_start+self.idx_p idx_end = idx_start+self.idx_p
@ -33,6 +36,9 @@ class DiffGenomeKern(Kern):
def Kdiag(self,X): def Kdiag(self,X):
Kdiag = self.kern.Kdiag(X) Kdiag = self.kern.Kdiag(X)
if self.idx_p<=0 or self.idx_p>X.shape[0]/2:
return Kdiag
slices = index_to_slices(X[:,self.index_dim]) slices = index_to_slices(X[:,self.index_dim])
idx_start = slices[1][0].start idx_start = slices[1][0].start
idx_end = idx_start+self.idx_p idx_end = idx_start+self.idx_p
@ -42,6 +48,10 @@ class DiffGenomeKern(Kern):
def update_gradients_full(self,dL_dK,X,X2=None): def update_gradients_full(self,dL_dK,X,X2=None):
assert X2==None assert X2==None
if self.idx_p<=0 or self.idx_p>X.shape[0]/2:
self.kern.update_gradients_full(dL_dK, X)
return
slices = index_to_slices(X[:,self.index_dim]) slices = index_to_slices(X[:,self.index_dim])
idx_start = slices[1][0].start idx_start = slices[1][0].start
idx_end = idx_start+self.idx_p idx_end = idx_start+self.idx_p