mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-30 14:35:15 +02:00
to merge with new bug fix
This commit is contained in:
parent
f093e60b5d
commit
62241c4211
2 changed files with 63 additions and 11 deletions
|
|
@ -551,6 +551,7 @@ class Indexable(Nameable, Observable):
|
||||||
return np.ones((self._highest_parent_.param_array.size,),dtype=np.bool)
|
return np.ones((self._highest_parent_.param_array.size,),dtype=np.bool)
|
||||||
|
|
||||||
def tie_together(self, *plist):
|
def tie_together(self, *plist):
|
||||||
|
"""Tie a list of parameters together"""
|
||||||
plist = list(plist)
|
plist = list(plist)
|
||||||
plist.append(self)
|
plist.append(self)
|
||||||
self._highest_parent_.ties.tie_together(plist)
|
self._highest_parent_.ties.tie_together(plist)
|
||||||
|
|
@ -560,6 +561,11 @@ class Indexable(Nameable, Observable):
|
||||||
plist.append(self)
|
plist.append(self)
|
||||||
self._highest_parent_.ties.untie(plist)
|
self._highest_parent_.ties.untie(plist)
|
||||||
|
|
||||||
|
def tie_vector(self, *plist):
|
||||||
|
"""Tie a vector of parameters to other vectors of parameters"""
|
||||||
|
for p in plist:
|
||||||
|
self._highest_parent_.ties.tie_vector(self,p)
|
||||||
|
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
# Constrain operations -> done
|
# Constrain operations -> done
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
|
|
|
||||||
|
|
@ -108,7 +108,7 @@ class Tie(Parameterized):
|
||||||
def _set_val(p):
|
def _set_val(p):
|
||||||
p[:] = val
|
p[:] = val
|
||||||
for p in plist:
|
for p in plist:
|
||||||
self._traverse_param(_set_val, p, [])
|
self._traverse_param(_set_val, (p,), [])
|
||||||
return val
|
return val
|
||||||
|
|
||||||
def _sync_constraint_group(self, plist, hastie=False, tie_con=None, warning=True):
|
def _sync_constraint_group(self, plist, hastie=False, tie_con=None, warning=True):
|
||||||
|
|
@ -140,18 +140,33 @@ class Tie(Parameterized):
|
||||||
Apply *func* to every leaves (param objects),
|
Apply *func* to every leaves (param objects),
|
||||||
and collect return values into *res*
|
and collect return values into *res*
|
||||||
"""
|
"""
|
||||||
if isinstance(p, Param):
|
if isinstance(p[0], Param):
|
||||||
res.append(func(p))
|
res.append(func(*p))
|
||||||
else:
|
else:
|
||||||
for pc in p.parameters:
|
for pc in p[0].parameters:
|
||||||
self._traverse_param(func,pc,res)
|
self._traverse_param(func, (pc,)+p[1:] ,res)
|
||||||
|
|
||||||
def _get_labels(self, plist):
|
def _get_labels(self, plist):
|
||||||
labels = []
|
labels = []
|
||||||
for p in plist:
|
for p in plist:
|
||||||
self._traverse_param(lambda x: x.tie.flat, p, labels)
|
self._traverse_param(lambda x: x.tie.flat, (p,), labels)
|
||||||
return np.unique(np.hstack(labels))
|
return np.unique(np.hstack(labels))
|
||||||
|
|
||||||
|
def _get_labels_vector(self, p1,p2):
|
||||||
|
label1 = []
|
||||||
|
self._traverse_param(lambda x: x.tie.flat, (p1,), label1)
|
||||||
|
label1 = np.hstack(label1)
|
||||||
|
label2 = []
|
||||||
|
self._traverse_param(lambda x: x.tie.flat, (p2,), label2)
|
||||||
|
label2 = np.hstack(label2)
|
||||||
|
expandlist = np.where(label1*label2==0)[0]
|
||||||
|
labellist =label1.copy()
|
||||||
|
idx = np.logical_and(label1==0,label2>0)
|
||||||
|
labellist[idx] = label2[idx]
|
||||||
|
idx = np.logical_and(label1*label2>0,label1!=label2)
|
||||||
|
removelist = (label1[idx],label2[idx])
|
||||||
|
return expandlist,removelist,labellist
|
||||||
|
|
||||||
def _set_labels(self, plist, labels):
|
def _set_labels(self, plist, labels):
|
||||||
"""
|
"""
|
||||||
If there is only one label, set all the param objects to that label,
|
If there is only one label, set all the param objects to that label,
|
||||||
|
|
@ -159,15 +174,21 @@ class Tie(Parameterized):
|
||||||
"""
|
"""
|
||||||
def _set_l1(p):
|
def _set_l1(p):
|
||||||
p.tie[:] = labels[0]
|
p.tie[:] = labels[0]
|
||||||
|
def _set_list(p, offset):
|
||||||
|
p.tie.flat[:] = labels[offset[0]:offset[0]+p.size]
|
||||||
|
offset[0] = offset[0]+ p.size
|
||||||
if len(labels)==1:
|
if len(labels)==1:
|
||||||
for p in plist:
|
for p in plist:
|
||||||
self._traverse_param(_set_l1, p, [])
|
self._traverse_param(_set_l1, (p,), [])
|
||||||
|
else:
|
||||||
|
for p in plist:
|
||||||
|
self._traverse_param(_set_list, (p,[0]), [])
|
||||||
|
|
||||||
def _replace_labels(self, p, label_pairs):
|
def _replace_labels(self, p, label_pairs):
|
||||||
def _replace_l(p):
|
def _replace_l(p):
|
||||||
for l1,l2 in label_pairs:
|
for l1,l2 in label_pairs:
|
||||||
p.tie[p.tie==l1] = l2
|
p.tie[p.tie==l1] = l2
|
||||||
self._traverse_param(_replace_l, p, [])
|
self._traverse_param(_replace_l, (p,), [])
|
||||||
|
|
||||||
def _expand_tie_param(self, num):
|
def _expand_tie_param(self, num):
|
||||||
"""Expand the tie param with the number of *num* parameters"""
|
"""Expand the tie param with the number of *num* parameters"""
|
||||||
|
|
@ -210,6 +231,11 @@ class Tie(Parameterized):
|
||||||
return
|
return
|
||||||
self._remove_tie_param(labels[1:])
|
self._remove_tie_param(labels[1:])
|
||||||
self._replace_labels(self._highest_parent_, [(l,labels[0]) for l in labels[1:]])
|
self._replace_labels(self._highest_parent_, [(l,labels[0]) for l in labels[1:]])
|
||||||
|
|
||||||
|
def _merge_tie_labelpair(self, labelpair):
|
||||||
|
"""Merge the second list in labelpair to the first list"""
|
||||||
|
self._remove_tie_param(labelpair[1])
|
||||||
|
self._replace_labels(self._highest_parent_, zip(labelpair[1],labelpair[0]))
|
||||||
|
|
||||||
def _remove_unnecessary_ties(self):
|
def _remove_unnecessary_ties(self):
|
||||||
"""Remove the unnecessary ties"""
|
"""Remove the unnecessary ties"""
|
||||||
|
|
@ -226,7 +252,13 @@ class Tie(Parameterized):
|
||||||
self._untie_ = None
|
self._untie_ = None
|
||||||
else:
|
else:
|
||||||
self.label_buf = np.zeros((self._highest_parent_.param_array.size,),dtype=np.uint32)
|
self.label_buf = np.zeros((self._highest_parent_.param_array.size,),dtype=np.uint32)
|
||||||
self._traverse_param(lambda x:np.put(self.label_buf,self._highest_parent_._raveled_index_for(x),x.tie), self._highest_parent_, [])
|
def up(x):
|
||||||
|
print self._highest_parent_._raveled_index_for(x)
|
||||||
|
np.put(self.label_buf,self._highest_parent_._raveled_index_for(x),x.tie.flat)
|
||||||
|
self._traverse_param(up, (self._highest_parent_,), [])
|
||||||
|
# self._traverse_param(lambda x:np.put(self.label_buf,self._highest_parent_._raveled_index_for(x),x.tie.flat), (self._highest_parent_,), [])
|
||||||
|
print self.tied_param.tie
|
||||||
|
print self.label_buf
|
||||||
self.buf_idx = self._highest_parent_._raveled_index_for(self.tied_param)
|
self.buf_idx = self._highest_parent_._raveled_index_for(self.tied_param)
|
||||||
self._untie_ = self.label_buf==0
|
self._untie_ = self.label_buf==0
|
||||||
self._untie_[self.buf_idx] = True
|
self._untie_[self.buf_idx] = True
|
||||||
|
|
@ -234,7 +266,7 @@ class Tie(Parameterized):
|
||||||
|
|
||||||
def tie_together(self,plist):
|
def tie_together(self,plist):
|
||||||
"""tie a list of parameters"""
|
"""tie a list of parameters"""
|
||||||
#self.updates = False
|
self.updates = False
|
||||||
labels = self._get_labels(plist)
|
labels = self._get_labels(plist)
|
||||||
val = self._sync_val_group(plist)
|
val = self._sync_val_group(plist)
|
||||||
if labels[0]==0 and labels.size==1:
|
if labels[0]==0 and labels.size==1:
|
||||||
|
|
@ -256,7 +288,21 @@ class Tie(Parameterized):
|
||||||
self._sync_constraint_group(plist, True, tie_con)
|
self._sync_constraint_group(plist, True, tie_con)
|
||||||
self._update_label_buf()
|
self._update_label_buf()
|
||||||
self.tied_param[self.tied_param.tie==tie_labels[0]] = val
|
self.tied_param[self.tied_param.tie==tie_labels[0]] = val
|
||||||
#self.updates = True
|
self.updates = True
|
||||||
|
|
||||||
|
def tie_vector(self, p1, p2):
|
||||||
|
"""tie a pair of vectors"""
|
||||||
|
self.updates = False
|
||||||
|
expandlist,removelist,labellist = self._get_labels_vector(p1, p2)
|
||||||
|
if len(expandlist)>0:
|
||||||
|
tie_labels = self._expand_tie_param(len(expandlist))
|
||||||
|
labellist[expandlist] = tie_labels
|
||||||
|
if len(removelist[0])>0:
|
||||||
|
self._merge_tie_labelpair(removelist)
|
||||||
|
print labellist
|
||||||
|
self._set_labels([p1,p2], labellist)
|
||||||
|
self._update_label_buf()
|
||||||
|
self.updates = True
|
||||||
|
|
||||||
def untie(self,plist):
|
def untie(self,plist):
|
||||||
"""Untie a list of parameters"""
|
"""Untie a list of parameters"""
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue