diff --git a/GPy/core/parameterization/parameter_core.py b/GPy/core/parameterization/parameter_core.py index 0ededab3..271609b9 100644 --- a/GPy/core/parameterization/parameter_core.py +++ b/GPy/core/parameterization/parameter_core.py @@ -551,6 +551,7 @@ class Indexable(Nameable, Observable): return np.ones((self._highest_parent_.param_array.size,),dtype=np.bool) def tie_together(self, *plist): + """Tie a list of parameters together""" plist = list(plist) plist.append(self) self._highest_parent_.ties.tie_together(plist) @@ -560,6 +561,11 @@ class Indexable(Nameable, Observable): plist.append(self) self._highest_parent_.ties.untie(plist) + def tie_vector(self, *plist): + """Tie a vector of parameters to other vectors of parameters""" + for p in plist: + self._highest_parent_.ties.tie_vector(self,p) + #=========================================================================== # Constrain operations -> done #=========================================================================== diff --git a/GPy/core/parameterization/ties_and_remappings.py b/GPy/core/parameterization/ties_and_remappings.py index 52a1ca45..01cc7561 100644 --- a/GPy/core/parameterization/ties_and_remappings.py +++ b/GPy/core/parameterization/ties_and_remappings.py @@ -108,7 +108,7 @@ class Tie(Parameterized): def _set_val(p): p[:] = val for p in plist: - self._traverse_param(_set_val, p, []) + self._traverse_param(_set_val, (p,), []) return val def _sync_constraint_group(self, plist, hastie=False, tie_con=None, warning=True): @@ -140,18 +140,33 @@ class Tie(Parameterized): Apply *func* to every leaves (param objects), and collect return values into *res* """ - if isinstance(p, Param): - res.append(func(p)) + if isinstance(p[0], Param): + res.append(func(*p)) else: - for pc in p.parameters: - self._traverse_param(func,pc,res) + for pc in p[0].parameters: + self._traverse_param(func, (pc,)+p[1:] ,res) def _get_labels(self, plist): labels = [] for p in plist: - self._traverse_param(lambda x: x.tie.flat, p, labels) + self._traverse_param(lambda x: x.tie.flat, (p,), labels) return np.unique(np.hstack(labels)) + def _get_labels_vector(self, p1,p2): + label1 = [] + self._traverse_param(lambda x: x.tie.flat, (p1,), label1) + label1 = np.hstack(label1) + label2 = [] + self._traverse_param(lambda x: x.tie.flat, (p2,), label2) + label2 = np.hstack(label2) + expandlist = np.where(label1*label2==0)[0] + labellist =label1.copy() + idx = np.logical_and(label1==0,label2>0) + labellist[idx] = label2[idx] + idx = np.logical_and(label1*label2>0,label1!=label2) + removelist = (label1[idx],label2[idx]) + return expandlist,removelist,labellist + def _set_labels(self, plist, labels): """ If there is only one label, set all the param objects to that label, @@ -159,15 +174,21 @@ class Tie(Parameterized): """ def _set_l1(p): p.tie[:] = labels[0] + def _set_list(p, offset): + p.tie.flat[:] = labels[offset[0]:offset[0]+p.size] + offset[0] = offset[0]+ p.size if len(labels)==1: for p in plist: - self._traverse_param(_set_l1, p, []) + self._traverse_param(_set_l1, (p,), []) + else: + for p in plist: + self._traverse_param(_set_list, (p,[0]), []) def _replace_labels(self, p, label_pairs): def _replace_l(p): for l1,l2 in label_pairs: p.tie[p.tie==l1] = l2 - self._traverse_param(_replace_l, p, []) + self._traverse_param(_replace_l, (p,), []) def _expand_tie_param(self, num): """Expand the tie param with the number of *num* parameters""" @@ -210,6 +231,11 @@ class Tie(Parameterized): return self._remove_tie_param(labels[1:]) self._replace_labels(self._highest_parent_, [(l,labels[0]) for l in labels[1:]]) + + def _merge_tie_labelpair(self, labelpair): + """Merge the second list in labelpair to the first list""" + self._remove_tie_param(labelpair[1]) + self._replace_labels(self._highest_parent_, zip(labelpair[1],labelpair[0])) def _remove_unnecessary_ties(self): """Remove the unnecessary ties""" @@ -226,7 +252,13 @@ class Tie(Parameterized): self._untie_ = None else: self.label_buf = np.zeros((self._highest_parent_.param_array.size,),dtype=np.uint32) - self._traverse_param(lambda x:np.put(self.label_buf,self._highest_parent_._raveled_index_for(x),x.tie), self._highest_parent_, []) + def up(x): + print self._highest_parent_._raveled_index_for(x) + np.put(self.label_buf,self._highest_parent_._raveled_index_for(x),x.tie.flat) + self._traverse_param(up, (self._highest_parent_,), []) +# self._traverse_param(lambda x:np.put(self.label_buf,self._highest_parent_._raveled_index_for(x),x.tie.flat), (self._highest_parent_,), []) + print self.tied_param.tie + print self.label_buf self.buf_idx = self._highest_parent_._raveled_index_for(self.tied_param) self._untie_ = self.label_buf==0 self._untie_[self.buf_idx] = True @@ -234,7 +266,7 @@ class Tie(Parameterized): def tie_together(self,plist): """tie a list of parameters""" - #self.updates = False + self.updates = False labels = self._get_labels(plist) val = self._sync_val_group(plist) if labels[0]==0 and labels.size==1: @@ -256,7 +288,21 @@ class Tie(Parameterized): self._sync_constraint_group(plist, True, tie_con) self._update_label_buf() self.tied_param[self.tied_param.tie==tie_labels[0]] = val - #self.updates = True + self.updates = True + + def tie_vector(self, p1, p2): + """tie a pair of vectors""" + self.updates = False + expandlist,removelist,labellist = self._get_labels_vector(p1, p2) + if len(expandlist)>0: + tie_labels = self._expand_tie_param(len(expandlist)) + labellist[expandlist] = tie_labels + if len(removelist[0])>0: + self._merge_tie_labelpair(removelist) + print labellist + self._set_labels([p1,p2], labellist) + self._update_label_buf() + self.updates = True def untie(self,plist): """Untie a list of parameters"""