chg: fixed naming in variational priors

:
This commit is contained in:
Max Zwiessele 2016-10-03 12:19:24 +01:00
parent 59be79e962
commit 543a04dc1b

View file

@ -10,7 +10,7 @@ from .param import Param
from paramz.transformations import Logexp, Logistic,__fixed__ from paramz.transformations import Logexp, Logistic,__fixed__
class VariationalPrior(Parameterized): class VariationalPrior(Parameterized):
def __init__(self, name='latent space', **kw): def __init__(self, name='latent prior', **kw):
super(VariationalPrior, self).__init__(name=name, **kw) super(VariationalPrior, self).__init__(name=name, **kw)
def KL_divergence(self, variational_posterior): def KL_divergence(self, variational_posterior):
@ -23,6 +23,9 @@ class VariationalPrior(Parameterized):
raise NotImplementedError("override this for variational inference of latent space") raise NotImplementedError("override this for variational inference of latent space")
class NormalPrior(VariationalPrior): class NormalPrior(VariationalPrior):
def __init__(self, name='normal_prior', **kw):
super(VariationalPrior, self).__init__(name=name, **kw)
def KL_divergence(self, variational_posterior): def KL_divergence(self, variational_posterior):
var_mean = np.square(variational_posterior.mean).sum() var_mean = np.square(variational_posterior.mean).sum()
var_S = (variational_posterior.variance - np.log(variational_posterior.variance)).sum() var_S = (variational_posterior.variance - np.log(variational_posterior.variance)).sum()
@ -58,7 +61,7 @@ class SpikeAndSlabPrior(VariationalPrior):
pi = self.pi[idx] pi = self.pi[idx]
else: else:
pi = self.pi pi = self.pi
var_mean = np.square(mu)/self.variance var_mean = np.square(mu)/self.variance
var_S = (S/self.variance - np.log(S)) var_S = (S/self.variance - np.log(S))
var_gamma = (gamma*np.log(gamma/pi)).sum()+((1-gamma)*np.log((1-gamma)/(1-pi))).sum() var_gamma = (gamma*np.log(gamma/pi)).sum()+((1-gamma)*np.log((1-gamma)/(1-pi))).sum()
@ -163,12 +166,12 @@ class NormalPosterior(VariationalPosterior):
"""Compute the KL divergence to another NormalPosterior Object. This only holds, if the two NormalPosterior objects have the same shape, as we do computational tricks for the multivariate normal KL divergence. """Compute the KL divergence to another NormalPosterior Object. This only holds, if the two NormalPosterior objects have the same shape, as we do computational tricks for the multivariate normal KL divergence.
""" """
return .5*( return .5*(
np.sum(self.variance/other.variance) np.sum(self.variance/other.variance)
+ ((other.mean-self.mean)**2/other.variance).sum() + ((other.mean-self.mean)**2/other.variance).sum()
- self.num_data * self.input_dim - self.num_data * self.input_dim
+ np.sum(np.log(other.variance)) - np.sum(np.log(self.variance)) + np.sum(np.log(other.variance)) - np.sum(np.log(self.variance))
) )
class SpikeAndSlabPosterior(VariationalPosterior): class SpikeAndSlabPosterior(VariationalPosterior):
''' '''
The SpikeAndSlab distribution for variational approximations. The SpikeAndSlab distribution for variational approximations.
@ -190,11 +193,11 @@ class SpikeAndSlabPosterior(VariationalPosterior):
else: else:
self.gamma = Param("binary_prob",binary_prob,Logistic(1e-10,1.-1e-10)) self.gamma = Param("binary_prob",binary_prob,Logistic(1e-10,1.-1e-10))
self.link_parameter(self.gamma) self.link_parameter(self.gamma)
def propogate_val(self): def propogate_val(self):
if self.group_spike: if self.group_spike:
self.gamma.values[:] = self.gamma_group.values self.gamma.values[:] = self.gamma_group.values
def collate_gradient(self): def collate_gradient(self):
if self.group_spike: if self.group_spike:
self.gamma_group.gradient = self.gamma.gradient.reshape(self.gamma.shape).sum(axis=0) self.gamma_group.gradient = self.gamma.gradient.reshape(self.gamma.shape).sum(axis=0)