mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 12:02:38 +02:00
[ssgplvm] group spike
This commit is contained in:
parent
7188e92efb
commit
b8508cc20c
2 changed files with 39 additions and 14 deletions
|
|
@ -103,7 +103,7 @@ class SSGPLVM(SparseGP_MPI):
|
|||
|
||||
"""
|
||||
def __init__(self, Y, input_dim, X=None, X_variance=None, Gamma=None, init='PCA', num_inducing=10,
|
||||
Z=None, kernel=None, inference_method=None, likelihood=None, name='Spike_and_Slab GPLVM', group_spike=False, mpi_comm=None, pi=None, learnPi=True,normalizer=False, **kwargs):
|
||||
Z=None, kernel=None, inference_method=None, likelihood=None, name='Spike_and_Slab GPLVM', group_spike=False, mpi_comm=None, pi=None, learnPi=False,normalizer=False, **kwargs):
|
||||
|
||||
self.group_spike = group_spike
|
||||
|
||||
|
|
@ -144,15 +144,12 @@ class SSGPLVM(SparseGP_MPI):
|
|||
if pi is None:
|
||||
pi = np.empty((input_dim))
|
||||
pi[:] = 0.5
|
||||
self.variational_prior = SpikeAndSlabPrior(pi=pi,learnPi=learnPi) # the prior probability of the latent binary variable b
|
||||
self.variational_prior = SpikeAndSlabPrior(pi=pi,learnPi=learnPi, group_spike=group_spike) # the prior probability of the latent binary variable b
|
||||
|
||||
X = SpikeAndSlabPosterior(X, X_variance, gamma)
|
||||
X = SpikeAndSlabPosterior(X, X_variance, gamma, group_spike=group_spike)
|
||||
|
||||
super(SSGPLVM,self).__init__(X, Y, Z, kernel, likelihood, variational_prior=self.variational_prior, inference_method=inference_method, name=name, mpi_comm=mpi_comm, normalizer=normalizer, **kwargs)
|
||||
self.link_parameter(self.X, index=0)
|
||||
|
||||
if self.group_spike:
|
||||
[self.X.gamma[:,i].tie('tieGamma'+str(i)) for i in range(self.X.gamma.shape[1])] # Tie columns together
|
||||
|
||||
def set_X_gradients(self, X, X_grad):
|
||||
"""Set the gradients of the posterior distribution of X in its specific form."""
|
||||
|
|
@ -163,8 +160,10 @@ class SSGPLVM(SparseGP_MPI):
|
|||
return X.mean.gradient, X.variance.gradient, X.binary_prob.gradient
|
||||
|
||||
def parameters_changed(self):
|
||||
self.X.propogate_val()
|
||||
super(SSGPLVM,self).parameters_changed()
|
||||
if isinstance(self.inference_method, VarDTC_minibatch):
|
||||
self.X.collate_gradient()
|
||||
return
|
||||
|
||||
self._log_marginal_likelihood -= self.variational_prior.KL_divergence(self.X)
|
||||
|
|
@ -173,6 +172,7 @@ class SSGPLVM(SparseGP_MPI):
|
|||
|
||||
# update for the KL divergence
|
||||
self.variational_prior.update_gradients_KL(self.X)
|
||||
self.X.collate_gradient()
|
||||
|
||||
def input_sensitivity(self):
|
||||
if self.kern.ARD:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue