chg: fixed naming in variational priors

:
This commit is contained in:
Max Zwiessele 2016-10-03 12:19:24 +01:00
parent 59be79e962
commit 543a04dc1b

View file

@ -10,7 +10,7 @@ from .param import Param
from paramz.transformations import Logexp, Logistic,__fixed__
class VariationalPrior(Parameterized):
def __init__(self, name='latent space', **kw):
def __init__(self, name='latent prior', **kw):
super(VariationalPrior, self).__init__(name=name, **kw)
def KL_divergence(self, variational_posterior):
@ -23,6 +23,9 @@ class VariationalPrior(Parameterized):
raise NotImplementedError("override this for variational inference of latent space")
class NormalPrior(VariationalPrior):
def __init__(self, name='normal_prior', **kw):
super(VariationalPrior, self).__init__(name=name, **kw)
def KL_divergence(self, variational_posterior):
var_mean = np.square(variational_posterior.mean).sum()
var_S = (variational_posterior.variance - np.log(variational_posterior.variance)).sum()
@ -58,7 +61,7 @@ class SpikeAndSlabPrior(VariationalPrior):
pi = self.pi[idx]
else:
pi = self.pi
var_mean = np.square(mu)/self.variance
var_S = (S/self.variance - np.log(S))
var_gamma = (gamma*np.log(gamma/pi)).sum()+((1-gamma)*np.log((1-gamma)/(1-pi))).sum()
@ -163,12 +166,12 @@ class NormalPosterior(VariationalPosterior):
"""Compute the KL divergence to another NormalPosterior Object. This only holds, if the two NormalPosterior objects have the same shape, as we do computational tricks for the multivariate normal KL divergence.
"""
return .5*(
np.sum(self.variance/other.variance)
+ ((other.mean-self.mean)**2/other.variance).sum()
np.sum(self.variance/other.variance)
+ ((other.mean-self.mean)**2/other.variance).sum()
- self.num_data * self.input_dim
+ np.sum(np.log(other.variance)) - np.sum(np.log(self.variance))
)
class SpikeAndSlabPosterior(VariationalPosterior):
'''
The SpikeAndSlab distribution for variational approximations.
@ -190,11 +193,11 @@ class SpikeAndSlabPosterior(VariationalPosterior):
else:
self.gamma = Param("binary_prob",binary_prob,Logistic(1e-10,1.-1e-10))
self.link_parameter(self.gamma)
def propogate_val(self):
if self.group_spike:
self.gamma.values[:] = self.gamma_group.values
def collate_gradient(self):
if self.group_spike:
self.gamma_group.gradient = self.gamma.gradient.reshape(self.gamma.shape).sum(axis=0)