mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 20:42:39 +02:00
minor changes on SSGPLVM
This commit is contained in:
parent
4fbec91eda
commit
6b8ac70210
1 changed files with 5 additions and 5 deletions
|
|
@ -24,7 +24,7 @@ class SSGPLVM(SparseGP):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, Y, input_dim, X=None, X_variance=None, init='PCA', num_inducing=10,
|
def __init__(self, Y, input_dim, X=None, X_variance=None, init='PCA', num_inducing=10,
|
||||||
Z=None, kernel=None, inference_method=None, likelihood=None, name='Spike_and_Slab GPLVM', group_spike=False, mpi_comm=None, **kwargs):
|
Z=None, kernel=None, inference_method=None, likelihood=None, name='Spike_and_Slab GPLVM', group_spike=False, mpi_comm=None, pi=None, learnPi=True, **kwargs):
|
||||||
|
|
||||||
self.mpi_comm = mpi_comm
|
self.mpi_comm = mpi_comm
|
||||||
self.__IN_OPTIMIZATION__ = False
|
self.__IN_OPTIMIZATION__ = False
|
||||||
|
|
@ -50,9 +50,6 @@ class SSGPLVM(SparseGP):
|
||||||
Z = np.random.permutation(X.copy())[:num_inducing]
|
Z = np.random.permutation(X.copy())[:num_inducing]
|
||||||
assert Z.shape[1] == X.shape[1]
|
assert Z.shape[1] == X.shape[1]
|
||||||
|
|
||||||
pi = np.empty((input_dim))
|
|
||||||
pi[:] = 0.5
|
|
||||||
|
|
||||||
if likelihood is None:
|
if likelihood is None:
|
||||||
likelihood = Gaussian()
|
likelihood = Gaussian()
|
||||||
|
|
||||||
|
|
@ -64,7 +61,10 @@ class SSGPLVM(SparseGP):
|
||||||
if inference_method is None:
|
if inference_method is None:
|
||||||
inference_method = VarDTC_minibatch(mpi_comm=mpi_comm)
|
inference_method = VarDTC_minibatch(mpi_comm=mpi_comm)
|
||||||
|
|
||||||
self.variational_prior = SpikeAndSlabPrior(pi=pi,learnPi=True) # the prior probability of the latent binary variable b
|
if pi is None:
|
||||||
|
pi = np.empty((input_dim))
|
||||||
|
pi[:] = 0.5
|
||||||
|
self.variational_prior = SpikeAndSlabPrior(pi=pi,learnPi=learnPi) # the prior probability of the latent binary variable b
|
||||||
|
|
||||||
X = SpikeAndSlabPosterior(X, X_variance, gamma)
|
X = SpikeAndSlabPosterior(X, X_variance, gamma)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue