mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-30 14:35:15 +02:00
finish tie framework restructure and add more tests
This commit is contained in:
parent
077232c26f
commit
c279da4c75
3 changed files with 45 additions and 113 deletions
|
|
@ -97,7 +97,6 @@ class Tie(Parameterized):
|
||||||
else:
|
else:
|
||||||
val = tieparam[tieparam.tie==l]
|
val = tieparam[tieparam.tie==l]
|
||||||
p[p.tie==l] = val
|
p[p.tie==l] = val
|
||||||
|
|
||||||
for p in plist:
|
for p in plist:
|
||||||
self._traverse_param(_sync_val_p, (p,self.tied_param,read), [])
|
self._traverse_param(_sync_val_p, (p,self.tied_param,read), [])
|
||||||
|
|
||||||
|
|
@ -112,6 +111,7 @@ class Tie(Parameterized):
|
||||||
cons = [c[0] if len(c)>0 else None for c in self.tied_param.constraints.properties_for(range(self.tied_param.size))]
|
cons = [c[0] if len(c)>0 else None for c in self.tied_param.constraints.properties_for(range(self.tied_param.size))]
|
||||||
def _sync_constraints_p(p, tieparam, read, cons):
|
def _sync_constraints_p(p, tieparam, read, cons):
|
||||||
if p.tie is not None:
|
if p.tie is not None:
|
||||||
|
p = p._original_
|
||||||
labels = np.unique(p.tie)
|
labels = np.unique(p.tie)
|
||||||
labels = labels[labels>0]
|
labels = labels[labels>0]
|
||||||
for l in labels:
|
for l in labels:
|
||||||
|
|
@ -127,6 +127,7 @@ class Tie(Parameterized):
|
||||||
p[p.tie==l].constrain(c)
|
p[p.tie==l].constrain(c)
|
||||||
if c != cons[idx]:
|
if c != cons[idx]:
|
||||||
tieparam[idx:idx+1].constrain(c)
|
tieparam[idx:idx+1].constrain(c)
|
||||||
|
cons[idx] = c
|
||||||
read[tieparam.tie==l] = 1
|
read[tieparam.tie==l] = 1
|
||||||
else:
|
else:
|
||||||
if cons[idx] is None:
|
if cons[idx] is None:
|
||||||
|
|
@ -134,7 +135,6 @@ class Tie(Parameterized):
|
||||||
p[p.tie==l].unconstrain()
|
p[p.tie==l].unconstrain()
|
||||||
else:
|
else:
|
||||||
if len(conslist)!=1 or conslist.keys()[0]!=cons[idx] or len(conslist[cons[idx]])!= (p.tie==l).sum():
|
if len(conslist)!=1 or conslist.keys()[0]!=cons[idx] or len(conslist[cons[idx]])!= (p.tie==l).sum():
|
||||||
print cons[idx]
|
|
||||||
p[p.tie==l].constrain(cons[idx])
|
p[p.tie==l].constrain(cons[idx])
|
||||||
for p in plist:
|
for p in plist:
|
||||||
self._traverse_param(_sync_constraints_p, (p,self.tied_param,read, cons), [])
|
self._traverse_param(_sync_constraints_p, (p,self.tied_param,read, cons), [])
|
||||||
|
|
@ -152,10 +152,10 @@ class Tie(Parameterized):
|
||||||
labels = p.ties._get_labels([p])
|
labels = p.ties._get_labels([p])
|
||||||
labels = labels[labels>0]
|
labels = labels[labels>0]
|
||||||
if len(labels)>0:
|
if len(labels)>0:
|
||||||
p._expand_tie_param(len(labels))
|
p.ties._expand_tie_param(len(labels))
|
||||||
vals = p.ties._get_sync_val(p, labels)
|
p.ties.tied_param.tie[:] = labels
|
||||||
p.tied_param[:] = vals
|
p.ties._sync_val([p],toTiedParam=True)
|
||||||
p.tied_param.tie[:] = labels
|
p.ties._sync_constraints([p], toTiedParam=True)
|
||||||
p._update_label_buf()
|
p._update_label_buf()
|
||||||
p.update_model(True)
|
p.update_model(True)
|
||||||
|
|
||||||
|
|
@ -193,59 +193,6 @@ class Tie(Parameterized):
|
||||||
p._update_label_buf()
|
p._update_label_buf()
|
||||||
self.update_model(True)
|
self.update_model(True)
|
||||||
|
|
||||||
def _get_sync_val(self, p, labels):
|
|
||||||
vals = np.empty((labels.size,))
|
|
||||||
read = np.zeros((labels.size,),dtype=np.uint8)
|
|
||||||
def _get_sync_v(p, labels, vals, read):
|
|
||||||
for i in xrange(labels.size):
|
|
||||||
if read[i]==1:
|
|
||||||
p[p.tie==labels[i]] = vals[i]
|
|
||||||
elif np.any(p.tie==labels[i]):
|
|
||||||
vals[i] = p[p.tie==labels[i]][0]
|
|
||||||
p[p.tie==labels[i]][0] = vals[i]
|
|
||||||
read[i] = 1
|
|
||||||
self._traverse_param(_get_sync_v, (p,labels,vals,read), [])
|
|
||||||
return vals
|
|
||||||
|
|
||||||
|
|
||||||
def _sync_val_group(self, plist):
|
|
||||||
val = np.hstack([p.param_array.flat for p in plist]).mean()
|
|
||||||
def _set_val(p):
|
|
||||||
p[:] = val
|
|
||||||
for p in plist:
|
|
||||||
self._traverse_param(_set_val, (p,), [])
|
|
||||||
return val
|
|
||||||
|
|
||||||
def _sync_constraint_group(self, plist, hastie=False, tie_con=None, warning=True):
|
|
||||||
if not hastie:
|
|
||||||
cons = []
|
|
||||||
for p in plist:
|
|
||||||
cons.extend(p.constraints.properties())
|
|
||||||
cons = list(set(cons))
|
|
||||||
if len(cons)==0:
|
|
||||||
tie_con = None
|
|
||||||
else:
|
|
||||||
tie_con = cons[0]
|
|
||||||
if tie_con is not None:
|
|
||||||
for p in plist:
|
|
||||||
if len(p.constraints.properties())!=1 or p.constraints[tie_con].size != p.size:
|
|
||||||
print 'WARNING: '+p.name+' have different constraints! They will be constrained '+str(tie_con)+'!'
|
|
||||||
p.constrain(tie_con)
|
|
||||||
return tie_con
|
|
||||||
elif hastie:
|
|
||||||
for p in plist:
|
|
||||||
if p.constraints.size>0:
|
|
||||||
print 'WARNING: '+p.name+' have different constraints! They will be unconstrained!'
|
|
||||||
p.unconstrain()
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _sync_constraint_vector(self, p1, p2, expandlist, idxlist, warning=True):
|
|
||||||
if p1.constraints.items() != p2.constraints.properties():
|
|
||||||
print 'WARNING: '+p1.name+' and '+p2.name+' have different constraints! Only the constraints of '+p1.name+' will be considered!'
|
|
||||||
for c,ind in p1.constraints.iteritems():
|
|
||||||
idx = idxlist[np.in1d(expandlist,ind)]
|
|
||||||
self.tied_param[idx].constrain(c)
|
|
||||||
|
|
||||||
def _traverse_param(self, func, p, res):
|
def _traverse_param(self, func, p, res):
|
||||||
"""
|
"""
|
||||||
Traverse a param tree starting with *p*
|
Traverse a param tree starting with *p*
|
||||||
|
|
@ -296,19 +243,6 @@ class Tie(Parameterized):
|
||||||
for p in plist:
|
for p in plist:
|
||||||
self._traverse_param(_set_list, (p,[0]), [])
|
self._traverse_param(_set_list, (p,[0]), [])
|
||||||
|
|
||||||
def _get_vals(self, p):
|
|
||||||
vals = []
|
|
||||||
self._traverse_param(lambda x: x.flat, (p,), vals)
|
|
||||||
return np.hstack(vals)
|
|
||||||
|
|
||||||
def _sync_val_pair(self,p1,p2):
|
|
||||||
p1val = self._get_vals(p1)
|
|
||||||
def _set_val(p, offset, p2):
|
|
||||||
p.flat[:] = p2[offset[0]:offset[0]+p.size]
|
|
||||||
offset[0] = offset[0]+ p.size
|
|
||||||
self._traverse_param(_set_val, (p2, [0], p1val), [])
|
|
||||||
return p1val
|
|
||||||
|
|
||||||
def _replace_labels(self, p, label_pairs):
|
def _replace_labels(self, p, label_pairs):
|
||||||
def _replace_l(p):
|
def _replace_l(p):
|
||||||
for l1,l2 in label_pairs:
|
for l1,l2 in label_pairs:
|
||||||
|
|
@ -417,45 +351,20 @@ class Tie(Parameterized):
|
||||||
self._update_label_buf()
|
self._update_label_buf()
|
||||||
self.update_model(True)
|
self.update_model(True)
|
||||||
|
|
||||||
# def tie_together(self,plist):
|
|
||||||
# """tie a list of parameters"""
|
|
||||||
# self.update_model(False)
|
|
||||||
# labels = self._get_labels(plist)
|
|
||||||
# val = self._sync_val_group(plist)
|
|
||||||
# if labels[0]==0 and labels.size==1:
|
|
||||||
# # None of parameters in plist has been tied before.
|
|
||||||
# tie_labels,_ = self._expand_tie_param(1)
|
|
||||||
# self._set_labels(plist, tie_labels)
|
|
||||||
# tie_con = self._sync_constraint_group(plist)
|
|
||||||
# if tie_con is not None:
|
|
||||||
# self.tied_param[self.tied_param.tie==tie_labels[0]].constrain(tie_con)
|
|
||||||
# else:
|
|
||||||
# # Some of parameters has been tied already.
|
|
||||||
# # Merge the tie param
|
|
||||||
# tie_labels = labels[labels>0]
|
|
||||||
# if tie_labels.size>1:
|
|
||||||
# self._merge_tie_labels(tie_labels)
|
|
||||||
# self._set_labels(plist, [tie_labels[0]])
|
|
||||||
# tie_p = self.tied_param[self.tied_param.tie==tie_labels[0]]
|
|
||||||
# tie_con = tie_p.constraints.properties()[0] if tie_p.constraints.size>0 else None
|
|
||||||
# self._sync_constraint_group(plist, True, tie_con)
|
|
||||||
# self._update_label_buf()
|
|
||||||
# self.tied_param[self.tied_param.tie==tie_labels[0]] = val
|
|
||||||
# self.update_model(True)
|
|
||||||
|
|
||||||
def tie_vector(self, p1, p2):
|
def tie_vector(self, p1, p2):
|
||||||
"""tie a pair of vectors"""
|
"""tie a pair of vectors"""
|
||||||
self.update_model(False)
|
self.update_model(False)
|
||||||
expandlist,removelist,labellist = self._get_labels_vector(p1, p2)
|
expandlist,removelist,labellist = self._get_labels_vector(p1, p2)
|
||||||
p1vals = self._sync_val_pair(p1,p2)
|
p_split = self._keepParamList([p1,p2])
|
||||||
if len(expandlist)>0:
|
if len(expandlist)>0:
|
||||||
tie_labels,idxlist = self._expand_tie_param(len(expandlist))
|
tie_labels,idxlist = self._expand_tie_param(len(expandlist))
|
||||||
labellist[expandlist] = tie_labels
|
labellist[expandlist] = tie_labels
|
||||||
self.tied_param[idxlist] = p1vals[expandlist]
|
|
||||||
if len(removelist[0])>0:
|
if len(removelist[0])>0:
|
||||||
self._merge_tie_labelpair(removelist)
|
self._merge_tie_labelpair(removelist)
|
||||||
|
p1,p2 = self._updateParamList(p_split)
|
||||||
self._set_labels([p1,p2], labellist)
|
self._set_labels([p1,p2], labellist)
|
||||||
self._sync_constraint_vector(p1,p2,expandlist,idxlist)
|
self._sync_val([p1,p2],toTiedParam=True)
|
||||||
|
self._sync_constraints([p1,p2], toTiedParam=True)
|
||||||
self._update_label_buf()
|
self._update_label_buf()
|
||||||
self.update_model(True)
|
self.update_model(True)
|
||||||
|
|
||||||
|
|
@ -516,10 +425,14 @@ class Tie(Parameterized):
|
||||||
self._highest_parent_.param_array[self.label_buf==self.label_buf[self.buf_idx[i]]] = self.tied_param[i]
|
self._highest_parent_.param_array[self.label_buf==self.label_buf[self.buf_idx[i]]] = self.tied_param[i]
|
||||||
self._PROPAGATE_VAL_ = True
|
self._PROPAGATE_VAL_ = True
|
||||||
|
|
||||||
def checkValueConsistence(self):
|
#=========================================
|
||||||
|
# Functions for checking consistency
|
||||||
|
#=========================================
|
||||||
|
|
||||||
|
def checkValueConsistency(self):
|
||||||
return not self._check_change()
|
return not self._check_change()
|
||||||
|
|
||||||
def checkConstraintConsistence(self):
|
def checkConstraintConsistency(self):
|
||||||
if self.tied_param is not None:
|
if self.tied_param is not None:
|
||||||
tlist = self.getTiedParamList()
|
tlist = self.getTiedParamList()
|
||||||
for l in tlist:
|
for l in tlist:
|
||||||
|
|
|
||||||
|
|
@ -439,7 +439,7 @@ def sparse_GP_regression_2D(num_samples=400, num_inducing=50, max_iters=100, opt
|
||||||
if plot:
|
if plot:
|
||||||
m.plot()
|
m.plot()
|
||||||
|
|
||||||
print m
|
#print m
|
||||||
return m
|
return m
|
||||||
|
|
||||||
def uncertain_inputs_sparse_regression(max_iters=200, optimize=True, plot=True):
|
def uncertain_inputs_sparse_regression(max_iters=200, optimize=True, plot=True):
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@
|
||||||
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
import numpy as np
|
|
||||||
import GPy
|
import GPy
|
||||||
|
|
||||||
class TieTests(unittest.TestCase):
|
class TieTests(unittest.TestCase):
|
||||||
|
|
@ -10,8 +9,8 @@ class TieTests(unittest.TestCase):
|
||||||
m = GPy.examples.regression.sparse_GP_regression_1D(optimize=False, plot=False, checkgrad=False)
|
m = GPy.examples.regression.sparse_GP_regression_1D(optimize=False, plot=False, checkgrad=False)
|
||||||
m.Z.constrain_positive(warning=False)
|
m.Z.constrain_positive(warning=False)
|
||||||
m.Z.tie_together()
|
m.Z.tie_together()
|
||||||
self.assertTrue(m.ties.checkValueConsistence())
|
self.assertTrue(m.ties.checkValueConsistency())
|
||||||
self.assertTrue(m.ties.checkValueConsistence())
|
self.assertTrue(m.ties.checkConstraintConsistency())
|
||||||
self.assertTrue(m.ties.checkTieTogether([m.Z]))
|
self.assertTrue(m.ties.checkTieTogether([m.Z]))
|
||||||
self.assertTrue(m.checkgrad())
|
self.assertTrue(m.checkgrad())
|
||||||
|
|
||||||
|
|
@ -22,8 +21,8 @@ class TieTests(unittest.TestCase):
|
||||||
m.Z[2:4].tie_together()
|
m.Z[2:4].tie_together()
|
||||||
self.assertTrue(m.ties.checkTieTogether([m.Z[:2]]))
|
self.assertTrue(m.ties.checkTieTogether([m.Z[:2]]))
|
||||||
self.assertTrue(m.ties.checkTieTogether([m.Z[2:4]]))
|
self.assertTrue(m.ties.checkTieTogether([m.Z[2:4]]))
|
||||||
self.assertTrue(m.ties.checkValueConsistence())
|
self.assertTrue(m.ties.checkValueConsistency())
|
||||||
self.assertTrue(m.ties.checkValueConsistence())
|
self.assertTrue(m.ties.checkConstraintConsistency())
|
||||||
self.assertTrue(m.checkgrad())
|
self.assertTrue(m.checkgrad())
|
||||||
|
|
||||||
def test_tie_together_merge(self):
|
def test_tie_together_merge(self):
|
||||||
|
|
@ -32,8 +31,28 @@ class TieTests(unittest.TestCase):
|
||||||
m.Z[:2].tie_together()
|
m.Z[:2].tie_together()
|
||||||
m.Z[1:3].tie_together()
|
m.Z[1:3].tie_together()
|
||||||
self.assertTrue(m.ties.checkTieTogether([m.Z[:3]]))
|
self.assertTrue(m.ties.checkTieTogether([m.Z[:3]]))
|
||||||
self.assertTrue(m.ties.checkValueConsistence())
|
self.assertTrue(m.ties.checkValueConsistency())
|
||||||
self.assertTrue(m.ties.checkValueConsistence())
|
self.assertTrue(m.ties.checkConstraintConsistency())
|
||||||
|
self.assertTrue(m.checkgrad())
|
||||||
|
|
||||||
|
def test_tie_vector(self):
|
||||||
|
m = GPy.examples.regression.sparse_GP_regression_1D(optimize=False, plot=False, checkgrad=False)
|
||||||
|
m.Z.constrain_positive(warning=False)
|
||||||
|
m.Z[:2].tie_vector(m.Z[2:4])
|
||||||
|
self.assertTrue(m.ties.checkValueConsistency())
|
||||||
|
self.assertTrue(m.ties.checkConstraintConsistency())
|
||||||
|
self.assertTrue(m.ties.checkTieVector([m.Z[:2],m.Z[2:4]]))
|
||||||
|
self.assertTrue(m.checkgrad())
|
||||||
|
|
||||||
|
def test_tie_vector_merge(self):
|
||||||
|
m = GPy.examples.regression.sparse_GP_regression_2D(optimize=False, plot=False)
|
||||||
|
m.Z.constrain_positive(warning=False)
|
||||||
|
m.Z[:10].tie_vector(m.Z[10:20])
|
||||||
|
m.Z[20:30].tie_vector(m.Z[30:40])
|
||||||
|
m.Z[10:20].tie_vector(m.Z[20:30])
|
||||||
|
self.assertTrue(m.ties.checkValueConsistency())
|
||||||
|
self.assertTrue(m.ties.checkConstraintConsistency())
|
||||||
|
self.assertTrue(m.ties.checkTieVector([m.Z[:10],m.Z[10:20],m.Z[20:30],m.Z[30:40]]))
|
||||||
self.assertTrue(m.checkgrad())
|
self.assertTrue(m.checkgrad())
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue