[ssgplvm] group spike

This commit is contained in:
Zhenwen Dai 2015-05-21 11:33:37 +01:00
parent 7188e92efb
commit b8508cc20c
2 changed files with 39 additions and 14 deletions

View file

@ -36,8 +36,9 @@ class NormalPrior(VariationalPrior):
variational_posterior.variance.gradient -= (1. - (1. / (variational_posterior.variance))) * 0.5
class SpikeAndSlabPrior(VariationalPrior):
def __init__(self, pi=None, learnPi=False, variance = 1.0, name='SpikeAndSlabPrior', **kw):
super(SpikeAndSlabPrior, self).__init__(name=name, **kw)
def __init__(self, pi=None, learnPi=False, variance = 1.0, group_spike=False, name='SpikeAndSlabPrior', **kw):
super(SpikeAndSlabPrior, self).__init__(name=name, **kw)
self.group_spike = group_spike
self.variance = Param('variance',variance)
self.learnPi = learnPi
if learnPi:
@ -50,7 +51,10 @@ class SpikeAndSlabPrior(VariationalPrior):
def KL_divergence(self, variational_posterior):
mu = variational_posterior.mean
S = variational_posterior.variance
gamma = variational_posterior.gamma.values
if self.group_spike:
gamma = variational_posterior.gamma.values[0]
else:
gamma = variational_posterior.gamma.values
if len(self.pi.shape)==2:
idx = np.unique(variational_posterior.gamma._raveled_index()/gamma.shape[-1])
pi = self.pi[idx]
@ -65,14 +69,21 @@ class SpikeAndSlabPrior(VariationalPrior):
def update_gradients_KL(self, variational_posterior):
mu = variational_posterior.mean
S = variational_posterior.variance
gamma = variational_posterior.gamma.values
if self.group_spike:
gamma = variational_posterior.gamma.values[0]
else:
gamma = variational_posterior.gamma.values
if len(self.pi.shape)==2:
idx = np.unique(variational_posterior.gamma._raveled_index()/gamma.shape[-1])
pi = self.pi[idx]
else:
pi = self.pi
variational_posterior.binary_prob.gradient -= np.log((1-pi)/pi*gamma/(1.-gamma))+((np.square(mu)+S)/self.variance-np.log(S)+np.log(self.variance)-1.)/2.
if self.group_spike:
dgamma = np.log((1-pi)/pi*gamma/(1.-gamma))/mu.shape[0]
else:
dgamma = np.log((1-pi)/pi*gamma/(1.-gamma))
variational_posterior.binary_prob.gradient -= dgamma+((np.square(mu)+S)/self.variance-np.log(S)+np.log(self.variance)-1.)/2.
mu.gradient -= gamma*mu/self.variance
S.gradient -= (1./self.variance - 1./S) * gamma /2.
if self.learnPi:
@ -154,13 +165,27 @@ class SpikeAndSlabPosterior(VariationalPosterior):
'''
The SpikeAndSlab distribution for variational approximations.
'''
def __init__(self, means, variances, binary_prob, name='latent space'):
def __init__(self, means, variances, binary_prob, group_spike=False, name='latent space'):
"""
binary_prob : the probability of the distribution on the slab part.
"""
super(SpikeAndSlabPosterior, self).__init__(means, variances, name)
self.gamma = Param("binary_prob",binary_prob,Logistic(0.,1.))
self.link_parameter(self.gamma)
self.group_spike = group_spike
if group_spike:
self.gamma_group = Param("binary_prob_group",binary_prob.mean(axis=0),Logistic(0.,1.))
self.gamma = Param("binary_prob",binary_prob, __fixed__)
self.link_parameters(self.gamma_group,self.gamma)
else:
self.gamma = Param("binary_prob",binary_prob,Logistic(0.,1.))
self.link_parameter(self.gamma)
def propogate_val(self):
if self.group_spike:
self.gamma.param_array.values.reshape(self.gamma.shape)[:] = self.gamma_group.values
def collate_gradient(self):
if self.group_spike:
self.gamma_group.gradient = self.gamma.gradient.reshape(self.gamma.shape).sum(axis=0)
def set_gradients(self, grad):
self.mean.gradient, self.variance.gradient, self.gamma.gradient = grad