further bug fix for sparsegp_mpi

This commit is contained in:
Zhenwen Dai 2014-08-27 09:45:06 +01:00
parent 1e1bbb2a26
commit 7ed0e70a46
3 changed files with 24 additions and 16 deletions

View file

@ -46,9 +46,15 @@ class SpikeAndSlabPrior(VariationalPrior):
mu = variational_posterior.mean
S = variational_posterior.variance
gamma = variational_posterior.binary_prob
if len(self.pi.shape)==2:
idx = np.unique(gamma._raveled_index()/gamma.shape[-1])
pi = self.pi[idx]
else:
pi = self.pi
var_mean = np.square(mu)/self.variance
var_S = (S/self.variance - np.log(S))
var_gamma = (gamma*np.log(gamma/self.pi)).sum()+((1-gamma)*np.log((1-gamma)/(1-self.pi))).sum()
var_gamma = (gamma*np.log(gamma/pi)).sum()+((1-gamma)*np.log((1-gamma)/(1-pi))).sum()
return var_gamma+ (gamma* (np.log(self.variance)-1. +var_mean + var_S)).sum()/2.
def update_gradients_KL(self, variational_posterior):