diff --git a/GPy/core/parameterization/parameter_core.py b/GPy/core/parameterization/parameter_core.py index 30d52f05..4fb4047f 100644 --- a/GPy/core/parameterization/parameter_core.py +++ b/GPy/core/parameterization/parameter_core.py @@ -593,6 +593,7 @@ class Indexable(Nameable, Observable): """Tie a vector of parameters to other vectors of parameters""" for p in plist: self._highest_parent_.ties.tie_vector(self,p) + self._highest_parent_._trigger_params_changed() #=========================================================================== # Constrain operations -> done diff --git a/GPy/core/parameterization/ties_and_remappings.py b/GPy/core/parameterization/ties_and_remappings.py index 01cc7561..f662685e 100644 --- a/GPy/core/parameterization/ties_and_remappings.py +++ b/GPy/core/parameterization/ties_and_remappings.py @@ -53,7 +53,7 @@ class Tie(Parameterized): ================================ TODO: - 1. Add the support for multiple parameter tie_together and tie_vector + 1. Add the support for multiple parameter tie_together and tie_vector [Preliminary] 2. Properly handling parameters with constraints [DONE] 3. Properly handling the merging of two models [DONE] 4. Properly handling initialization [DONE] @@ -73,9 +73,9 @@ class Tie(Parameterized): def mergeTies(self, p): """Merge the tie tree with another tie tree""" assert hasattr(p,'ties') and isinstance(p.ties,Tie), str(type(p)) - self.updates = False + self.update_model(False) if p.ties.tied_param is not None: - tie_labels = self._expand_tie_param(p.ties.tied_param.size) + tie_labels,_ = self._expand_tie_param(p.ties.tied_param.size) self.tied_param[-p.ties.tied_param.size:] = p.ties.tied_param pairs = zip(self.tied_param.tie,tie_labels) self._replace_labels(p, pairs) @@ -83,7 +83,7 @@ class Tie(Parameterized): p.remove_parameter(p.ties) del p.ties self._update_label_buf() - self.updates = True + self.update_model(True) def splitTies(self, p): """Split the tie subtree with the tie tree""" @@ -91,7 +91,7 @@ class Tie(Parameterized): p.add_parameter(p.ties, -1) p.add_observer(p.ties, p.ties._parameters_changed_notification, priority=-500) if self.tied_param is not None: - self.updates = False + self.update_model(False) labels = self._get_labels([p]) labels = labels[labels>0] if len(labels)>0: @@ -101,7 +101,7 @@ class Tie(Parameterized): p.tied_param.tie[:] = self.tied_param.tie[idx] self._remove_unnecessary_ties() self._update_label_buf() - self.updates = True + self.update_model(True) def _sync_val_group(self, plist): val = np.hstack([p.param_array.flat for p in plist]).mean() @@ -133,6 +133,13 @@ class Tie(Parameterized): print 'WARNING: '+p.name+' have different constraints! They will be unconstrained!' p.unconstrain() return None + + def _sync_constraint_vector(self, p1, p2, expandlist, idxlist, warning=True): + if p1.constraints.items() != p2.constraints.properties(): + print 'WARNING: '+p1.name+' and '+p2.name+' have different constraints! Only the constraints of '+p1.name+' will be considered!' + for c,ind in p1.constraints.iteritems(): + idx = idxlist[np.in1d(expandlist,ind)] + self.tied_param[idx].constrain(c) def _traverse_param(self, func, p, res): """ @@ -183,6 +190,19 @@ class Tie(Parameterized): else: for p in plist: self._traverse_param(_set_list, (p,[0]), []) + + def _get_vals(self, p): + vals = [] + self._traverse_param(lambda x: x.flat, (p,), vals) + return np.hstack(vals) + + def _sync_val_pair(self,p1,p2): + p1val = self._get_vals(p1) + def _set_val(p, offset, p2): + p.flat[:] = p2[offset[0]:offset[0]+p.size] + offset[0] = offset[0]+ p.size + self._traverse_param(_set_val, (p2, [0], p1val), []) + return p1val def _replace_labels(self, p, label_pairs): def _replace_l(p): @@ -194,21 +214,25 @@ class Tie(Parameterized): """Expand the tie param with the number of *num* parameters""" if self.tied_param is None: start_label = 1 + labellist = np.array(range(start_label,start_label+num),dtype=np.int) + idxlist = np.array(range(0,num),dtype=np.int) new_buf = np.empty((num,)) self.tied_param = Param('tied',new_buf) - self.tied_param.tie[:] = range(start_label,start_label+num) + self.tied_param.tie[:] = labellist else: start_label = self.tied_param.tie.max()+1 new_buf = np.empty((self.tied_param.size+num,)) new_buf[:self.tied_param.size] = self.tied_param.param_array.copy() old_tie_ = self.tied_param.tie.copy() old_size = self.tied_param.size + labellist = np.array(range(start_label,start_label+num),dtype=np.int) + idxlist = np.array(range(old_size,old_size+num),dtype=np.int) self.remove_parameter(self.tied_param) self.tied_param = Param('tied',new_buf) self.tied_param.tie[:old_size] = old_tie_ - self.tied_param.tie[old_size:] = range(start_label,start_label+num) + self.tied_param.tie[old_size:] = labellist self.add_parameter(self.tied_param) - return range(start_label,start_label+num) + return labellist, idxlist def _remove_tie_param(self, labels): """Remove the tie param corresponding to *labels*""" @@ -252,13 +276,7 @@ class Tie(Parameterized): self._untie_ = None else: self.label_buf = np.zeros((self._highest_parent_.param_array.size,),dtype=np.uint32) - def up(x): - print self._highest_parent_._raveled_index_for(x) - np.put(self.label_buf,self._highest_parent_._raveled_index_for(x),x.tie.flat) - self._traverse_param(up, (self._highest_parent_,), []) -# self._traverse_param(lambda x:np.put(self.label_buf,self._highest_parent_._raveled_index_for(x),x.tie.flat), (self._highest_parent_,), []) - print self.tied_param.tie - print self.label_buf + self._traverse_param(lambda x:np.put(self.label_buf,self._highest_parent_._raveled_index_for(x),x.tie.flat), (self._highest_parent_,), []) self.buf_idx = self._highest_parent_._raveled_index_for(self.tied_param) self._untie_ = self.label_buf==0 self._untie_[self.buf_idx] = True @@ -266,12 +284,12 @@ class Tie(Parameterized): def tie_together(self,plist): """tie a list of parameters""" - self.updates = False + self.update_model(False) labels = self._get_labels(plist) val = self._sync_val_group(plist) if labels[0]==0 and labels.size==1: # None of parameters in plist has been tied before. - tie_labels = self._expand_tie_param(1) + tie_labels,_ = self._expand_tie_param(1) self._set_labels(plist, tie_labels) tie_con = self._sync_constraint_group(plist) if tie_con is not None: @@ -288,30 +306,32 @@ class Tie(Parameterized): self._sync_constraint_group(plist, True, tie_con) self._update_label_buf() self.tied_param[self.tied_param.tie==tie_labels[0]] = val - self.updates = True + self.update_model(True) def tie_vector(self, p1, p2): """tie a pair of vectors""" - self.updates = False + self.update_model(False) expandlist,removelist,labellist = self._get_labels_vector(p1, p2) + p1vals = self._sync_val_pair(p1,p2) if len(expandlist)>0: - tie_labels = self._expand_tie_param(len(expandlist)) + tie_labels,idxlist = self._expand_tie_param(len(expandlist)) labellist[expandlist] = tie_labels + self.tied_param[idxlist] = p1vals[expandlist] if len(removelist[0])>0: self._merge_tie_labelpair(removelist) - print labellist self._set_labels([p1,p2], labellist) + self._sync_constraint_vector(p1,p2,expandlist,idxlist) self._update_label_buf() - self.updates = True + self.update_model(True) def untie(self,plist): """Untie a list of parameters""" - self.updates = False + self.update_model(False) self._set_labels(plist,[0]) self._update_label_buf() self._remove_unnecessary_ties() self._update_label_buf() - self.updates = True + self.update_model(True) def _check_change(self): changed = False