mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-30 14:35:15 +02:00
bug fix for tie framework. SSMRD ready!
This commit is contained in:
parent
ec64f653b5
commit
43f3bfc385
3 changed files with 14 additions and 7 deletions
|
|
@ -593,8 +593,7 @@ class Indexable(Nameable, Observable):
|
||||||
|
|
||||||
def tie_vector(self, *plist):
|
def tie_vector(self, *plist):
|
||||||
"""Tie a vector of parameters to other vectors of parameters"""
|
"""Tie a vector of parameters to other vectors of parameters"""
|
||||||
for p in plist:
|
self._highest_parent_.ties.tie_vector((self,)+plist)
|
||||||
self._highest_parent_.ties.tie_vector(self, p)
|
|
||||||
self._highest_parent_._trigger_params_changed()
|
self._highest_parent_._trigger_params_changed()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -218,10 +218,12 @@ class Tie(Parameterized):
|
||||||
|
|
||||||
def _get_labels_vector(self, p1,p2):
|
def _get_labels_vector(self, p1,p2):
|
||||||
label1 = []
|
label1 = []
|
||||||
self._traverse_param(lambda x: x.tie.flat, (p1,), label1)
|
for p in p1:
|
||||||
|
self._traverse_param(lambda x: x.tie.flat, (p,), label1)
|
||||||
label1 = np.hstack(label1)
|
label1 = np.hstack(label1)
|
||||||
label2 = []
|
label2 = []
|
||||||
self._traverse_param(lambda x: x.tie.flat, (p2,), label2)
|
for p in p2:
|
||||||
|
self._traverse_param(lambda x: x.tie.flat, (p,), label2)
|
||||||
label2 = np.hstack(label2)
|
label2 = np.hstack(label2)
|
||||||
expandlist = np.where(label1+label2==0)[0]
|
expandlist = np.where(label1+label2==0)[0]
|
||||||
labellist =label1.copy()
|
labellist =label1.copy()
|
||||||
|
|
@ -371,14 +373,20 @@ class Tie(Parameterized):
|
||||||
self.update_model(True)
|
self.update_model(True)
|
||||||
|
|
||||||
def tie_vector(self, plist):
|
def tie_vector(self, plist):
|
||||||
|
assert len(plist)>=2
|
||||||
p_splits = [self._keepParamList([p]) for p in plist]
|
p_splits = [self._keepParamList([p]) for p in plist]
|
||||||
|
for p_split2 in p_splits[1:]:
|
||||||
|
p_split1 = p_splits[0]
|
||||||
|
p1 = self._updateParamList(p_split1)
|
||||||
|
p2 = self._updateParamList(p_split2)
|
||||||
|
self._tie_vector(p1, p2)
|
||||||
|
|
||||||
def _tie_vector(self, p1, p2):
|
def _tie_vector(self, p1, p2):
|
||||||
"""tie a pair of vectors"""
|
"""tie a pair of vectors"""
|
||||||
self.update_model(False)
|
self.update_model(False)
|
||||||
expandlist,removelist,labellist = self._get_labels_vector(p1, p2)
|
expandlist,removelist,labellist = self._get_labels_vector(p1, p2)
|
||||||
p_split1 = self._keepParamList([p1])
|
p_split1 = self._keepParamList(p1)
|
||||||
p_split2 = self._keepParamList([p2])
|
p_split2 = self._keepParamList(p2)
|
||||||
if len(expandlist)>0:
|
if len(expandlist)>0:
|
||||||
tie_labels,idxlist = self._expand_tie_param(len(expandlist))
|
tie_labels,idxlist = self._expand_tie_param(len(expandlist))
|
||||||
labellist[expandlist] = tie_labels
|
labellist[expandlist] = tie_labels
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue