diff --git a/GPy/core/parameterization/variational.py b/GPy/core/parameterization/variational.py index 09191c0b..377c1d4c 100644 --- a/GPy/core/parameterization/variational.py +++ b/GPy/core/parameterization/variational.py @@ -36,8 +36,9 @@ class NormalPrior(VariationalPrior): variational_posterior.variance.gradient -= (1. - (1. / (variational_posterior.variance))) * 0.5 class SpikeAndSlabPrior(VariationalPrior): - def __init__(self, pi=None, learnPi=False, variance = 1.0, name='SpikeAndSlabPrior', **kw): - super(SpikeAndSlabPrior, self).__init__(name=name, **kw) + def __init__(self, pi=None, learnPi=False, variance = 1.0, group_spike=False, name='SpikeAndSlabPrior', **kw): + super(SpikeAndSlabPrior, self).__init__(name=name, **kw) + self.group_spike = group_spike self.variance = Param('variance',variance) self.learnPi = learnPi if learnPi: @@ -50,7 +51,10 @@ class SpikeAndSlabPrior(VariationalPrior): def KL_divergence(self, variational_posterior): mu = variational_posterior.mean S = variational_posterior.variance - gamma = variational_posterior.gamma.values + if self.group_spike: + gamma = variational_posterior.gamma.values[0] + else: + gamma = variational_posterior.gamma.values if len(self.pi.shape)==2: idx = np.unique(variational_posterior.gamma._raveled_index()/gamma.shape[-1]) pi = self.pi[idx] @@ -65,14 +69,21 @@ class SpikeAndSlabPrior(VariationalPrior): def update_gradients_KL(self, variational_posterior): mu = variational_posterior.mean S = variational_posterior.variance - gamma = variational_posterior.gamma.values + if self.group_spike: + gamma = variational_posterior.gamma.values[0] + else: + gamma = variational_posterior.gamma.values if len(self.pi.shape)==2: idx = np.unique(variational_posterior.gamma._raveled_index()/gamma.shape[-1]) pi = self.pi[idx] else: pi = self.pi - variational_posterior.binary_prob.gradient -= np.log((1-pi)/pi*gamma/(1.-gamma))+((np.square(mu)+S)/self.variance-np.log(S)+np.log(self.variance)-1.)/2. + if self.group_spike: + dgamma = np.log((1-pi)/pi*gamma/(1.-gamma))/mu.shape[0] + else: + dgamma = np.log((1-pi)/pi*gamma/(1.-gamma)) + variational_posterior.binary_prob.gradient -= dgamma+((np.square(mu)+S)/self.variance-np.log(S)+np.log(self.variance)-1.)/2. mu.gradient -= gamma*mu/self.variance S.gradient -= (1./self.variance - 1./S) * gamma /2. if self.learnPi: @@ -154,13 +165,27 @@ class SpikeAndSlabPosterior(VariationalPosterior): ''' The SpikeAndSlab distribution for variational approximations. ''' - def __init__(self, means, variances, binary_prob, name='latent space'): + def __init__(self, means, variances, binary_prob, group_spike=False, name='latent space'): """ binary_prob : the probability of the distribution on the slab part. """ super(SpikeAndSlabPosterior, self).__init__(means, variances, name) - self.gamma = Param("binary_prob",binary_prob,Logistic(0.,1.)) - self.link_parameter(self.gamma) + self.group_spike = group_spike + if group_spike: + self.gamma_group = Param("binary_prob_group",binary_prob.mean(axis=0),Logistic(0.,1.)) + self.gamma = Param("binary_prob",binary_prob, __fixed__) + self.link_parameters(self.gamma_group,self.gamma) + else: + self.gamma = Param("binary_prob",binary_prob,Logistic(0.,1.)) + self.link_parameter(self.gamma) + + def propogate_val(self): + if self.group_spike: + self.gamma.param_array.values.reshape(self.gamma.shape)[:] = self.gamma_group.values + + def collate_gradient(self): + if self.group_spike: + self.gamma_group.gradient = self.gamma.gradient.reshape(self.gamma.shape).sum(axis=0) def set_gradients(self, grad): self.mean.gradient, self.variance.gradient, self.gamma.gradient = grad diff --git a/GPy/models/ss_gplvm.py b/GPy/models/ss_gplvm.py index 44eea74c..c2b7c0d0 100644 --- a/GPy/models/ss_gplvm.py +++ b/GPy/models/ss_gplvm.py @@ -103,7 +103,7 @@ class SSGPLVM(SparseGP_MPI): """ def __init__(self, Y, input_dim, X=None, X_variance=None, Gamma=None, init='PCA', num_inducing=10, - Z=None, kernel=None, inference_method=None, likelihood=None, name='Spike_and_Slab GPLVM', group_spike=False, mpi_comm=None, pi=None, learnPi=True,normalizer=False, **kwargs): + Z=None, kernel=None, inference_method=None, likelihood=None, name='Spike_and_Slab GPLVM', group_spike=False, mpi_comm=None, pi=None, learnPi=False,normalizer=False, **kwargs): self.group_spike = group_spike @@ -144,15 +144,12 @@ class SSGPLVM(SparseGP_MPI): if pi is None: pi = np.empty((input_dim)) pi[:] = 0.5 - self.variational_prior = SpikeAndSlabPrior(pi=pi,learnPi=learnPi) # the prior probability of the latent binary variable b + self.variational_prior = SpikeAndSlabPrior(pi=pi,learnPi=learnPi, group_spike=group_spike) # the prior probability of the latent binary variable b - X = SpikeAndSlabPosterior(X, X_variance, gamma) + X = SpikeAndSlabPosterior(X, X_variance, gamma, group_spike=group_spike) super(SSGPLVM,self).__init__(X, Y, Z, kernel, likelihood, variational_prior=self.variational_prior, inference_method=inference_method, name=name, mpi_comm=mpi_comm, normalizer=normalizer, **kwargs) self.link_parameter(self.X, index=0) - - if self.group_spike: - [self.X.gamma[:,i].tie('tieGamma'+str(i)) for i in range(self.X.gamma.shape[1])] # Tie columns together def set_X_gradients(self, X, X_grad): """Set the gradients of the posterior distribution of X in its specific form.""" @@ -163,8 +160,10 @@ class SSGPLVM(SparseGP_MPI): return X.mean.gradient, X.variance.gradient, X.binary_prob.gradient def parameters_changed(self): + self.X.propogate_val() super(SSGPLVM,self).parameters_changed() if isinstance(self.inference_method, VarDTC_minibatch): + self.X.collate_gradient() return self._log_marginal_likelihood -= self.variational_prior.KL_divergence(self.X) @@ -173,6 +172,7 @@ class SSGPLVM(SparseGP_MPI): # update for the KL divergence self.variational_prior.update_gradients_KL(self.X) + self.X.collate_gradient() def input_sensitivity(self): if self.kern.ARD: