to merge with new bug fix

This commit is contained in:
Zhenwen Dai 2014-09-05 14:36:05 +01:00
parent f093e60b5d
commit 62241c4211
2 changed files with 63 additions and 11 deletions

View file

@ -551,6 +551,7 @@ class Indexable(Nameable, Observable):
return np.ones((self._highest_parent_.param_array.size,),dtype=np.bool)
def tie_together(self, *plist):
"""Tie a list of parameters together"""
plist = list(plist)
plist.append(self)
self._highest_parent_.ties.tie_together(plist)
@ -560,6 +561,11 @@ class Indexable(Nameable, Observable):
plist.append(self)
self._highest_parent_.ties.untie(plist)
def tie_vector(self, *plist):
"""Tie a vector of parameters to other vectors of parameters"""
for p in plist:
self._highest_parent_.ties.tie_vector(self,p)
#===========================================================================
# Constrain operations -> done
#===========================================================================

View file

@ -108,7 +108,7 @@ class Tie(Parameterized):
def _set_val(p):
p[:] = val
for p in plist:
self._traverse_param(_set_val, p, [])
self._traverse_param(_set_val, (p,), [])
return val
def _sync_constraint_group(self, plist, hastie=False, tie_con=None, warning=True):
@ -140,18 +140,33 @@ class Tie(Parameterized):
Apply *func* to every leaves (param objects),
and collect return values into *res*
"""
if isinstance(p, Param):
res.append(func(p))
if isinstance(p[0], Param):
res.append(func(*p))
else:
for pc in p.parameters:
self._traverse_param(func,pc,res)
for pc in p[0].parameters:
self._traverse_param(func, (pc,)+p[1:] ,res)
def _get_labels(self, plist):
labels = []
for p in plist:
self._traverse_param(lambda x: x.tie.flat, p, labels)
self._traverse_param(lambda x: x.tie.flat, (p,), labels)
return np.unique(np.hstack(labels))
def _get_labels_vector(self, p1,p2):
label1 = []
self._traverse_param(lambda x: x.tie.flat, (p1,), label1)
label1 = np.hstack(label1)
label2 = []
self._traverse_param(lambda x: x.tie.flat, (p2,), label2)
label2 = np.hstack(label2)
expandlist = np.where(label1*label2==0)[0]
labellist =label1.copy()
idx = np.logical_and(label1==0,label2>0)
labellist[idx] = label2[idx]
idx = np.logical_and(label1*label2>0,label1!=label2)
removelist = (label1[idx],label2[idx])
return expandlist,removelist,labellist
def _set_labels(self, plist, labels):
"""
If there is only one label, set all the param objects to that label,
@ -159,15 +174,21 @@ class Tie(Parameterized):
"""
def _set_l1(p):
p.tie[:] = labels[0]
def _set_list(p, offset):
p.tie.flat[:] = labels[offset[0]:offset[0]+p.size]
offset[0] = offset[0]+ p.size
if len(labels)==1:
for p in plist:
self._traverse_param(_set_l1, p, [])
self._traverse_param(_set_l1, (p,), [])
else:
for p in plist:
self._traverse_param(_set_list, (p,[0]), [])
def _replace_labels(self, p, label_pairs):
def _replace_l(p):
for l1,l2 in label_pairs:
p.tie[p.tie==l1] = l2
self._traverse_param(_replace_l, p, [])
self._traverse_param(_replace_l, (p,), [])
def _expand_tie_param(self, num):
"""Expand the tie param with the number of *num* parameters"""
@ -210,6 +231,11 @@ class Tie(Parameterized):
return
self._remove_tie_param(labels[1:])
self._replace_labels(self._highest_parent_, [(l,labels[0]) for l in labels[1:]])
def _merge_tie_labelpair(self, labelpair):
"""Merge the second list in labelpair to the first list"""
self._remove_tie_param(labelpair[1])
self._replace_labels(self._highest_parent_, zip(labelpair[1],labelpair[0]))
def _remove_unnecessary_ties(self):
"""Remove the unnecessary ties"""
@ -226,7 +252,13 @@ class Tie(Parameterized):
self._untie_ = None
else:
self.label_buf = np.zeros((self._highest_parent_.param_array.size,),dtype=np.uint32)
self._traverse_param(lambda x:np.put(self.label_buf,self._highest_parent_._raveled_index_for(x),x.tie), self._highest_parent_, [])
def up(x):
print self._highest_parent_._raveled_index_for(x)
np.put(self.label_buf,self._highest_parent_._raveled_index_for(x),x.tie.flat)
self._traverse_param(up, (self._highest_parent_,), [])
# self._traverse_param(lambda x:np.put(self.label_buf,self._highest_parent_._raveled_index_for(x),x.tie.flat), (self._highest_parent_,), [])
print self.tied_param.tie
print self.label_buf
self.buf_idx = self._highest_parent_._raveled_index_for(self.tied_param)
self._untie_ = self.label_buf==0
self._untie_[self.buf_idx] = True
@ -234,7 +266,7 @@ class Tie(Parameterized):
def tie_together(self,plist):
"""tie a list of parameters"""
#self.updates = False
self.updates = False
labels = self._get_labels(plist)
val = self._sync_val_group(plist)
if labels[0]==0 and labels.size==1:
@ -256,7 +288,21 @@ class Tie(Parameterized):
self._sync_constraint_group(plist, True, tie_con)
self._update_label_buf()
self.tied_param[self.tied_param.tie==tie_labels[0]] = val
#self.updates = True
self.updates = True
def tie_vector(self, p1, p2):
"""tie a pair of vectors"""
self.updates = False
expandlist,removelist,labellist = self._get_labels_vector(p1, p2)
if len(expandlist)>0:
tie_labels = self._expand_tie_param(len(expandlist))
labellist[expandlist] = tie_labels
if len(removelist[0])>0:
self._merge_tie_labelpair(removelist)
print labellist
self._set_labels([p1,p2], labellist)
self._update_label_buf()
self.updates = True
def untie(self,plist):
"""Untie a list of parameters"""