mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-15 06:52:39 +02:00
changes to the hierarchical kernpart.
Looks to work now.
This commit is contained in:
parent
f19a26a006
commit
3ad48534c8
2 changed files with 10 additions and 9 deletions
|
|
@ -306,4 +306,4 @@ def hierarchical(k):
|
||||||
# for sl in k.input_slices:
|
# for sl in k.input_slices:
|
||||||
# assert (sl.start is None) and (sl.stop is None), "cannot adjust input slices! (TODO)"
|
# assert (sl.start is None) and (sl.stop is None), "cannot adjust input slices! (TODO)"
|
||||||
_parts = [parts.hierarchical.Hierarchical(k.parts)]
|
_parts = [parts.hierarchical.Hierarchical(k.parts)]
|
||||||
return kern(k.input_dim+1,_parts)
|
return kern(k.input_dim+len(k.parts),_parts)
|
||||||
|
|
|
||||||
|
|
@ -24,26 +24,26 @@ class Hierarchical(Kernpart):
|
||||||
return np.hstack([k._get_params() for k in self.parts])
|
return np.hstack([k._get_params() for k in self.parts])
|
||||||
|
|
||||||
def _set_params(self,x):
|
def _set_params(self,x):
|
||||||
[k._set_params(x[start:stop]) for start, stop in zip(self.param_starts, self.param_stops)]
|
[k._set_params(x[start:stop]) for k, start, stop in zip(self.parts, self.param_starts, self.param_stops)]
|
||||||
|
|
||||||
def _get_param_names(self):
|
def _get_param_names(self):
|
||||||
return self.k._get_param_names()
|
return sum([[str(i)+'_'+k.name+'_'+n for n in k._get_param_names()] for i,k in enumerate(self.parts)],[])
|
||||||
|
|
||||||
def _sort_slices(self,X,X2):
|
def _sort_slices(self,X,X2):
|
||||||
slices = [index_to_slices(x) for x in X[-self.levels:].T]
|
slices = [index_to_slices(x) for x in X[:,-self.levels:].T]
|
||||||
X = X[:-self.levels]
|
X = X[:,:-self.levels]
|
||||||
if X2 is None:
|
if X2 is None:
|
||||||
slices2 = slices
|
slices2 = slices
|
||||||
X2 = X
|
X2 = X
|
||||||
else:
|
else:
|
||||||
slices2 = [index_to_slices(x) for x in X2[-self.levels:].T]
|
slices2 = [index_to_slices(x) for x in X2[:,-self.levels:].T]
|
||||||
X2 = X2[:-self.levels]
|
X2 = X2[:,:-self.levels]
|
||||||
return X, X2, slices, slices2
|
return X, X2, slices, slices2
|
||||||
|
|
||||||
def K(self,X,X2,target):
|
def K(self,X,X2,target):
|
||||||
X, X2, slices, slices2 = self._sort_slices(X,X2)
|
X, X2, slices, slices2 = self._sort_slices(X,X2)
|
||||||
|
|
||||||
[[[k.K(X[s],X2[s2],target[s,s2]) for s in slices_i] for s2 in slices_j] for k,slices_i,slices_j in zip(self.parts,slices,slices2)]
|
[[[[k.K(X[s],X2[s2],target[s,s2]) for s in slices_i] for s2 in slices_j] for slices_i,slices_j in zip(slices_,slices2_)] for k, slices_, slices2_ in zip(self.parts,slices,slices2)]
|
||||||
|
|
||||||
def Kdiag(self,X,target):
|
def Kdiag(self,X,target):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
@ -51,7 +51,8 @@ class Hierarchical(Kernpart):
|
||||||
#[[self.k.Kdiag(X[s],target[s]) for s in slices_i] for slices_i in slices]
|
#[[self.k.Kdiag(X[s],target[s]) for s in slices_i] for slices_i in slices]
|
||||||
|
|
||||||
def dK_dtheta(self,dL_dK,X,X2,target):
|
def dK_dtheta(self,dL_dK,X,X2,target):
|
||||||
[[[k.dK_dtheta(dL_dK[s,s2],X[s],X2[s2],target[p_start:p_stop]) for s in slices_i] for s2 in slices_j] for k,slices_i,slices_j, p_start, p_stop in zip(self.parts, slices, slices2, self.param_starts, self.param_stops)]
|
X, X2, slices, slices2 = self._sort_slices(X,X2)
|
||||||
|
[[[[k.dK_dtheta(dL_dK[s,s2],X[s],X2[s2],target[p_start:p_stop]) for s in slices_i] for s2 in slices_j] for slices_i,slices_j in zip(slices_, slices2_)] for k, p_start, p_stop, slices_, slices2_ in zip(self.parts, self.param_starts, self.param_stops, slices, slices2)]
|
||||||
|
|
||||||
|
|
||||||
def dK_dX(self,dL_dK,X,X2,target):
|
def dK_dX(self,dL_dK,X,X2,target):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue