mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-12 05:22:38 +02:00
[ssgplvm] group spike
This commit is contained in:
parent
7188e92efb
commit
b8508cc20c
2 changed files with 39 additions and 14 deletions
|
|
@ -36,8 +36,9 @@ class NormalPrior(VariationalPrior):
|
||||||
variational_posterior.variance.gradient -= (1. - (1. / (variational_posterior.variance))) * 0.5
|
variational_posterior.variance.gradient -= (1. - (1. / (variational_posterior.variance))) * 0.5
|
||||||
|
|
||||||
class SpikeAndSlabPrior(VariationalPrior):
|
class SpikeAndSlabPrior(VariationalPrior):
|
||||||
def __init__(self, pi=None, learnPi=False, variance = 1.0, name='SpikeAndSlabPrior', **kw):
|
def __init__(self, pi=None, learnPi=False, variance = 1.0, group_spike=False, name='SpikeAndSlabPrior', **kw):
|
||||||
super(SpikeAndSlabPrior, self).__init__(name=name, **kw)
|
super(SpikeAndSlabPrior, self).__init__(name=name, **kw)
|
||||||
|
self.group_spike = group_spike
|
||||||
self.variance = Param('variance',variance)
|
self.variance = Param('variance',variance)
|
||||||
self.learnPi = learnPi
|
self.learnPi = learnPi
|
||||||
if learnPi:
|
if learnPi:
|
||||||
|
|
@ -50,7 +51,10 @@ class SpikeAndSlabPrior(VariationalPrior):
|
||||||
def KL_divergence(self, variational_posterior):
|
def KL_divergence(self, variational_posterior):
|
||||||
mu = variational_posterior.mean
|
mu = variational_posterior.mean
|
||||||
S = variational_posterior.variance
|
S = variational_posterior.variance
|
||||||
gamma = variational_posterior.gamma.values
|
if self.group_spike:
|
||||||
|
gamma = variational_posterior.gamma.values[0]
|
||||||
|
else:
|
||||||
|
gamma = variational_posterior.gamma.values
|
||||||
if len(self.pi.shape)==2:
|
if len(self.pi.shape)==2:
|
||||||
idx = np.unique(variational_posterior.gamma._raveled_index()/gamma.shape[-1])
|
idx = np.unique(variational_posterior.gamma._raveled_index()/gamma.shape[-1])
|
||||||
pi = self.pi[idx]
|
pi = self.pi[idx]
|
||||||
|
|
@ -65,14 +69,21 @@ class SpikeAndSlabPrior(VariationalPrior):
|
||||||
def update_gradients_KL(self, variational_posterior):
|
def update_gradients_KL(self, variational_posterior):
|
||||||
mu = variational_posterior.mean
|
mu = variational_posterior.mean
|
||||||
S = variational_posterior.variance
|
S = variational_posterior.variance
|
||||||
gamma = variational_posterior.gamma.values
|
if self.group_spike:
|
||||||
|
gamma = variational_posterior.gamma.values[0]
|
||||||
|
else:
|
||||||
|
gamma = variational_posterior.gamma.values
|
||||||
if len(self.pi.shape)==2:
|
if len(self.pi.shape)==2:
|
||||||
idx = np.unique(variational_posterior.gamma._raveled_index()/gamma.shape[-1])
|
idx = np.unique(variational_posterior.gamma._raveled_index()/gamma.shape[-1])
|
||||||
pi = self.pi[idx]
|
pi = self.pi[idx]
|
||||||
else:
|
else:
|
||||||
pi = self.pi
|
pi = self.pi
|
||||||
|
|
||||||
variational_posterior.binary_prob.gradient -= np.log((1-pi)/pi*gamma/(1.-gamma))+((np.square(mu)+S)/self.variance-np.log(S)+np.log(self.variance)-1.)/2.
|
if self.group_spike:
|
||||||
|
dgamma = np.log((1-pi)/pi*gamma/(1.-gamma))/mu.shape[0]
|
||||||
|
else:
|
||||||
|
dgamma = np.log((1-pi)/pi*gamma/(1.-gamma))
|
||||||
|
variational_posterior.binary_prob.gradient -= dgamma+((np.square(mu)+S)/self.variance-np.log(S)+np.log(self.variance)-1.)/2.
|
||||||
mu.gradient -= gamma*mu/self.variance
|
mu.gradient -= gamma*mu/self.variance
|
||||||
S.gradient -= (1./self.variance - 1./S) * gamma /2.
|
S.gradient -= (1./self.variance - 1./S) * gamma /2.
|
||||||
if self.learnPi:
|
if self.learnPi:
|
||||||
|
|
@ -154,13 +165,27 @@ class SpikeAndSlabPosterior(VariationalPosterior):
|
||||||
'''
|
'''
|
||||||
The SpikeAndSlab distribution for variational approximations.
|
The SpikeAndSlab distribution for variational approximations.
|
||||||
'''
|
'''
|
||||||
def __init__(self, means, variances, binary_prob, name='latent space'):
|
def __init__(self, means, variances, binary_prob, group_spike=False, name='latent space'):
|
||||||
"""
|
"""
|
||||||
binary_prob : the probability of the distribution on the slab part.
|
binary_prob : the probability of the distribution on the slab part.
|
||||||
"""
|
"""
|
||||||
super(SpikeAndSlabPosterior, self).__init__(means, variances, name)
|
super(SpikeAndSlabPosterior, self).__init__(means, variances, name)
|
||||||
self.gamma = Param("binary_prob",binary_prob,Logistic(0.,1.))
|
self.group_spike = group_spike
|
||||||
self.link_parameter(self.gamma)
|
if group_spike:
|
||||||
|
self.gamma_group = Param("binary_prob_group",binary_prob.mean(axis=0),Logistic(0.,1.))
|
||||||
|
self.gamma = Param("binary_prob",binary_prob, __fixed__)
|
||||||
|
self.link_parameters(self.gamma_group,self.gamma)
|
||||||
|
else:
|
||||||
|
self.gamma = Param("binary_prob",binary_prob,Logistic(0.,1.))
|
||||||
|
self.link_parameter(self.gamma)
|
||||||
|
|
||||||
|
def propogate_val(self):
|
||||||
|
if self.group_spike:
|
||||||
|
self.gamma.param_array.values.reshape(self.gamma.shape)[:] = self.gamma_group.values
|
||||||
|
|
||||||
|
def collate_gradient(self):
|
||||||
|
if self.group_spike:
|
||||||
|
self.gamma_group.gradient = self.gamma.gradient.reshape(self.gamma.shape).sum(axis=0)
|
||||||
|
|
||||||
def set_gradients(self, grad):
|
def set_gradients(self, grad):
|
||||||
self.mean.gradient, self.variance.gradient, self.gamma.gradient = grad
|
self.mean.gradient, self.variance.gradient, self.gamma.gradient = grad
|
||||||
|
|
|
||||||
|
|
@ -103,7 +103,7 @@ class SSGPLVM(SparseGP_MPI):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, Y, input_dim, X=None, X_variance=None, Gamma=None, init='PCA', num_inducing=10,
|
def __init__(self, Y, input_dim, X=None, X_variance=None, Gamma=None, init='PCA', num_inducing=10,
|
||||||
Z=None, kernel=None, inference_method=None, likelihood=None, name='Spike_and_Slab GPLVM', group_spike=False, mpi_comm=None, pi=None, learnPi=True,normalizer=False, **kwargs):
|
Z=None, kernel=None, inference_method=None, likelihood=None, name='Spike_and_Slab GPLVM', group_spike=False, mpi_comm=None, pi=None, learnPi=False,normalizer=False, **kwargs):
|
||||||
|
|
||||||
self.group_spike = group_spike
|
self.group_spike = group_spike
|
||||||
|
|
||||||
|
|
@ -144,15 +144,12 @@ class SSGPLVM(SparseGP_MPI):
|
||||||
if pi is None:
|
if pi is None:
|
||||||
pi = np.empty((input_dim))
|
pi = np.empty((input_dim))
|
||||||
pi[:] = 0.5
|
pi[:] = 0.5
|
||||||
self.variational_prior = SpikeAndSlabPrior(pi=pi,learnPi=learnPi) # the prior probability of the latent binary variable b
|
self.variational_prior = SpikeAndSlabPrior(pi=pi,learnPi=learnPi, group_spike=group_spike) # the prior probability of the latent binary variable b
|
||||||
|
|
||||||
X = SpikeAndSlabPosterior(X, X_variance, gamma)
|
X = SpikeAndSlabPosterior(X, X_variance, gamma, group_spike=group_spike)
|
||||||
|
|
||||||
super(SSGPLVM,self).__init__(X, Y, Z, kernel, likelihood, variational_prior=self.variational_prior, inference_method=inference_method, name=name, mpi_comm=mpi_comm, normalizer=normalizer, **kwargs)
|
super(SSGPLVM,self).__init__(X, Y, Z, kernel, likelihood, variational_prior=self.variational_prior, inference_method=inference_method, name=name, mpi_comm=mpi_comm, normalizer=normalizer, **kwargs)
|
||||||
self.link_parameter(self.X, index=0)
|
self.link_parameter(self.X, index=0)
|
||||||
|
|
||||||
if self.group_spike:
|
|
||||||
[self.X.gamma[:,i].tie('tieGamma'+str(i)) for i in range(self.X.gamma.shape[1])] # Tie columns together
|
|
||||||
|
|
||||||
def set_X_gradients(self, X, X_grad):
|
def set_X_gradients(self, X, X_grad):
|
||||||
"""Set the gradients of the posterior distribution of X in its specific form."""
|
"""Set the gradients of the posterior distribution of X in its specific form."""
|
||||||
|
|
@ -163,8 +160,10 @@ class SSGPLVM(SparseGP_MPI):
|
||||||
return X.mean.gradient, X.variance.gradient, X.binary_prob.gradient
|
return X.mean.gradient, X.variance.gradient, X.binary_prob.gradient
|
||||||
|
|
||||||
def parameters_changed(self):
|
def parameters_changed(self):
|
||||||
|
self.X.propogate_val()
|
||||||
super(SSGPLVM,self).parameters_changed()
|
super(SSGPLVM,self).parameters_changed()
|
||||||
if isinstance(self.inference_method, VarDTC_minibatch):
|
if isinstance(self.inference_method, VarDTC_minibatch):
|
||||||
|
self.X.collate_gradient()
|
||||||
return
|
return
|
||||||
|
|
||||||
self._log_marginal_likelihood -= self.variational_prior.KL_divergence(self.X)
|
self._log_marginal_likelihood -= self.variational_prior.KL_divergence(self.X)
|
||||||
|
|
@ -173,6 +172,7 @@ class SSGPLVM(SparseGP_MPI):
|
||||||
|
|
||||||
# update for the KL divergence
|
# update for the KL divergence
|
||||||
self.variational_prior.update_gradients_KL(self.X)
|
self.variational_prior.update_gradients_KL(self.X)
|
||||||
|
self.X.collate_gradient()
|
||||||
|
|
||||||
def input_sensitivity(self):
|
def input_sensitivity(self):
|
||||||
if self.kern.ARD:
|
if self.kern.ARD:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue