From 11bb1e4c4cffcd31c34367a2faf27291263bba66 Mon Sep 17 00:00:00 2001 From: Zhenwen Dai Date: Thu, 4 Sep 2014 14:07:46 +0100 Subject: [PATCH] implement splitTies --- GPy/core/parameterization/parameterized.py | 7 ++++ .../parameterization/ties_and_remappings.py | 37 +++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/GPy/core/parameterization/parameterized.py b/GPy/core/parameterization/parameterized.py index 1c5708f9..f2fd4a27 100644 --- a/GPy/core/parameterization/parameterized.py +++ b/GPy/core/parameterization/parameterized.py @@ -216,6 +216,13 @@ class Parameterized(Parameterizable): self._highest_parent_._connect_parameters() self._highest_parent_._connect_fixes() self._highest_parent_._notify_parent_change() + + if isinstance(param,Parameterized): + from ties_and_remappings import Tie + if not isinstance(param,Tie): + self._highest_parent_.ties.splitTies(param) + else: + self._highest_parent_.ties._update_label_buf() def _connect_parameters(self, ignore_added_names=False): # connect parameterlist to this parameterized object diff --git a/GPy/core/parameterization/ties_and_remappings.py b/GPy/core/parameterization/ties_and_remappings.py index e5065067..a068bfca 100644 --- a/GPy/core/parameterization/ties_and_remappings.py +++ b/GPy/core/parameterization/ties_and_remappings.py @@ -84,6 +84,24 @@ class Tie(Parameterized): del p.ties self._update_label_buf() self.updates = True + + def splitTies(self, p): + """Split the tie subtree with the tie tree""" + p.ties = Tie() + p.add_parameter(p.ties, -1) + p.add_observer(p.ties, p.ties._parameters_changed_notification, priority=-500) + if self.tied_param is not None: + self.updates = False + labels = self._get_labels([p]) + labels = labels[labels>0] + if len(labels)>0: + p._expand_tie_param(len(labels)) + idx = np.in1d(self.tied_param.tie,labels) + p.tied_param[:] = self.tied_param[idx] + p.tied_param.tie[:] = self.tied_param.tie[idx] + self._remove_unnecessary_ties() + self._update_label_buf() + self.updates = True def _sync_val_group(self, plist): val = np.asarray([p.param_array for p in plist]).mean() @@ -169,6 +187,14 @@ class Tie(Parameterized): return self._remove_tie_param(labels[1:]) self._replace_labels(self._highest_parent_, [(l,labels[0]) for l in labels[1:]]) + + def _remove_unnecessary_ties(self): + """Remove the unnecessary ties""" + if self.tied_param is not None: + labels = [l for l in self.tied_param.tie if (self.label_buf==l).sum()<=2] + if len(labels)>0: + self._remove_tie_param(labels) + self._replace_labels(self._highest_parent_, zip(labels,[0]*len(labels))) def _update_label_buf(self): if self.tied_param is None: @@ -181,6 +207,7 @@ class Tie(Parameterized): self.buf_idx = self._highest_parent_._raveled_index_for(self.tied_param) self._untie_ = self.label_buf==0 self._untie_[self.buf_idx] = True + assert(np.all(self.tied_param.tie>0)) def tie_together(self,plist): """tie a list of parameters""" @@ -201,6 +228,16 @@ class Tie(Parameterized): self._update_label_buf() self.tied_param[self.tied_param.tie==tie_labels[0]] = val self.updates = True + + def untie(self,plist): + """Untie a list of parameters""" + self.updates = False + self._set_labels(plist,[0]) + self._update_label_buf() + self._remove_unnecessary_ties() + self._update_label_buf() + self.updates = True + def _check_change(self): changed = False