mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-05 09:42:39 +02:00
[ssgplvm] linear kernel
This commit is contained in:
parent
b65da11df5
commit
dad476faf6
7 changed files with 162 additions and 104 deletions
|
|
@ -44,8 +44,10 @@ class SSGPLVM(SparseGP):
|
|||
X_variance = np.random.uniform(0,.1,X.shape)
|
||||
|
||||
gamma = np.empty_like(X, order='F') # The posterior probabilities of the binary variable in the variational approximation
|
||||
#gamma[:] = 0.5 + 0.01 * np.random.randn(X.shape[0], input_dim)
|
||||
gamma[:] = 0.5
|
||||
gamma[:] = 0.5 + 0.1 * np.random.randn(X.shape[0], input_dim)
|
||||
gamma[gamma>=1. - 1e-9] = 1e-9
|
||||
gamma[gamma<1e-9] = 1e-9
|
||||
#gamma[:] = 0.5
|
||||
|
||||
if group_spike:
|
||||
gamma[:] = gamma.mean(axis=0)
|
||||
|
|
@ -57,19 +59,20 @@ class SSGPLVM(SparseGP):
|
|||
pi = np.empty((input_dim))
|
||||
pi[:] = 0.5
|
||||
|
||||
if mpi_comm != None:
|
||||
mpi_comm.Bcast(X, root=0)
|
||||
mpi_comm.Bcast(fracs, root=0)
|
||||
mpi_comm.Bcast(X_variance, root=0)
|
||||
mpi_comm.Bcast(gamma, root=0)
|
||||
mpi_comm.Bcast(Z, root=0)
|
||||
mpi_comm.Bcast(pi, root=0)
|
||||
# if mpi_comm != None:
|
||||
# mpi_comm.Bcast(X, root=0)
|
||||
# mpi_comm.Bcast(fracs, root=0)
|
||||
# mpi_comm.Bcast(X_variance, root=0)
|
||||
# mpi_comm.Bcast(gamma, root=0)
|
||||
# mpi_comm.Bcast(Z, root=0)
|
||||
# mpi_comm.Bcast(pi, root=0)
|
||||
|
||||
if likelihood is None:
|
||||
likelihood = Gaussian()
|
||||
|
||||
if kernel is None:
|
||||
kernel = kern.RBF(input_dim, lengthscale=fracs, ARD=True) # + kern.white(input_dim)
|
||||
kernel.set_for_SpikeAndSlab()
|
||||
|
||||
self.variational_prior = SpikeAndSlabPrior(pi=pi) # the prior probability of the latent binary variable b
|
||||
|
||||
|
|
@ -90,6 +93,7 @@ class SSGPLVM(SparseGP):
|
|||
self.X_local = self.X[Y_start:Y_end]
|
||||
self.Y_range = (Y_start, Y_end)
|
||||
self.Y_list = np.array(Y_list)
|
||||
[mpi_comm.Bcast(p, root=0) for p in self.flattened_parameters]
|
||||
|
||||
def set_X_gradients(self, X, X_grad):
|
||||
"""Set the gradients of the posterior distribution of X in its specific form."""
|
||||
|
|
@ -125,3 +129,16 @@ class SSGPLVM(SparseGP):
|
|||
|
||||
return dim_reduction_plots.plot_latent(self, plot_inducing=plot_inducing, *args, **kwargs)
|
||||
|
||||
def __getstate__(self):
|
||||
dc = super(SSGPLVM, self).__getstate__()
|
||||
del dc['mpi_comm']
|
||||
del dc['Y_local']
|
||||
del dc['X_local']
|
||||
return dc
|
||||
|
||||
def __setstate__(self, state):
|
||||
state['mpi_comm'] = None
|
||||
Y_range = state['Y_range']
|
||||
state['Y_local'] = state['Y'][Y_range[0]:Y_range[1]]
|
||||
state['X_local'] = state['X'][Y_range[0]:Y_range[1]]
|
||||
return super(SSGPLVM, self).__setstate__(state)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue