mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-05 01:32:40 +02:00
[SSGPLVM] Learn prior parameters
This commit is contained in:
parent
0f6004034e
commit
0258abf5c4
2 changed files with 8 additions and 3 deletions
|
|
@ -34,12 +34,12 @@ class NormalPrior(VariationalPrior):
|
|||
variational_posterior.variance.gradient -= (1. - (1. / (variational_posterior.variance))) * 0.5
|
||||
|
||||
class SpikeAndSlabPrior(VariationalPrior):
|
||||
def __init__(self, variance = 1.0, pi = 0.5, name='SpikeAndSlabPrior', **kw):
|
||||
def __init__(self, pi, variance = 1.0, name='SpikeAndSlabPrior', **kw):
|
||||
super(VariationalPrior, self).__init__(name=name, **kw)
|
||||
assert variance==1.0, "Not Implemented!"
|
||||
self.pi = Param('pi', pi)
|
||||
self.variance = Param('variance',variance)
|
||||
self.add_parameters(self.pi, self.variance)
|
||||
self.add_parameters(self.pi)
|
||||
|
||||
def KL_divergence(self, variational_posterior):
|
||||
mu = variational_posterior.mean
|
||||
|
|
@ -58,6 +58,8 @@ class SpikeAndSlabPrior(VariationalPrior):
|
|||
gamma.gradient -= np.log((1-self.pi)/self.pi*gamma/(1.-gamma))+(np.square(mu)+S-np.log(S)-1.)/2.
|
||||
mu.gradient -= gamma*mu
|
||||
S.gradient -= (1. - (1. / (S))) * gamma /2.
|
||||
self.pi.gradient = (gamma/self.pi - (1.-gamma)/(1.-self.pi)).sum(axis=0)
|
||||
|
||||
|
||||
|
||||
class VariationalPosterior(Parameterized):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue