diagonal add kmm

This commit is contained in:
Max Zwiessele 2014-03-12 12:03:47 +00:00
parent 54239555a1
commit 5027e8e312

View file

@ -3,6 +3,7 @@
from posterior import Posterior from posterior import Posterior
from ...util.linalg import jitchol, backsub_both_sides, tdot, dtrtrs, dtrtri, dpotri, dpotrs, symmetrify from ...util.linalg import jitchol, backsub_both_sides, tdot, dtrtrs, dtrtri, dpotri, dpotrs, symmetrify
from ...util import diag
from ...core.parameterization.variational import VariationalPosterior from ...core.parameterization.variational import VariationalPosterior
import numpy as np import numpy as np
from ...util.misc import param_to_array from ...util.misc import param_to_array
@ -28,7 +29,7 @@ class VarDTC(object):
def set_limit(self, limit): def set_limit(self, limit):
self.get_trYYT.limit = limit self.get_trYYT.limit = limit
self.get_YYTfactor.limit = limit self.get_YYTfactor.limit = limit
def _get_trYYT(self, Y): def _get_trYYT(self, Y):
return param_to_array(np.sum(np.square(Y))) return param_to_array(np.sum(np.square(Y)))
@ -77,10 +78,10 @@ class VarDTC(object):
num_inducing = Z.shape[0] num_inducing = Z.shape[0]
num_data = Y.shape[0] num_data = Y.shape[0]
# kernel computations, using BGPLVM notation # kernel computations, using BGPLVM notation
Kmm = kern.K(Z) +np.eye(Z.shape[0]) * self.const_jitter
Lm = jitchol(Kmm+np.eye(Z.shape[0])*self.const_jitter) Kmm = kern.K(Z).copy()
diag.add(Kmm, self.const_jitter)
Lm = jitchol(Kmm)
# The rather complex computations of A # The rather complex computations of A
if uncertain_inputs: if uncertain_inputs:
@ -169,7 +170,6 @@ class VarDTC(object):
Bi, _ = dpotri(LB, lower=1) Bi, _ = dpotri(LB, lower=1)
symmetrify(Bi) symmetrify(Bi)
Bi = -dpotri(LB, lower=1)[0] Bi = -dpotri(LB, lower=1)[0]
from ...util import diag
diag.add(Bi, 1) diag.add(Bi, 1)
woodbury_inv = backsub_both_sides(Lm, Bi) woodbury_inv = backsub_both_sides(Lm, Bi)
@ -238,7 +238,8 @@ class VarDTCMissingData(object):
dL_dKmm = 0 dL_dKmm = 0
log_marginal = 0 log_marginal = 0
Kmm = kern.K(Z) Kmm = kern.K(Z).copy()
diag.add(Kmm, self.const_jitter)
#factor Kmm #factor Kmm
Lm = jitchol(Kmm) Lm = jitchol(Kmm)
if uncertain_inputs: LmInv = dtrtri(Lm) if uncertain_inputs: LmInv = dtrtri(Lm)
@ -324,7 +325,6 @@ class VarDTCMissingData(object):
Bi, _ = dpotri(LB, lower=1) Bi, _ = dpotri(LB, lower=1)
symmetrify(Bi) symmetrify(Bi)
Bi = -dpotri(LB, lower=1)[0] Bi = -dpotri(LB, lower=1)[0]
from ...util import diag
diag.add(Bi, 1) diag.add(Bi, 1)
woodbury_inv_all[:, :, ind] = backsub_both_sides(Lm, Bi)[:,:,None] woodbury_inv_all[:, :, ind] = backsub_both_sides(Lm, Bi)[:,:,None]