mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 20:42:39 +02:00
minor changes on SSGPLVM
This commit is contained in:
parent
4fbec91eda
commit
6b8ac70210
1 changed files with 5 additions and 5 deletions
|
|
@ -24,7 +24,7 @@ class SSGPLVM(SparseGP):
|
|||
|
||||
"""
|
||||
def __init__(self, Y, input_dim, X=None, X_variance=None, init='PCA', num_inducing=10,
|
||||
Z=None, kernel=None, inference_method=None, likelihood=None, name='Spike_and_Slab GPLVM', group_spike=False, mpi_comm=None, **kwargs):
|
||||
Z=None, kernel=None, inference_method=None, likelihood=None, name='Spike_and_Slab GPLVM', group_spike=False, mpi_comm=None, pi=None, learnPi=True, **kwargs):
|
||||
|
||||
self.mpi_comm = mpi_comm
|
||||
self.__IN_OPTIMIZATION__ = False
|
||||
|
|
@ -50,9 +50,6 @@ class SSGPLVM(SparseGP):
|
|||
Z = np.random.permutation(X.copy())[:num_inducing]
|
||||
assert Z.shape[1] == X.shape[1]
|
||||
|
||||
pi = np.empty((input_dim))
|
||||
pi[:] = 0.5
|
||||
|
||||
if likelihood is None:
|
||||
likelihood = Gaussian()
|
||||
|
||||
|
|
@ -64,7 +61,10 @@ class SSGPLVM(SparseGP):
|
|||
if inference_method is None:
|
||||
inference_method = VarDTC_minibatch(mpi_comm=mpi_comm)
|
||||
|
||||
self.variational_prior = SpikeAndSlabPrior(pi=pi,learnPi=True) # the prior probability of the latent binary variable b
|
||||
if pi is None:
|
||||
pi = np.empty((input_dim))
|
||||
pi[:] = 0.5
|
||||
self.variational_prior = SpikeAndSlabPrior(pi=pi,learnPi=learnPi) # the prior probability of the latent binary variable b
|
||||
|
||||
X = SpikeAndSlabPosterior(X, X_variance, gamma)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue