[paramz] fully integrated all tests running

This commit is contained in:
mzwiessele 2015-10-15 14:59:57 +01:00
parent e49c75ce2e
commit dce82847a7
78 changed files with 1581 additions and 1222 deletions

View file

@ -2,7 +2,7 @@
# Licensed under the BSD 3-clause license (see LICENSE.txt)
import numpy as np
from ...util.linalg import jitchol, DSYR, dtrtrs, dtrtri
from ...core.parameterization.observable_array import ObsAr
from paramz import ObsAr
from . import ExactGaussianInference, VarDTC
from ...util import diag

View file

@ -2,10 +2,9 @@
# Licensed under the BSD 3-clause license (see LICENSE.txt)
import numpy as np
from ...core import Model
from ...core.parameterization import variational
from ...core import ProbabilisticModel
from ...core import variational
from ...util.linalg import tdot
from GPy.core.parameterization.variational import VariationalPosterior
def infer_newX(model, Y_new, optimize=True, init='L2'):
"""
@ -27,7 +26,7 @@ def infer_newX(model, Y_new, optimize=True, init='L2'):
return infr_m.X, infr_m
class InferenceX(Model):
class InferenceX(ProbabilisticModel):
"""
The model class for inference of new X with given new Y. (replacing the "do_test_latent" in Bayesian GPLVM)
It is a tiny inference model created from the original GP model. The kernel, likelihood (only Gaussian is supported at the moment)
@ -62,14 +61,12 @@ class InferenceX(Model):
# self.kern.GPU(True)
from copy import deepcopy
self.posterior = deepcopy(model.posterior)
from ...core.parameterization.variational import VariationalPosterior
if isinstance(model.X, VariationalPosterior):
if isinstance(model.X, variational.VariationalPosterior):
self.uncertain_input = True
from ...models.ss_gplvm import IBPPrior
from ...models.ss_mrd import IBPPrior_SSMRD
if isinstance(model.variational_prior, IBPPrior) or isinstance(model.variational_prior, IBPPrior_SSMRD):
from ...core.parameterization.variational import SpikeAndSlabPrior
self.variational_prior = SpikeAndSlabPrior(pi=0.5, learnPi=False, group_spike=False)
self.variational_prior = variational.SpikeAndSlabPrior(pi=0.5, learnPi=False, group_spike=False)
else:
self.variational_prior = model.variational_prior.copy()
else:
@ -105,17 +102,16 @@ class InferenceX(Model):
idx = dist.argmin(axis=1)
from ...models import SSGPLVM
from ...util.misc import param_to_array
if isinstance(model, SSGPLVM):
X = variational.SpikeAndSlabPosterior(param_to_array(model.X.mean[idx]), param_to_array(model.X.variance[idx]), param_to_array(model.X.gamma[idx]))
X = variational.SpikeAndSlabPosterior((model.X.mean[idx].values), (model.X.variance[idx].values), (model.X.gamma[idx].values))
if model.group_spike:
X.gamma.fix()
else:
if self.uncertain_input and self.sparse_gp:
X = variational.NormalPosterior(param_to_array(model.X.mean[idx]), param_to_array(model.X.variance[idx]))
X = variational.NormalPosterior((model.X.mean[idx].values), (model.X.variance[idx].values))
else:
from ...core import Param
X = Param('latent mean',param_to_array(model.X[idx]).copy())
X = Param('latent mean',(model.X[idx].values).copy())
return X
@ -160,8 +156,7 @@ class InferenceX(Model):
self.X.gradient = X_grad
if self.uncertain_input:
from ...core.parameterization.variational import SpikeAndSlabPrior
if isinstance(self.variational_prior, SpikeAndSlabPrior):
if isinstance(self.variational_prior, variational.SpikeAndSlabPrior):
# Update Log-likelihood
KL_div = self.variational_prior.KL_divergence(self.X)
# update for the KL divergence

View file

@ -4,7 +4,7 @@
from .posterior import Posterior
from ...util.linalg import mdot, jitchol, backsub_both_sides, tdot, dtrtrs, dtrtri, dpotri, dpotrs, symmetrify
from ...util import diag
from ...core.parameterization.variational import VariationalPosterior
from ...core.variational import VariationalPosterior
import numpy as np
from . import LatentFunctionInference
log_2_pi = np.log(2*np.pi)
@ -23,8 +23,7 @@ class VarDTC(LatentFunctionInference):
"""
const_jitter = 1e-8
def __init__(self, limit=1):
#self._YYTfactor_cache = caching.cache()
from ...util.caching import Cacher
from paramz.caching import Cacher
self.limit = limit
self.get_trYYT = Cacher(self._get_trYYT, limit)
self.get_YYTfactor = Cacher(self._get_YYTfactor, limit)
@ -45,7 +44,7 @@ class VarDTC(LatentFunctionInference):
def __setstate__(self, state):
# has to be overridden, as Cacher objects cannot be pickled.
self.limit = state
from ...util.caching import Cacher
from paramz.caching import Cacher
self.get_trYYT = Cacher(self._get_trYYT, self.limit)
self.get_YYTfactor = Cacher(self._get_YYTfactor, self.limit)

View file

@ -4,7 +4,7 @@
from .posterior import Posterior
from ...util.linalg import jitchol, backsub_both_sides, tdot, dtrtrs, dtrtri,pdinv
from ...util import diag
from ...core.parameterization.variational import VariationalPosterior
from ...core.variational import VariationalPosterior
import numpy as np
from . import LatentFunctionInference
log_2_pi = np.log(2*np.pi)