mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-30 14:35:15 +02:00
to merge with new bug fix
This commit is contained in:
parent
f093e60b5d
commit
62241c4211
2 changed files with 63 additions and 11 deletions
|
|
@ -551,6 +551,7 @@ class Indexable(Nameable, Observable):
|
|||
return np.ones((self._highest_parent_.param_array.size,),dtype=np.bool)
|
||||
|
||||
def tie_together(self, *plist):
|
||||
"""Tie a list of parameters together"""
|
||||
plist = list(plist)
|
||||
plist.append(self)
|
||||
self._highest_parent_.ties.tie_together(plist)
|
||||
|
|
@ -560,6 +561,11 @@ class Indexable(Nameable, Observable):
|
|||
plist.append(self)
|
||||
self._highest_parent_.ties.untie(plist)
|
||||
|
||||
def tie_vector(self, *plist):
|
||||
"""Tie a vector of parameters to other vectors of parameters"""
|
||||
for p in plist:
|
||||
self._highest_parent_.ties.tie_vector(self,p)
|
||||
|
||||
#===========================================================================
|
||||
# Constrain operations -> done
|
||||
#===========================================================================
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ class Tie(Parameterized):
|
|||
def _set_val(p):
|
||||
p[:] = val
|
||||
for p in plist:
|
||||
self._traverse_param(_set_val, p, [])
|
||||
self._traverse_param(_set_val, (p,), [])
|
||||
return val
|
||||
|
||||
def _sync_constraint_group(self, plist, hastie=False, tie_con=None, warning=True):
|
||||
|
|
@ -140,18 +140,33 @@ class Tie(Parameterized):
|
|||
Apply *func* to every leaves (param objects),
|
||||
and collect return values into *res*
|
||||
"""
|
||||
if isinstance(p, Param):
|
||||
res.append(func(p))
|
||||
if isinstance(p[0], Param):
|
||||
res.append(func(*p))
|
||||
else:
|
||||
for pc in p.parameters:
|
||||
self._traverse_param(func,pc,res)
|
||||
for pc in p[0].parameters:
|
||||
self._traverse_param(func, (pc,)+p[1:] ,res)
|
||||
|
||||
def _get_labels(self, plist):
|
||||
labels = []
|
||||
for p in plist:
|
||||
self._traverse_param(lambda x: x.tie.flat, p, labels)
|
||||
self._traverse_param(lambda x: x.tie.flat, (p,), labels)
|
||||
return np.unique(np.hstack(labels))
|
||||
|
||||
def _get_labels_vector(self, p1,p2):
|
||||
label1 = []
|
||||
self._traverse_param(lambda x: x.tie.flat, (p1,), label1)
|
||||
label1 = np.hstack(label1)
|
||||
label2 = []
|
||||
self._traverse_param(lambda x: x.tie.flat, (p2,), label2)
|
||||
label2 = np.hstack(label2)
|
||||
expandlist = np.where(label1*label2==0)[0]
|
||||
labellist =label1.copy()
|
||||
idx = np.logical_and(label1==0,label2>0)
|
||||
labellist[idx] = label2[idx]
|
||||
idx = np.logical_and(label1*label2>0,label1!=label2)
|
||||
removelist = (label1[idx],label2[idx])
|
||||
return expandlist,removelist,labellist
|
||||
|
||||
def _set_labels(self, plist, labels):
|
||||
"""
|
||||
If there is only one label, set all the param objects to that label,
|
||||
|
|
@ -159,15 +174,21 @@ class Tie(Parameterized):
|
|||
"""
|
||||
def _set_l1(p):
|
||||
p.tie[:] = labels[0]
|
||||
def _set_list(p, offset):
|
||||
p.tie.flat[:] = labels[offset[0]:offset[0]+p.size]
|
||||
offset[0] = offset[0]+ p.size
|
||||
if len(labels)==1:
|
||||
for p in plist:
|
||||
self._traverse_param(_set_l1, p, [])
|
||||
self._traverse_param(_set_l1, (p,), [])
|
||||
else:
|
||||
for p in plist:
|
||||
self._traverse_param(_set_list, (p,[0]), [])
|
||||
|
||||
def _replace_labels(self, p, label_pairs):
|
||||
def _replace_l(p):
|
||||
for l1,l2 in label_pairs:
|
||||
p.tie[p.tie==l1] = l2
|
||||
self._traverse_param(_replace_l, p, [])
|
||||
self._traverse_param(_replace_l, (p,), [])
|
||||
|
||||
def _expand_tie_param(self, num):
|
||||
"""Expand the tie param with the number of *num* parameters"""
|
||||
|
|
@ -210,6 +231,11 @@ class Tie(Parameterized):
|
|||
return
|
||||
self._remove_tie_param(labels[1:])
|
||||
self._replace_labels(self._highest_parent_, [(l,labels[0]) for l in labels[1:]])
|
||||
|
||||
def _merge_tie_labelpair(self, labelpair):
|
||||
"""Merge the second list in labelpair to the first list"""
|
||||
self._remove_tie_param(labelpair[1])
|
||||
self._replace_labels(self._highest_parent_, zip(labelpair[1],labelpair[0]))
|
||||
|
||||
def _remove_unnecessary_ties(self):
|
||||
"""Remove the unnecessary ties"""
|
||||
|
|
@ -226,7 +252,13 @@ class Tie(Parameterized):
|
|||
self._untie_ = None
|
||||
else:
|
||||
self.label_buf = np.zeros((self._highest_parent_.param_array.size,),dtype=np.uint32)
|
||||
self._traverse_param(lambda x:np.put(self.label_buf,self._highest_parent_._raveled_index_for(x),x.tie), self._highest_parent_, [])
|
||||
def up(x):
|
||||
print self._highest_parent_._raveled_index_for(x)
|
||||
np.put(self.label_buf,self._highest_parent_._raveled_index_for(x),x.tie.flat)
|
||||
self._traverse_param(up, (self._highest_parent_,), [])
|
||||
# self._traverse_param(lambda x:np.put(self.label_buf,self._highest_parent_._raveled_index_for(x),x.tie.flat), (self._highest_parent_,), [])
|
||||
print self.tied_param.tie
|
||||
print self.label_buf
|
||||
self.buf_idx = self._highest_parent_._raveled_index_for(self.tied_param)
|
||||
self._untie_ = self.label_buf==0
|
||||
self._untie_[self.buf_idx] = True
|
||||
|
|
@ -234,7 +266,7 @@ class Tie(Parameterized):
|
|||
|
||||
def tie_together(self,plist):
|
||||
"""tie a list of parameters"""
|
||||
#self.updates = False
|
||||
self.updates = False
|
||||
labels = self._get_labels(plist)
|
||||
val = self._sync_val_group(plist)
|
||||
if labels[0]==0 and labels.size==1:
|
||||
|
|
@ -256,7 +288,21 @@ class Tie(Parameterized):
|
|||
self._sync_constraint_group(plist, True, tie_con)
|
||||
self._update_label_buf()
|
||||
self.tied_param[self.tied_param.tie==tie_labels[0]] = val
|
||||
#self.updates = True
|
||||
self.updates = True
|
||||
|
||||
def tie_vector(self, p1, p2):
|
||||
"""tie a pair of vectors"""
|
||||
self.updates = False
|
||||
expandlist,removelist,labellist = self._get_labels_vector(p1, p2)
|
||||
if len(expandlist)>0:
|
||||
tie_labels = self._expand_tie_param(len(expandlist))
|
||||
labellist[expandlist] = tie_labels
|
||||
if len(removelist[0])>0:
|
||||
self._merge_tie_labelpair(removelist)
|
||||
print labellist
|
||||
self._set_labels([p1,p2], labellist)
|
||||
self._update_label_buf()
|
||||
self.updates = True
|
||||
|
||||
def untie(self,plist):
|
||||
"""Untie a list of parameters"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue