fix the bugs about tie found today

This commit is contained in:
Zhenwen Dai 2014-09-22 16:45:54 +01:00
parent ca892e32e1
commit 494f38f788
3 changed files with 77 additions and 10 deletions

View file

@ -240,9 +240,10 @@ class Tie(Parameterized):
for p in plist:
self._traverse_param(_set_l1, (p,), [])
else:
idx = [0]
for p in plist:
self._traverse_param(_set_list, (p,[0]), [])
self._traverse_param(_set_list, (p,idx), [])
def _replace_labels(self, p, label_pairs):
def _replace_l(p):
for l1,l2 in label_pairs:
@ -258,6 +259,7 @@ class Tie(Parameterized):
new_buf = np.empty((num,))
self.tied_param = Param('tied',new_buf)
self.tied_param.tie[:] = labellist
self.link_parameter(self.tied_param)
else:
start_label = self.tied_param.tie.max()+1
new_buf = np.empty((self.tied_param.size+num,))
@ -266,11 +268,13 @@ class Tie(Parameterized):
old_size = self.tied_param.size
labellist = np.array(range(start_label,start_label+num),dtype=np.int)
idxlist = np.array(range(old_size,old_size+num),dtype=np.int)
cons = self.tied_param.constraints.copy()
self.unlink_parameter(self.tied_param)
self.tied_param = Param('tied',new_buf)
self.tied_param.tie[:old_size] = old_tie_
self.tied_param.tie[old_size:] = labellist
self.link_parameter(self.tied_param)
self.link_parameter(self.tied_param)
self.tied_param.constraints.update(cons)
return labellist, idxlist
def _remove_tie_param(self, labels):
@ -283,10 +287,17 @@ class Tie(Parameterized):
idx = np.logical_not(np.in1d(self.tied_param.tie,labels))
new_buf[:] = self.tied_param[idx]
old_tie_ = self.tied_param.tie.copy()
self.unlink_parameter(self.tied_param)
cons = {}
for c,ind in self.tied_param.constraints.iteritems():
buf = np.zeros((old_tie_.size,),dtype=np.uint8)
buf[ind] = 1
if (buf[idx]==1).sum()>0:
cons[c] = np.where(buf[idx]==1)[0]
self.unlink_parameter(self.tied_param)
self.tied_param = Param('tied',new_buf)
self.tied_param.tie[:] = old_tie_[idx]
self.link_parameter(self.tied_param)
[self.tied_param.constraints.add(c,ind) for c,ind in cons.iteritems()]
def _merge_tie_labels(self, labels):
"""Merge all the labels in the list to the first one"""
@ -322,7 +333,10 @@ class Tie(Parameterized):
assert(np.all(self.tied_param.tie>0))
def _keepParamList(self,plist):
return [(p._original_, p._current_slice_) for p in plist]
paramlist = []
for p in plist:
self._traverse_param(lambda p: (p._original_, p._current_slice_), (p,), paramlist)
return paramlist
def _updateParamList(self, p_split):
return [p_org[p_slice] for p_org,p_slice in p_split]
@ -355,16 +369,20 @@ class Tie(Parameterized):
"""tie a pair of vectors"""
self.update_model(False)
expandlist,removelist,labellist = self._get_labels_vector(p1, p2)
p_split = self._keepParamList([p1,p2])
p_split1 = self._keepParamList([p1])
p_split2 = self._keepParamList([p2])
if len(expandlist)>0:
tie_labels,idxlist = self._expand_tie_param(len(expandlist))
labellist[expandlist] = tie_labels
if len(removelist[0])>0:
self._merge_tie_labelpair(removelist)
p1,p2 = self._updateParamList(p_split)
self._set_labels([p1,p2], labellist)
self._sync_val([p1,p2],toTiedParam=True)
self._sync_constraints([p1,p2], toTiedParam=True)
p1 = self._updateParamList(p_split1)
p2 = self._updateParamList(p_split2)
ps = p1+p2
self._set_labels(p1, labellist)
self._set_labels(p2, labellist)
self._sync_val(ps,toTiedParam=True)
self._sync_constraints(ps, toTiedParam=True)
self._update_label_buf()
self.update_model(True)

View file

@ -162,6 +162,32 @@ def bgplvm_oil(optimize=True, verbose=1, plot=True, N=200, Q=7, num_inducing=40,
from ..util.misc import param_to_array
import numpy as np
_np.random.seed(0)
data = GPy.util.datasets.oil()
kernel = GPy.kern.RBF(Q, 1., 1./_np.random.uniform(0,1,(Q,)), ARD=True)# + GPy.kern.Bias(Q, _np.exp(-2))
Y = data['X'][:N]
m = GPy.models.BayesianGPLVM(Y, Q, kernel=kernel, num_inducing=num_inducing, **k)
m.data_labels = data['Y'][:N].argmax(axis=1)
if optimize:
m.optimize('bfgs', messages=verbose, max_iters=max_iters, gtol=.05)
if plot:
fig, (latent_axes, sense_axes) = plt.subplots(1, 2)
m.plot_latent(ax=latent_axes, labels=m.data_labels)
data_show = GPy.plotting.matplot_dep.visualize.vector_show((m.Y[0,:]))
lvm_visualizer = GPy.plotting.matplot_dep.visualize.lvm_dimselect(param_to_array(m.X.mean)[0:1,:], # @UnusedVariable
m, data_show, latent_axes=latent_axes, sense_axes=sense_axes, labels=m.data_labels)
raw_input('Press enter to finish')
plt.close(fig)
return m
def bgplvm_oil_100(optimize=True, verbose=1, plot=True, N=200, Q=7, num_inducing=40, max_iters=1000, **k):
import GPy
from matplotlib import pyplot as plt
from ..util.misc import param_to_array
_np.random.seed(0)
data = GPy.util.datasets.oil_100()

View file

@ -2,6 +2,7 @@
# Licensed under the BSD 3-clause license (see LICENSE.txt)
import unittest
import numpy as np
import GPy
class TieTests(unittest.TestCase):
@ -54,6 +55,28 @@ class TieTests(unittest.TestCase):
self.assertTrue(m.ties.checkConstraintConsistency())
self.assertTrue(m.ties.checkTieVector([m.Z[:10],m.Z[10:20],m.Z[20:30],m.Z[30:40]]))
self.assertTrue(m.checkgrad())
def test_remove_tie(self):
x = np.random.rand(100,1)
y = np.random.rand(100,1)
m = GPy.models.SparseGPRegression(x,y,kernel=GPy.kern.RBF(1)+GPy.kern.Matern32(1))
m.kern.rbf.lengthscale.tie_together(m.kern.Mat32.lengthscale)
m.Z[:1].tie_together(m.Z[1:2])
m.kern.rbf.variance.tie_together(m.kern.Mat32.variance)
m.kern.rbf.lengthscale.untie()
self.assertTrue(m.ties.checkValueConsistency())
self.assertTrue(m.ties.checkConstraintConsistency())
self.assertTrue(m.ties.checkTieTogether([m.kern.rbf.variance,m.kern.Mat32.variance]))
self.assertTrue(m.ties.checkTieVector([m.Z[:1],m.Z[1:2]]))
self.assertTrue(m.checkgrad())
def test_tie_variational_posterior(self):
m = GPy.examples.dimensionality_reduction.bgplvm_oil_100(plot=False,optimize=False)
m.X[:10].tie_vector(m.X[10:20])
self.assertTrue(m.ties.checkValueConsistency())
self.assertTrue(m.ties.checkConstraintConsistency())
self.assertTrue(m.ties.checkTieVector([m.X[:10],m.X[10:20]]))
self.assertTrue(m.checkgrad())
if __name__ == "__main__":
print "Running unit tests, please be (very) patient..."