[ssgplvm] linear kernel

This commit is contained in:
Zhenwen Dai 2014-05-15 11:43:29 +01:00
parent b65da11df5
commit dad476faf6
7 changed files with 162 additions and 104 deletions

View file

@ -44,8 +44,10 @@ class SSGPLVM(SparseGP):
X_variance = np.random.uniform(0,.1,X.shape)
gamma = np.empty_like(X, order='F') # The posterior probabilities of the binary variable in the variational approximation
#gamma[:] = 0.5 + 0.01 * np.random.randn(X.shape[0], input_dim)
gamma[:] = 0.5
gamma[:] = 0.5 + 0.1 * np.random.randn(X.shape[0], input_dim)
gamma[gamma>=1. - 1e-9] = 1e-9
gamma[gamma<1e-9] = 1e-9
#gamma[:] = 0.5
if group_spike:
gamma[:] = gamma.mean(axis=0)
@ -57,19 +59,20 @@ class SSGPLVM(SparseGP):
pi = np.empty((input_dim))
pi[:] = 0.5
if mpi_comm != None:
mpi_comm.Bcast(X, root=0)
mpi_comm.Bcast(fracs, root=0)
mpi_comm.Bcast(X_variance, root=0)
mpi_comm.Bcast(gamma, root=0)
mpi_comm.Bcast(Z, root=0)
mpi_comm.Bcast(pi, root=0)
# if mpi_comm != None:
# mpi_comm.Bcast(X, root=0)
# mpi_comm.Bcast(fracs, root=0)
# mpi_comm.Bcast(X_variance, root=0)
# mpi_comm.Bcast(gamma, root=0)
# mpi_comm.Bcast(Z, root=0)
# mpi_comm.Bcast(pi, root=0)
if likelihood is None:
likelihood = Gaussian()
if kernel is None:
kernel = kern.RBF(input_dim, lengthscale=fracs, ARD=True) # + kern.white(input_dim)
kernel.set_for_SpikeAndSlab()
self.variational_prior = SpikeAndSlabPrior(pi=pi) # the prior probability of the latent binary variable b
@ -90,6 +93,7 @@ class SSGPLVM(SparseGP):
self.X_local = self.X[Y_start:Y_end]
self.Y_range = (Y_start, Y_end)
self.Y_list = np.array(Y_list)
[mpi_comm.Bcast(p, root=0) for p in self.flattened_parameters]
def set_X_gradients(self, X, X_grad):
"""Set the gradients of the posterior distribution of X in its specific form."""
@ -125,3 +129,16 @@ class SSGPLVM(SparseGP):
return dim_reduction_plots.plot_latent(self, plot_inducing=plot_inducing, *args, **kwargs)
def __getstate__(self):
dc = super(SSGPLVM, self).__getstate__()
del dc['mpi_comm']
del dc['Y_local']
del dc['X_local']
return dc
def __setstate__(self, state):
state['mpi_comm'] = None
Y_range = state['Y_range']
state['Y_local'] = state['Y'][Y_range[0]:Y_range[1]]
state['X_local'] = state['X'][Y_range[0]:Y_range[1]]
return super(SSGPLVM, self).__setstate__(state)