diff --git a/GPy/models/ss_gplvm.py b/GPy/models/ss_gplvm.py index d1b07415..027f9783 100644 --- a/GPy/models/ss_gplvm.py +++ b/GPy/models/ss_gplvm.py @@ -24,7 +24,7 @@ class SSGPLVM(SparseGP): """ def __init__(self, Y, input_dim, X=None, X_variance=None, init='PCA', num_inducing=10, - Z=None, kernel=None, inference_method=None, likelihood=None, name='Spike_and_Slab GPLVM', group_spike=False, mpi_comm=None, **kwargs): + Z=None, kernel=None, inference_method=None, likelihood=None, name='Spike_and_Slab GPLVM', group_spike=False, mpi_comm=None, pi=None, learnPi=True, **kwargs): self.mpi_comm = mpi_comm self.__IN_OPTIMIZATION__ = False @@ -49,9 +49,6 @@ class SSGPLVM(SparseGP): if Z is None: Z = np.random.permutation(X.copy())[:num_inducing] assert Z.shape[1] == X.shape[1] - - pi = np.empty((input_dim)) - pi[:] = 0.5 if likelihood is None: likelihood = Gaussian() @@ -64,7 +61,10 @@ class SSGPLVM(SparseGP): if inference_method is None: inference_method = VarDTC_minibatch(mpi_comm=mpi_comm) - self.variational_prior = SpikeAndSlabPrior(pi=pi,learnPi=True) # the prior probability of the latent binary variable b + if pi is None: + pi = np.empty((input_dim)) + pi[:] = 0.5 + self.variational_prior = SpikeAndSlabPrior(pi=pi,learnPi=learnPi) # the prior probability of the latent binary variable b X = SpikeAndSlabPosterior(X, X_variance, gamma)