finish the debug of sparsegp_mpi

This commit is contained in:
Zhenwen Dai 2014-09-12 10:36:01 +01:00
parent 31f11eda40
commit d7eee6aa00
2 changed files with 10 additions and 20 deletions

View file

@ -167,18 +167,12 @@ class VarDTC_minibatch(LatentFunctionInference):
# Compute Common Components
#======================================================================
from ...util.debug import checkFullRank
Kmm = kern.K(Z).copy()
diag.add(Kmm, self.const_jitter)
r1 = checkFullRank(Kmm,name='Kmm')
KmmInv,Lm,LmInv,_ = pdinv(Kmm)
LmInvPsi2LmInvT = LmInv.dot(psi2_full).dot(LmInv.T)
Lambda = np.eye(Kmm.shape[0])+LmInvPsi2LmInvT
r2 = checkFullRank(Lambda,name='Lambda')
# if (not r1) or (not r2):
# raise
LInv,LL,LLInv,logdet_L = pdinv(Lambda)
b = LLInv.dot(LmInv.dot(psi1Y_full.T))
bbt = np.square(b).sum()
@ -203,7 +197,6 @@ class VarDTC_minibatch(LatentFunctionInference):
# Compute dL_dKmm
#======================================================================
# dL_dKmm = -(output_dim*np.einsum('md,od->mo',KmmInvPsi2LLInvT,KmmInvPsi2LLInvT) + vvt)/2.
dL_dKmm = dL_dpsi2R - output_dim*KmmInv.dot(psi2_full).dot(KmmInv)/2.
#======================================================================
@ -336,13 +329,7 @@ def update_gradients(model, mpi_comm=None):
Y = model.Y_local
X = model.X[model.N_range[0]:model.N_range[1]]
try:
model._log_marginal_likelihood, dL_dKmm, model.posterior = model.inference_method.inference_likelihood(model.kern, X, model.Z, model.likelihood, Y)
except Exception:
if model.mpi_comm is None or model.mpi_comm.rank==0:
import time
model.pickle('model_'+str(int(time.time()))+'.pickle')
raise
model._log_marginal_likelihood, dL_dKmm, model.posterior = model.inference_method.inference_likelihood(model.kern, X, model.Z, model.likelihood, Y)
het_noise = model.likelihood.variance.size > 1