[bgplvm] technical new stuff

This commit is contained in:
mzwiessele 2015-07-29 10:48:05 +02:00
parent 4ca4916cc0
commit fca2440943
5 changed files with 15 additions and 5 deletions

View file

@ -7,6 +7,6 @@ from .parameterization.param import Param, ParamConcatenation
from .parameterization.observable_array import ObsAr
from .gp import GP
from .svgp import SVGP
#from .svgp import SVGP
from .sparse_gp import SparseGP
from .mapping import *

View file

@ -161,6 +161,16 @@ class NormalPosterior(VariationalPosterior):
from ...plotting.matplot_dep import variational_plots
return variational_plots.plot(self, *args, **kwargs)
def KL(self, other):
"""Compute the KL divergence to another NormalPosterior Object. This only holds, if the two NormalPosterior objects have the same shape, as we do computational tricks for the multivariate normal KL divergence.
"""
return .5*(
np.sum(self.variance/other.variance)
+ ((other.mean-self.mean)**2/other.variance).sum()
- self.num_data * self.input_dim
+ np.sum(np.log(other.variance)) - np.sum(np.log(self.variance))
)
class SpikeAndSlabPosterior(VariationalPosterior):
'''
The SpikeAndSlab distribution for variational approximations.

View file

@ -69,7 +69,7 @@ from .expectation_propagation_dtc import EPDTC
from .dtc import DTC
from .fitc import FITC
from .var_dtc_parallel import VarDTC_minibatch
from .svgp import SVGP
#from .svgp import SVGP
# class FullLatentFunctionData(object):
#

View file

@ -72,10 +72,10 @@ class SparseGPStochastics(StochasticStorage):
bdict = {}
for d in self.d:
inan = np.isnan(self.Y[:, d])
arr_str = np.array2string(inan,
arr_str = int(np.array2string(inan,
np.inf, 0,
True, '',
formatter={'bool':lambda x: '1' if x else '0'})
formatter={'bool':lambda x: '1' if x else '0'}), 2)
try:
bdict[arr_str][0].append(d)
except:

View file

@ -43,7 +43,7 @@ def plot(parameterized, fignum=None, ax=None, colors=None, figsize=(12, 6)):
if i < means.shape[1] - 1:
a.set_xticklabels('')
pb.draw()
fig.tight_layout(h_pad=.01) # , rect=(0, 0, 1, .95))
a.figure.tight_layout(h_pad=.01) # , rect=(0, 0, 1, .95))
return dict(lines=lines, fills=fills, bg_lines=bg_lines)
def plot_SpikeSlab(parameterized, fignum=None, ax=None, colors=None, side_by_side=True):