mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-21 14:05:14 +02:00
chg: fixed naming in variational priors
:
This commit is contained in:
parent
59be79e962
commit
543a04dc1b
1 changed files with 10 additions and 7 deletions
|
|
@ -10,7 +10,7 @@ from .param import Param
|
||||||
from paramz.transformations import Logexp, Logistic,__fixed__
|
from paramz.transformations import Logexp, Logistic,__fixed__
|
||||||
|
|
||||||
class VariationalPrior(Parameterized):
|
class VariationalPrior(Parameterized):
|
||||||
def __init__(self, name='latent space', **kw):
|
def __init__(self, name='latent prior', **kw):
|
||||||
super(VariationalPrior, self).__init__(name=name, **kw)
|
super(VariationalPrior, self).__init__(name=name, **kw)
|
||||||
|
|
||||||
def KL_divergence(self, variational_posterior):
|
def KL_divergence(self, variational_posterior):
|
||||||
|
|
@ -23,6 +23,9 @@ class VariationalPrior(Parameterized):
|
||||||
raise NotImplementedError("override this for variational inference of latent space")
|
raise NotImplementedError("override this for variational inference of latent space")
|
||||||
|
|
||||||
class NormalPrior(VariationalPrior):
|
class NormalPrior(VariationalPrior):
|
||||||
|
def __init__(self, name='normal_prior', **kw):
|
||||||
|
super(VariationalPrior, self).__init__(name=name, **kw)
|
||||||
|
|
||||||
def KL_divergence(self, variational_posterior):
|
def KL_divergence(self, variational_posterior):
|
||||||
var_mean = np.square(variational_posterior.mean).sum()
|
var_mean = np.square(variational_posterior.mean).sum()
|
||||||
var_S = (variational_posterior.variance - np.log(variational_posterior.variance)).sum()
|
var_S = (variational_posterior.variance - np.log(variational_posterior.variance)).sum()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue