mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-30 14:35:15 +02:00
fix the bugs about tie found today
This commit is contained in:
parent
ca892e32e1
commit
494f38f788
3 changed files with 77 additions and 10 deletions
|
|
@ -240,9 +240,10 @@ class Tie(Parameterized):
|
|||
for p in plist:
|
||||
self._traverse_param(_set_l1, (p,), [])
|
||||
else:
|
||||
idx = [0]
|
||||
for p in plist:
|
||||
self._traverse_param(_set_list, (p,[0]), [])
|
||||
|
||||
self._traverse_param(_set_list, (p,idx), [])
|
||||
|
||||
def _replace_labels(self, p, label_pairs):
|
||||
def _replace_l(p):
|
||||
for l1,l2 in label_pairs:
|
||||
|
|
@ -258,6 +259,7 @@ class Tie(Parameterized):
|
|||
new_buf = np.empty((num,))
|
||||
self.tied_param = Param('tied',new_buf)
|
||||
self.tied_param.tie[:] = labellist
|
||||
self.link_parameter(self.tied_param)
|
||||
else:
|
||||
start_label = self.tied_param.tie.max()+1
|
||||
new_buf = np.empty((self.tied_param.size+num,))
|
||||
|
|
@ -266,11 +268,13 @@ class Tie(Parameterized):
|
|||
old_size = self.tied_param.size
|
||||
labellist = np.array(range(start_label,start_label+num),dtype=np.int)
|
||||
idxlist = np.array(range(old_size,old_size+num),dtype=np.int)
|
||||
cons = self.tied_param.constraints.copy()
|
||||
self.unlink_parameter(self.tied_param)
|
||||
self.tied_param = Param('tied',new_buf)
|
||||
self.tied_param.tie[:old_size] = old_tie_
|
||||
self.tied_param.tie[old_size:] = labellist
|
||||
self.link_parameter(self.tied_param)
|
||||
self.link_parameter(self.tied_param)
|
||||
self.tied_param.constraints.update(cons)
|
||||
return labellist, idxlist
|
||||
|
||||
def _remove_tie_param(self, labels):
|
||||
|
|
@ -283,10 +287,17 @@ class Tie(Parameterized):
|
|||
idx = np.logical_not(np.in1d(self.tied_param.tie,labels))
|
||||
new_buf[:] = self.tied_param[idx]
|
||||
old_tie_ = self.tied_param.tie.copy()
|
||||
self.unlink_parameter(self.tied_param)
|
||||
cons = {}
|
||||
for c,ind in self.tied_param.constraints.iteritems():
|
||||
buf = np.zeros((old_tie_.size,),dtype=np.uint8)
|
||||
buf[ind] = 1
|
||||
if (buf[idx]==1).sum()>0:
|
||||
cons[c] = np.where(buf[idx]==1)[0]
|
||||
self.unlink_parameter(self.tied_param)
|
||||
self.tied_param = Param('tied',new_buf)
|
||||
self.tied_param.tie[:] = old_tie_[idx]
|
||||
self.link_parameter(self.tied_param)
|
||||
[self.tied_param.constraints.add(c,ind) for c,ind in cons.iteritems()]
|
||||
|
||||
def _merge_tie_labels(self, labels):
|
||||
"""Merge all the labels in the list to the first one"""
|
||||
|
|
@ -322,7 +333,10 @@ class Tie(Parameterized):
|
|||
assert(np.all(self.tied_param.tie>0))
|
||||
|
||||
def _keepParamList(self,plist):
|
||||
return [(p._original_, p._current_slice_) for p in plist]
|
||||
paramlist = []
|
||||
for p in plist:
|
||||
self._traverse_param(lambda p: (p._original_, p._current_slice_), (p,), paramlist)
|
||||
return paramlist
|
||||
|
||||
def _updateParamList(self, p_split):
|
||||
return [p_org[p_slice] for p_org,p_slice in p_split]
|
||||
|
|
@ -355,16 +369,20 @@ class Tie(Parameterized):
|
|||
"""tie a pair of vectors"""
|
||||
self.update_model(False)
|
||||
expandlist,removelist,labellist = self._get_labels_vector(p1, p2)
|
||||
p_split = self._keepParamList([p1,p2])
|
||||
p_split1 = self._keepParamList([p1])
|
||||
p_split2 = self._keepParamList([p2])
|
||||
if len(expandlist)>0:
|
||||
tie_labels,idxlist = self._expand_tie_param(len(expandlist))
|
||||
labellist[expandlist] = tie_labels
|
||||
if len(removelist[0])>0:
|
||||
self._merge_tie_labelpair(removelist)
|
||||
p1,p2 = self._updateParamList(p_split)
|
||||
self._set_labels([p1,p2], labellist)
|
||||
self._sync_val([p1,p2],toTiedParam=True)
|
||||
self._sync_constraints([p1,p2], toTiedParam=True)
|
||||
p1 = self._updateParamList(p_split1)
|
||||
p2 = self._updateParamList(p_split2)
|
||||
ps = p1+p2
|
||||
self._set_labels(p1, labellist)
|
||||
self._set_labels(p2, labellist)
|
||||
self._sync_val(ps,toTiedParam=True)
|
||||
self._sync_constraints(ps, toTiedParam=True)
|
||||
self._update_label_buf()
|
||||
self.update_model(True)
|
||||
|
||||
|
|
|
|||
|
|
@ -162,6 +162,32 @@ def bgplvm_oil(optimize=True, verbose=1, plot=True, N=200, Q=7, num_inducing=40,
|
|||
from ..util.misc import param_to_array
|
||||
import numpy as np
|
||||
|
||||
_np.random.seed(0)
|
||||
data = GPy.util.datasets.oil()
|
||||
|
||||
kernel = GPy.kern.RBF(Q, 1., 1./_np.random.uniform(0,1,(Q,)), ARD=True)# + GPy.kern.Bias(Q, _np.exp(-2))
|
||||
Y = data['X'][:N]
|
||||
m = GPy.models.BayesianGPLVM(Y, Q, kernel=kernel, num_inducing=num_inducing, **k)
|
||||
m.data_labels = data['Y'][:N].argmax(axis=1)
|
||||
|
||||
if optimize:
|
||||
m.optimize('bfgs', messages=verbose, max_iters=max_iters, gtol=.05)
|
||||
|
||||
if plot:
|
||||
fig, (latent_axes, sense_axes) = plt.subplots(1, 2)
|
||||
m.plot_latent(ax=latent_axes, labels=m.data_labels)
|
||||
data_show = GPy.plotting.matplot_dep.visualize.vector_show((m.Y[0,:]))
|
||||
lvm_visualizer = GPy.plotting.matplot_dep.visualize.lvm_dimselect(param_to_array(m.X.mean)[0:1,:], # @UnusedVariable
|
||||
m, data_show, latent_axes=latent_axes, sense_axes=sense_axes, labels=m.data_labels)
|
||||
raw_input('Press enter to finish')
|
||||
plt.close(fig)
|
||||
return m
|
||||
|
||||
def bgplvm_oil_100(optimize=True, verbose=1, plot=True, N=200, Q=7, num_inducing=40, max_iters=1000, **k):
|
||||
import GPy
|
||||
from matplotlib import pyplot as plt
|
||||
from ..util.misc import param_to_array
|
||||
|
||||
_np.random.seed(0)
|
||||
data = GPy.util.datasets.oil_100()
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import GPy
|
||||
|
||||
class TieTests(unittest.TestCase):
|
||||
|
|
@ -54,6 +55,28 @@ class TieTests(unittest.TestCase):
|
|||
self.assertTrue(m.ties.checkConstraintConsistency())
|
||||
self.assertTrue(m.ties.checkTieVector([m.Z[:10],m.Z[10:20],m.Z[20:30],m.Z[30:40]]))
|
||||
self.assertTrue(m.checkgrad())
|
||||
|
||||
def test_remove_tie(self):
|
||||
x = np.random.rand(100,1)
|
||||
y = np.random.rand(100,1)
|
||||
m = GPy.models.SparseGPRegression(x,y,kernel=GPy.kern.RBF(1)+GPy.kern.Matern32(1))
|
||||
m.kern.rbf.lengthscale.tie_together(m.kern.Mat32.lengthscale)
|
||||
m.Z[:1].tie_together(m.Z[1:2])
|
||||
m.kern.rbf.variance.tie_together(m.kern.Mat32.variance)
|
||||
m.kern.rbf.lengthscale.untie()
|
||||
self.assertTrue(m.ties.checkValueConsistency())
|
||||
self.assertTrue(m.ties.checkConstraintConsistency())
|
||||
self.assertTrue(m.ties.checkTieTogether([m.kern.rbf.variance,m.kern.Mat32.variance]))
|
||||
self.assertTrue(m.ties.checkTieVector([m.Z[:1],m.Z[1:2]]))
|
||||
self.assertTrue(m.checkgrad())
|
||||
|
||||
def test_tie_variational_posterior(self):
|
||||
m = GPy.examples.dimensionality_reduction.bgplvm_oil_100(plot=False,optimize=False)
|
||||
m.X[:10].tie_vector(m.X[10:20])
|
||||
self.assertTrue(m.ties.checkValueConsistency())
|
||||
self.assertTrue(m.ties.checkConstraintConsistency())
|
||||
self.assertTrue(m.ties.checkTieVector([m.X[:10],m.X[10:20]]))
|
||||
self.assertTrue(m.checkgrad())
|
||||
|
||||
if __name__ == "__main__":
|
||||
print "Running unit tests, please be (very) patient..."
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue