implement splitTies

This commit is contained in:
Zhenwen Dai 2014-09-04 14:07:46 +01:00
parent a506f4a605
commit 11bb1e4c4c
2 changed files with 44 additions and 0 deletions

View file

@ -216,6 +216,13 @@ class Parameterized(Parameterizable):
self._highest_parent_._connect_parameters()
self._highest_parent_._connect_fixes()
self._highest_parent_._notify_parent_change()
if isinstance(param,Parameterized):
from ties_and_remappings import Tie
if not isinstance(param,Tie):
self._highest_parent_.ties.splitTies(param)
else:
self._highest_parent_.ties._update_label_buf()
def _connect_parameters(self, ignore_added_names=False):
# connect parameterlist to this parameterized object

View file

@ -84,6 +84,24 @@ class Tie(Parameterized):
del p.ties
self._update_label_buf()
self.updates = True
def splitTies(self, p):
"""Split the tie subtree with the tie tree"""
p.ties = Tie()
p.add_parameter(p.ties, -1)
p.add_observer(p.ties, p.ties._parameters_changed_notification, priority=-500)
if self.tied_param is not None:
self.updates = False
labels = self._get_labels([p])
labels = labels[labels>0]
if len(labels)>0:
p._expand_tie_param(len(labels))
idx = np.in1d(self.tied_param.tie,labels)
p.tied_param[:] = self.tied_param[idx]
p.tied_param.tie[:] = self.tied_param.tie[idx]
self._remove_unnecessary_ties()
self._update_label_buf()
self.updates = True
def _sync_val_group(self, plist):
val = np.asarray([p.param_array for p in plist]).mean()
@ -169,6 +187,14 @@ class Tie(Parameterized):
return
self._remove_tie_param(labels[1:])
self._replace_labels(self._highest_parent_, [(l,labels[0]) for l in labels[1:]])
def _remove_unnecessary_ties(self):
"""Remove the unnecessary ties"""
if self.tied_param is not None:
labels = [l for l in self.tied_param.tie if (self.label_buf==l).sum()<=2]
if len(labels)>0:
self._remove_tie_param(labels)
self._replace_labels(self._highest_parent_, zip(labels,[0]*len(labels)))
def _update_label_buf(self):
if self.tied_param is None:
@ -181,6 +207,7 @@ class Tie(Parameterized):
self.buf_idx = self._highest_parent_._raveled_index_for(self.tied_param)
self._untie_ = self.label_buf==0
self._untie_[self.buf_idx] = True
assert(np.all(self.tied_param.tie>0))
def tie_together(self,plist):
"""tie a list of parameters"""
@ -201,6 +228,16 @@ class Tie(Parameterized):
self._update_label_buf()
self.tied_param[self.tied_param.tie==tie_labels[0]] = val
self.updates = True
def untie(self,plist):
"""Untie a list of parameters"""
self.updates = False
self._set_labels(plist,[0])
self._update_label_buf()
self._remove_unnecessary_ties()
self._update_label_buf()
self.updates = True
def _check_change(self):
changed = False