improve numerical stability of vardtc_parallel

This commit is contained in:
Zhenwen Dai 2014-08-26 16:38:00 +01:00
parent 80adaed616
commit 77a96efeba
3 changed files with 8 additions and 7 deletions

View file

@ -169,14 +169,15 @@ class VarDTC_minibatch(LatentFunctionInference):
Kmm = kern.K(Z).copy()
diag.add(Kmm, self.const_jitter)
Lm = jitchol(Kmm)
Lambda = Kmm+psi2_full
LmInvPsi2LmInvT = backsub_both_sides(Lm,psi2_full,transpose='right')
Lambda = np.eye(Kmm.shape[0])+LmInvPsi2LmInvT
LL = jitchol(Lambda)
LL = np.dot(Lm,LL)
b,_ = dtrtrs(LL, psi1Y_full.T)
bbt = np.square(b).sum()
v,_ = dtrtrs(LL.T,b,lower=False)
vvt = np.einsum('md,od->mo',v,v)
LmInvPsi2LmInvT = backsub_both_sides(Lm,psi2_full,transpose='right')
Psi2LLInvT = dtrtrs(LL,psi2_full)[0].T
LmInvPsi2LLInvT= dtrtrs(Lm,Psi2LLInvT)[0]