minor changes on SSGPLVM

This commit is contained in:
Zhenwen Dai 2014-08-18 15:17:52 +01:00
parent 4fbec91eda
commit 6b8ac70210

View file

@ -24,7 +24,7 @@ class SSGPLVM(SparseGP):
""" """
def __init__(self, Y, input_dim, X=None, X_variance=None, init='PCA', num_inducing=10, def __init__(self, Y, input_dim, X=None, X_variance=None, init='PCA', num_inducing=10,
Z=None, kernel=None, inference_method=None, likelihood=None, name='Spike_and_Slab GPLVM', group_spike=False, mpi_comm=None, **kwargs): Z=None, kernel=None, inference_method=None, likelihood=None, name='Spike_and_Slab GPLVM', group_spike=False, mpi_comm=None, pi=None, learnPi=True, **kwargs):
self.mpi_comm = mpi_comm self.mpi_comm = mpi_comm
self.__IN_OPTIMIZATION__ = False self.__IN_OPTIMIZATION__ = False
@ -49,9 +49,6 @@ class SSGPLVM(SparseGP):
if Z is None: if Z is None:
Z = np.random.permutation(X.copy())[:num_inducing] Z = np.random.permutation(X.copy())[:num_inducing]
assert Z.shape[1] == X.shape[1] assert Z.shape[1] == X.shape[1]
pi = np.empty((input_dim))
pi[:] = 0.5
if likelihood is None: if likelihood is None:
likelihood = Gaussian() likelihood = Gaussian()
@ -64,7 +61,10 @@ class SSGPLVM(SparseGP):
if inference_method is None: if inference_method is None:
inference_method = VarDTC_minibatch(mpi_comm=mpi_comm) inference_method = VarDTC_minibatch(mpi_comm=mpi_comm)
self.variational_prior = SpikeAndSlabPrior(pi=pi,learnPi=True) # the prior probability of the latent binary variable b if pi is None:
pi = np.empty((input_dim))
pi[:] = 0.5
self.variational_prior = SpikeAndSlabPrior(pi=pi,learnPi=learnPi) # the prior probability of the latent binary variable b
X = SpikeAndSlabPosterior(X, X_variance, gamma) X = SpikeAndSlabPosterior(X, X_variance, gamma)