mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-21 14:05:14 +02:00
chg: fixed naming in variational priors
:
This commit is contained in:
parent
59be79e962
commit
543a04dc1b
1 changed files with 10 additions and 7 deletions
|
|
@ -10,7 +10,7 @@ from .param import Param
|
||||||
from paramz.transformations import Logexp, Logistic,__fixed__
|
from paramz.transformations import Logexp, Logistic,__fixed__
|
||||||
|
|
||||||
class VariationalPrior(Parameterized):
|
class VariationalPrior(Parameterized):
|
||||||
def __init__(self, name='latent space', **kw):
|
def __init__(self, name='latent prior', **kw):
|
||||||
super(VariationalPrior, self).__init__(name=name, **kw)
|
super(VariationalPrior, self).__init__(name=name, **kw)
|
||||||
|
|
||||||
def KL_divergence(self, variational_posterior):
|
def KL_divergence(self, variational_posterior):
|
||||||
|
|
@ -23,6 +23,9 @@ class VariationalPrior(Parameterized):
|
||||||
raise NotImplementedError("override this for variational inference of latent space")
|
raise NotImplementedError("override this for variational inference of latent space")
|
||||||
|
|
||||||
class NormalPrior(VariationalPrior):
|
class NormalPrior(VariationalPrior):
|
||||||
|
def __init__(self, name='normal_prior', **kw):
|
||||||
|
super(VariationalPrior, self).__init__(name=name, **kw)
|
||||||
|
|
||||||
def KL_divergence(self, variational_posterior):
|
def KL_divergence(self, variational_posterior):
|
||||||
var_mean = np.square(variational_posterior.mean).sum()
|
var_mean = np.square(variational_posterior.mean).sum()
|
||||||
var_S = (variational_posterior.variance - np.log(variational_posterior.variance)).sum()
|
var_S = (variational_posterior.variance - np.log(variational_posterior.variance)).sum()
|
||||||
|
|
@ -58,7 +61,7 @@ class SpikeAndSlabPrior(VariationalPrior):
|
||||||
pi = self.pi[idx]
|
pi = self.pi[idx]
|
||||||
else:
|
else:
|
||||||
pi = self.pi
|
pi = self.pi
|
||||||
|
|
||||||
var_mean = np.square(mu)/self.variance
|
var_mean = np.square(mu)/self.variance
|
||||||
var_S = (S/self.variance - np.log(S))
|
var_S = (S/self.variance - np.log(S))
|
||||||
var_gamma = (gamma*np.log(gamma/pi)).sum()+((1-gamma)*np.log((1-gamma)/(1-pi))).sum()
|
var_gamma = (gamma*np.log(gamma/pi)).sum()+((1-gamma)*np.log((1-gamma)/(1-pi))).sum()
|
||||||
|
|
@ -163,12 +166,12 @@ class NormalPosterior(VariationalPosterior):
|
||||||
"""Compute the KL divergence to another NormalPosterior Object. This only holds, if the two NormalPosterior objects have the same shape, as we do computational tricks for the multivariate normal KL divergence.
|
"""Compute the KL divergence to another NormalPosterior Object. This only holds, if the two NormalPosterior objects have the same shape, as we do computational tricks for the multivariate normal KL divergence.
|
||||||
"""
|
"""
|
||||||
return .5*(
|
return .5*(
|
||||||
np.sum(self.variance/other.variance)
|
np.sum(self.variance/other.variance)
|
||||||
+ ((other.mean-self.mean)**2/other.variance).sum()
|
+ ((other.mean-self.mean)**2/other.variance).sum()
|
||||||
- self.num_data * self.input_dim
|
- self.num_data * self.input_dim
|
||||||
+ np.sum(np.log(other.variance)) - np.sum(np.log(self.variance))
|
+ np.sum(np.log(other.variance)) - np.sum(np.log(self.variance))
|
||||||
)
|
)
|
||||||
|
|
||||||
class SpikeAndSlabPosterior(VariationalPosterior):
|
class SpikeAndSlabPosterior(VariationalPosterior):
|
||||||
'''
|
'''
|
||||||
The SpikeAndSlab distribution for variational approximations.
|
The SpikeAndSlab distribution for variational approximations.
|
||||||
|
|
@ -190,11 +193,11 @@ class SpikeAndSlabPosterior(VariationalPosterior):
|
||||||
else:
|
else:
|
||||||
self.gamma = Param("binary_prob",binary_prob,Logistic(1e-10,1.-1e-10))
|
self.gamma = Param("binary_prob",binary_prob,Logistic(1e-10,1.-1e-10))
|
||||||
self.link_parameter(self.gamma)
|
self.link_parameter(self.gamma)
|
||||||
|
|
||||||
def propogate_val(self):
|
def propogate_val(self):
|
||||||
if self.group_spike:
|
if self.group_spike:
|
||||||
self.gamma.values[:] = self.gamma_group.values
|
self.gamma.values[:] = self.gamma_group.values
|
||||||
|
|
||||||
def collate_gradient(self):
|
def collate_gradient(self):
|
||||||
if self.group_spike:
|
if self.group_spike:
|
||||||
self.gamma_group.gradient = self.gamma.gradient.reshape(self.gamma.shape).sum(axis=0)
|
self.gamma_group.gradient = self.gamma.gradient.reshape(self.gamma.shape).sum(axis=0)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue