mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-30 14:35:15 +02:00
finish preliminary implementation of tie_vector
This commit is contained in:
parent
c9f3d652b7
commit
64c271086e
2 changed files with 46 additions and 25 deletions
|
|
@ -593,6 +593,7 @@ class Indexable(Nameable, Observable):
|
||||||
"""Tie a vector of parameters to other vectors of parameters"""
|
"""Tie a vector of parameters to other vectors of parameters"""
|
||||||
for p in plist:
|
for p in plist:
|
||||||
self._highest_parent_.ties.tie_vector(self,p)
|
self._highest_parent_.ties.tie_vector(self,p)
|
||||||
|
self._highest_parent_._trigger_params_changed()
|
||||||
|
|
||||||
#===========================================================================
|
#===========================================================================
|
||||||
# Constrain operations -> done
|
# Constrain operations -> done
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ class Tie(Parameterized):
|
||||||
================================
|
================================
|
||||||
|
|
||||||
TODO:
|
TODO:
|
||||||
1. Add the support for multiple parameter tie_together and tie_vector
|
1. Add the support for multiple parameter tie_together and tie_vector [Preliminary]
|
||||||
2. Properly handling parameters with constraints [DONE]
|
2. Properly handling parameters with constraints [DONE]
|
||||||
3. Properly handling the merging of two models [DONE]
|
3. Properly handling the merging of two models [DONE]
|
||||||
4. Properly handling initialization [DONE]
|
4. Properly handling initialization [DONE]
|
||||||
|
|
@ -73,9 +73,9 @@ class Tie(Parameterized):
|
||||||
def mergeTies(self, p):
|
def mergeTies(self, p):
|
||||||
"""Merge the tie tree with another tie tree"""
|
"""Merge the tie tree with another tie tree"""
|
||||||
assert hasattr(p,'ties') and isinstance(p.ties,Tie), str(type(p))
|
assert hasattr(p,'ties') and isinstance(p.ties,Tie), str(type(p))
|
||||||
self.updates = False
|
self.update_model(False)
|
||||||
if p.ties.tied_param is not None:
|
if p.ties.tied_param is not None:
|
||||||
tie_labels = self._expand_tie_param(p.ties.tied_param.size)
|
tie_labels,_ = self._expand_tie_param(p.ties.tied_param.size)
|
||||||
self.tied_param[-p.ties.tied_param.size:] = p.ties.tied_param
|
self.tied_param[-p.ties.tied_param.size:] = p.ties.tied_param
|
||||||
pairs = zip(self.tied_param.tie,tie_labels)
|
pairs = zip(self.tied_param.tie,tie_labels)
|
||||||
self._replace_labels(p, pairs)
|
self._replace_labels(p, pairs)
|
||||||
|
|
@ -83,7 +83,7 @@ class Tie(Parameterized):
|
||||||
p.remove_parameter(p.ties)
|
p.remove_parameter(p.ties)
|
||||||
del p.ties
|
del p.ties
|
||||||
self._update_label_buf()
|
self._update_label_buf()
|
||||||
self.updates = True
|
self.update_model(True)
|
||||||
|
|
||||||
def splitTies(self, p):
|
def splitTies(self, p):
|
||||||
"""Split the tie subtree with the tie tree"""
|
"""Split the tie subtree with the tie tree"""
|
||||||
|
|
@ -91,7 +91,7 @@ class Tie(Parameterized):
|
||||||
p.add_parameter(p.ties, -1)
|
p.add_parameter(p.ties, -1)
|
||||||
p.add_observer(p.ties, p.ties._parameters_changed_notification, priority=-500)
|
p.add_observer(p.ties, p.ties._parameters_changed_notification, priority=-500)
|
||||||
if self.tied_param is not None:
|
if self.tied_param is not None:
|
||||||
self.updates = False
|
self.update_model(False)
|
||||||
labels = self._get_labels([p])
|
labels = self._get_labels([p])
|
||||||
labels = labels[labels>0]
|
labels = labels[labels>0]
|
||||||
if len(labels)>0:
|
if len(labels)>0:
|
||||||
|
|
@ -101,7 +101,7 @@ class Tie(Parameterized):
|
||||||
p.tied_param.tie[:] = self.tied_param.tie[idx]
|
p.tied_param.tie[:] = self.tied_param.tie[idx]
|
||||||
self._remove_unnecessary_ties()
|
self._remove_unnecessary_ties()
|
||||||
self._update_label_buf()
|
self._update_label_buf()
|
||||||
self.updates = True
|
self.update_model(True)
|
||||||
|
|
||||||
def _sync_val_group(self, plist):
|
def _sync_val_group(self, plist):
|
||||||
val = np.hstack([p.param_array.flat for p in plist]).mean()
|
val = np.hstack([p.param_array.flat for p in plist]).mean()
|
||||||
|
|
@ -134,6 +134,13 @@ class Tie(Parameterized):
|
||||||
p.unconstrain()
|
p.unconstrain()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _sync_constraint_vector(self, p1, p2, expandlist, idxlist, warning=True):
|
||||||
|
if p1.constraints.items() != p2.constraints.properties():
|
||||||
|
print 'WARNING: '+p1.name+' and '+p2.name+' have different constraints! Only the constraints of '+p1.name+' will be considered!'
|
||||||
|
for c,ind in p1.constraints.iteritems():
|
||||||
|
idx = idxlist[np.in1d(expandlist,ind)]
|
||||||
|
self.tied_param[idx].constrain(c)
|
||||||
|
|
||||||
def _traverse_param(self, func, p, res):
|
def _traverse_param(self, func, p, res):
|
||||||
"""
|
"""
|
||||||
Traverse a param tree starting with *p*
|
Traverse a param tree starting with *p*
|
||||||
|
|
@ -184,6 +191,19 @@ class Tie(Parameterized):
|
||||||
for p in plist:
|
for p in plist:
|
||||||
self._traverse_param(_set_list, (p,[0]), [])
|
self._traverse_param(_set_list, (p,[0]), [])
|
||||||
|
|
||||||
|
def _get_vals(self, p):
|
||||||
|
vals = []
|
||||||
|
self._traverse_param(lambda x: x.flat, (p,), vals)
|
||||||
|
return np.hstack(vals)
|
||||||
|
|
||||||
|
def _sync_val_pair(self,p1,p2):
|
||||||
|
p1val = self._get_vals(p1)
|
||||||
|
def _set_val(p, offset, p2):
|
||||||
|
p.flat[:] = p2[offset[0]:offset[0]+p.size]
|
||||||
|
offset[0] = offset[0]+ p.size
|
||||||
|
self._traverse_param(_set_val, (p2, [0], p1val), [])
|
||||||
|
return p1val
|
||||||
|
|
||||||
def _replace_labels(self, p, label_pairs):
|
def _replace_labels(self, p, label_pairs):
|
||||||
def _replace_l(p):
|
def _replace_l(p):
|
||||||
for l1,l2 in label_pairs:
|
for l1,l2 in label_pairs:
|
||||||
|
|
@ -194,21 +214,25 @@ class Tie(Parameterized):
|
||||||
"""Expand the tie param with the number of *num* parameters"""
|
"""Expand the tie param with the number of *num* parameters"""
|
||||||
if self.tied_param is None:
|
if self.tied_param is None:
|
||||||
start_label = 1
|
start_label = 1
|
||||||
|
labellist = np.array(range(start_label,start_label+num),dtype=np.int)
|
||||||
|
idxlist = np.array(range(0,num),dtype=np.int)
|
||||||
new_buf = np.empty((num,))
|
new_buf = np.empty((num,))
|
||||||
self.tied_param = Param('tied',new_buf)
|
self.tied_param = Param('tied',new_buf)
|
||||||
self.tied_param.tie[:] = range(start_label,start_label+num)
|
self.tied_param.tie[:] = labellist
|
||||||
else:
|
else:
|
||||||
start_label = self.tied_param.tie.max()+1
|
start_label = self.tied_param.tie.max()+1
|
||||||
new_buf = np.empty((self.tied_param.size+num,))
|
new_buf = np.empty((self.tied_param.size+num,))
|
||||||
new_buf[:self.tied_param.size] = self.tied_param.param_array.copy()
|
new_buf[:self.tied_param.size] = self.tied_param.param_array.copy()
|
||||||
old_tie_ = self.tied_param.tie.copy()
|
old_tie_ = self.tied_param.tie.copy()
|
||||||
old_size = self.tied_param.size
|
old_size = self.tied_param.size
|
||||||
|
labellist = np.array(range(start_label,start_label+num),dtype=np.int)
|
||||||
|
idxlist = np.array(range(old_size,old_size+num),dtype=np.int)
|
||||||
self.remove_parameter(self.tied_param)
|
self.remove_parameter(self.tied_param)
|
||||||
self.tied_param = Param('tied',new_buf)
|
self.tied_param = Param('tied',new_buf)
|
||||||
self.tied_param.tie[:old_size] = old_tie_
|
self.tied_param.tie[:old_size] = old_tie_
|
||||||
self.tied_param.tie[old_size:] = range(start_label,start_label+num)
|
self.tied_param.tie[old_size:] = labellist
|
||||||
self.add_parameter(self.tied_param)
|
self.add_parameter(self.tied_param)
|
||||||
return range(start_label,start_label+num)
|
return labellist, idxlist
|
||||||
|
|
||||||
def _remove_tie_param(self, labels):
|
def _remove_tie_param(self, labels):
|
||||||
"""Remove the tie param corresponding to *labels*"""
|
"""Remove the tie param corresponding to *labels*"""
|
||||||
|
|
@ -252,13 +276,7 @@ class Tie(Parameterized):
|
||||||
self._untie_ = None
|
self._untie_ = None
|
||||||
else:
|
else:
|
||||||
self.label_buf = np.zeros((self._highest_parent_.param_array.size,),dtype=np.uint32)
|
self.label_buf = np.zeros((self._highest_parent_.param_array.size,),dtype=np.uint32)
|
||||||
def up(x):
|
self._traverse_param(lambda x:np.put(self.label_buf,self._highest_parent_._raveled_index_for(x),x.tie.flat), (self._highest_parent_,), [])
|
||||||
print self._highest_parent_._raveled_index_for(x)
|
|
||||||
np.put(self.label_buf,self._highest_parent_._raveled_index_for(x),x.tie.flat)
|
|
||||||
self._traverse_param(up, (self._highest_parent_,), [])
|
|
||||||
# self._traverse_param(lambda x:np.put(self.label_buf,self._highest_parent_._raveled_index_for(x),x.tie.flat), (self._highest_parent_,), [])
|
|
||||||
print self.tied_param.tie
|
|
||||||
print self.label_buf
|
|
||||||
self.buf_idx = self._highest_parent_._raveled_index_for(self.tied_param)
|
self.buf_idx = self._highest_parent_._raveled_index_for(self.tied_param)
|
||||||
self._untie_ = self.label_buf==0
|
self._untie_ = self.label_buf==0
|
||||||
self._untie_[self.buf_idx] = True
|
self._untie_[self.buf_idx] = True
|
||||||
|
|
@ -266,12 +284,12 @@ class Tie(Parameterized):
|
||||||
|
|
||||||
def tie_together(self,plist):
|
def tie_together(self,plist):
|
||||||
"""tie a list of parameters"""
|
"""tie a list of parameters"""
|
||||||
self.updates = False
|
self.update_model(False)
|
||||||
labels = self._get_labels(plist)
|
labels = self._get_labels(plist)
|
||||||
val = self._sync_val_group(plist)
|
val = self._sync_val_group(plist)
|
||||||
if labels[0]==0 and labels.size==1:
|
if labels[0]==0 and labels.size==1:
|
||||||
# None of parameters in plist has been tied before.
|
# None of parameters in plist has been tied before.
|
||||||
tie_labels = self._expand_tie_param(1)
|
tie_labels,_ = self._expand_tie_param(1)
|
||||||
self._set_labels(plist, tie_labels)
|
self._set_labels(plist, tie_labels)
|
||||||
tie_con = self._sync_constraint_group(plist)
|
tie_con = self._sync_constraint_group(plist)
|
||||||
if tie_con is not None:
|
if tie_con is not None:
|
||||||
|
|
@ -288,30 +306,32 @@ class Tie(Parameterized):
|
||||||
self._sync_constraint_group(plist, True, tie_con)
|
self._sync_constraint_group(plist, True, tie_con)
|
||||||
self._update_label_buf()
|
self._update_label_buf()
|
||||||
self.tied_param[self.tied_param.tie==tie_labels[0]] = val
|
self.tied_param[self.tied_param.tie==tie_labels[0]] = val
|
||||||
self.updates = True
|
self.update_model(True)
|
||||||
|
|
||||||
def tie_vector(self, p1, p2):
|
def tie_vector(self, p1, p2):
|
||||||
"""tie a pair of vectors"""
|
"""tie a pair of vectors"""
|
||||||
self.updates = False
|
self.update_model(False)
|
||||||
expandlist,removelist,labellist = self._get_labels_vector(p1, p2)
|
expandlist,removelist,labellist = self._get_labels_vector(p1, p2)
|
||||||
|
p1vals = self._sync_val_pair(p1,p2)
|
||||||
if len(expandlist)>0:
|
if len(expandlist)>0:
|
||||||
tie_labels = self._expand_tie_param(len(expandlist))
|
tie_labels,idxlist = self._expand_tie_param(len(expandlist))
|
||||||
labellist[expandlist] = tie_labels
|
labellist[expandlist] = tie_labels
|
||||||
|
self.tied_param[idxlist] = p1vals[expandlist]
|
||||||
if len(removelist[0])>0:
|
if len(removelist[0])>0:
|
||||||
self._merge_tie_labelpair(removelist)
|
self._merge_tie_labelpair(removelist)
|
||||||
print labellist
|
|
||||||
self._set_labels([p1,p2], labellist)
|
self._set_labels([p1,p2], labellist)
|
||||||
|
self._sync_constraint_vector(p1,p2,expandlist,idxlist)
|
||||||
self._update_label_buf()
|
self._update_label_buf()
|
||||||
self.updates = True
|
self.update_model(True)
|
||||||
|
|
||||||
def untie(self,plist):
|
def untie(self,plist):
|
||||||
"""Untie a list of parameters"""
|
"""Untie a list of parameters"""
|
||||||
self.updates = False
|
self.update_model(False)
|
||||||
self._set_labels(plist,[0])
|
self._set_labels(plist,[0])
|
||||||
self._update_label_buf()
|
self._update_label_buf()
|
||||||
self._remove_unnecessary_ties()
|
self._remove_unnecessary_ties()
|
||||||
self._update_label_buf()
|
self._update_label_buf()
|
||||||
self.updates = True
|
self.update_model(True)
|
||||||
|
|
||||||
def _check_change(self):
|
def _check_change(self):
|
||||||
changed = False
|
changed = False
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue