generalize the interface of mpi

This commit is contained in:
Zhenwen Dai 2014-08-26 18:12:41 +01:00
parent a853d060fb
commit f29753e9e6
4 changed files with 143 additions and 80 deletions

View file

@ -55,8 +55,13 @@ class SpikeAndSlabPrior(VariationalPrior):
mu = variational_posterior.mean
S = variational_posterior.variance
gamma = variational_posterior.binary_prob
if len(self.pi.shape)==2:
idx = np.unique(gamma._raveled_index()/gamma.shape[-1])
pi = self.pi[idx]
else:
pi = self.pi
gamma.gradient -= np.log((1-self.pi)/self.pi*gamma/(1.-gamma))+((np.square(mu)+S)/self.variance-np.log(S)+np.log(self.variance)-1.)/2.
gamma.gradient -= np.log((1-pi)/pi*gamma/(1.-gamma))+((np.square(mu)+S)/self.variance-np.log(S)+np.log(self.variance)-1.)/2.
mu.gradient -= gamma*mu/self.variance
S.gradient -= (1./self.variance - 1./S) * gamma /2.
if self.learnPi:
@ -65,7 +70,7 @@ class SpikeAndSlabPrior(VariationalPrior):
elif len(self.pi.shape)==1:
self.pi.gradient = (gamma/self.pi - (1.-gamma)/(1.-self.pi)).sum(axis=0)
else:
self.pi.gradient = (gamma/self.pi - (1.-gamma)/(1.-self.pi))
self.pi[idx].gradient = (gamma/self.pi[idx] - (1.-gamma)/(1.-self.pi[idx]))
class VariationalPosterior(Parameterized):
def __init__(self, means=None, variances=None, name='latent space', *a, **kw):