merged simple conflict£

This commit is contained in:
James Hensman 2014-02-05 16:26:10 +00:00
commit 432d9668a6
7 changed files with 381 additions and 14 deletions

View file

@ -4,6 +4,7 @@
from posterior import Posterior
from ...util.linalg import jitchol, backsub_both_sides, tdot, dtrtrs, dpotri, symmetrify
import numpy as np
from GPy.util.linalg import dtrtri
log_2_pi = np.log(2*np.pi)
class VarDTC(object):
@ -69,19 +70,24 @@ class VarDTC(object):
psi2_beta = (psi2 * (beta.flatten().reshape(num_data, 1, 1))).sum(0)
else:
psi2_beta = psi2.sum(0) * beta
evals, evecs = linalg.eigh(psi2_beta)
clipped_evals = np.clip(evals, 0., 1e6) # TODO: make clipping configurable
if not np.array_equal(evals, clipped_evals):
pass # print evals
tmp = evecs * np.sqrt(clipped_evals)
tmp = tmp.T
if 0:
evals, evecs = linalg.eigh(psi2_beta)
clipped_evals = np.clip(evals, 0., 1e6) # TODO: make clipping configurable
if not np.array_equal(evals, clipped_evals):
pass # print evals
tmp = evecs * np.sqrt(clipped_evals)
tmp = tmp.T
# no backsubstitution because of bound explosion on tr(A) if not...
LmInv, _ = dtrtri(Lm, lower=1)
A = LmInv.T.dot(psi2_beta.dot(LmInv))
print A.sum()
else:
if het_noise:
tmp = psi1 * (np.sqrt(beta.reshape(num_data, 1)))
else:
tmp = psi1 * (np.sqrt(beta))
tmp, _ = dtrtrs(Lm, np.asfortranarray(tmp.T), lower=1)
A = tdot(tmp)
tmp, _ = dtrtrs(Lm, np.asfortranarray(tmp.T), lower=1)
A = tdot(tmp)
# factor B
B = np.eye(num_inducing) + A