mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-05 14:55:15 +02:00
improve numerical stability of vardtc_parallel
This commit is contained in:
parent
80adaed616
commit
77a96efeba
3 changed files with 8 additions and 7 deletions
|
|
@ -169,14 +169,15 @@ class VarDTC_minibatch(LatentFunctionInference):
|
|||
Kmm = kern.K(Z).copy()
|
||||
diag.add(Kmm, self.const_jitter)
|
||||
Lm = jitchol(Kmm)
|
||||
|
||||
Lambda = Kmm+psi2_full
|
||||
|
||||
LmInvPsi2LmInvT = backsub_both_sides(Lm,psi2_full,transpose='right')
|
||||
Lambda = np.eye(Kmm.shape[0])+LmInvPsi2LmInvT
|
||||
LL = jitchol(Lambda)
|
||||
LL = np.dot(Lm,LL)
|
||||
b,_ = dtrtrs(LL, psi1Y_full.T)
|
||||
bbt = np.square(b).sum()
|
||||
v,_ = dtrtrs(LL.T,b,lower=False)
|
||||
vvt = np.einsum('md,od->mo',v,v)
|
||||
LmInvPsi2LmInvT = backsub_both_sides(Lm,psi2_full,transpose='right')
|
||||
|
||||
Psi2LLInvT = dtrtrs(LL,psi2_full)[0].T
|
||||
LmInvPsi2LLInvT= dtrtrs(Lm,Psi2LLInvT)[0]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue