adapt the numerical stability strategy from VarDTC to VarDTC_minibatch

This commit is contained in:
Zhenwen Dai 2014-09-09 11:46:19 +01:00
parent 4e83501ea5
commit 0f47a6b35f

View file

@ -2,7 +2,7 @@
# Licensed under the BSD 3-clause license (see LICENSE.txt)
from posterior import Posterior
from ...util.linalg import jitchol, backsub_both_sides, tdot, dtrtrs, dtrtri
from ...util.linalg import jitchol, backsub_both_sides, tdot, dtrtrs, dtrtri,pdinv
from ...util import diag
from ...core.parameterization.variational import VariationalPosterior
import numpy as np
@ -144,6 +144,7 @@ class VarDTC_minibatch(LatentFunctionInference):
"""
num_data, output_dim = Y.shape
input_dim = Z.shape[0]
if self.mpi_comm != None:
num_data_all = np.array(num_data,dtype=np.int32)
self.mpi_comm.Allreduce([np.int32(num_data), MPI.INT], [num_data_all, MPI.INT])
@ -171,28 +172,19 @@ class VarDTC_minibatch(LatentFunctionInference):
Kmm = kern.K(Z).copy()
diag.add(Kmm, self.const_jitter)
r1 = checkFullRank(Kmm,name='Kmm')
Lm = jitchol(Kmm)
LmInv = dtrtri(Lm)
KmmInv,Lm,LmInv,_ = pdinv(Kmm)
#LmInvPsi2LmInvT = LmInv.dot(psi2_full).dot(LmInv.T)
LmInvPsi2LmInvT = backsub_both_sides(Lm,psi2_full,transpose='right')
LmInvPsi2LmInvT = LmInv.dot(psi2_full).dot(LmInv.T)
Lambda = np.eye(Kmm.shape[0])+LmInvPsi2LmInvT
r2 = checkFullRank(Lambda,name='Lambda')
if (not r1) or (not r2):
raise
LL = jitchol(Lambda)
LL = np.dot(Lm,LL)
b,_ = dtrtrs(LL, psi1Y_full.T)
# if (not r1) or (not r2):
# raise
LInv,LL,LLInv,logdet_L = pdinv(Lambda)
b = LLInv.dot(LmInv.dot(psi1Y_full.T))
bbt = np.square(b).sum()
v,_ = dtrtrs(LL.T,b,lower=False)
vvt = np.einsum('md,od->mo',v,v)
v = LmInv.T.dot(LLInv.T.dot(b))
Psi2LLInvT = dtrtrs(LL,psi2_full)[0].T
LmInvPsi2LLInvT= dtrtrs(Lm,Psi2LLInvT)[0]
KmmInvPsi2LLInvT = dtrtrs(Lm,LmInvPsi2LLInvT,trans=True)[0]
KmmInvPsi2P = dtrtrs(LL,KmmInvPsi2LLInvT.T, trans=True)[0].T
dL_dpsi2R = (output_dim*KmmInvPsi2P - vvt)/2. # dL_dpsi2 with R inside psi2
dL_dpsi2R = LmInv.T.dot(-LLInv.T.dot(tdot(b)+output_dim*np.eye(input_dim)).dot(LLInv)+output_dim*np.eye(input_dim)).dot(LmInv)/2.
# Cache intermediate results
self.midRes['dL_dpsi2R'] = dL_dpsi2R
@ -205,20 +197,21 @@ class VarDTC_minibatch(LatentFunctionInference):
logL_R = -np.log(beta).sum()
else:
logL_R = -num_data*np.log(beta)
logL = -(output_dim*(num_data*log_2_pi+logL_R+psi0_full-np.trace(LmInvPsi2LmInvT))+YRY_full-bbt)/2.-output_dim*(-np.log(np.diag(Lm)).sum()+np.log(np.diag(LL)).sum())
logL = -(output_dim*(num_data*log_2_pi+logL_R+psi0_full-np.trace(LmInvPsi2LmInvT))+YRY_full-bbt)/2.-output_dim*logdet_L/2.
#======================================================================
# Compute dL_dKmm
#======================================================================
dL_dKmm = -(output_dim*np.einsum('md,od->mo',KmmInvPsi2LLInvT,KmmInvPsi2LLInvT) + vvt)/2.
# dL_dKmm = -(output_dim*np.einsum('md,od->mo',KmmInvPsi2LLInvT,KmmInvPsi2LLInvT) + vvt)/2.
dL_dKmm = dL_dpsi2R - KmmInv.dot(psi2_full).dot(KmmInv)/2.
#======================================================================
# Compute the Posterior distribution of inducing points p(u|Y)
#======================================================================
if not self.Y_speedup or het_noise:
post = Posterior(woodbury_inv=KmmInvPsi2P, woodbury_vector=v, K=Kmm, mean=None, cov=None, K_chol=Lm)
post = Posterior(woodbury_inv=LmInv.T.dot(np.eye(input_dim)-LInv).dot(LmInv), woodbury_vector=v, K=Kmm, mean=None, cov=None, K_chol=Lm)
else:
post = None