diff --git a/GPy/core/sparse_gp_mpi.py b/GPy/core/sparse_gp_mpi.py index cecbe667..e7faf7a8 100644 --- a/GPy/core/sparse_gp_mpi.py +++ b/GPy/core/sparse_gp_mpi.py @@ -3,6 +3,7 @@ import numpy as np from sparse_gp import SparseGP +from numpy.linalg.linalg import LinAlgError from ..inference.latent_function_inference.var_dtc_parallel import update_gradients, VarDTC_minibatch import logging @@ -83,11 +84,7 @@ class SparseGP_MPI(SparseGP): if self.mpi_comm != None: if self._IN_OPTIMIZATION_ and self.mpi_comm.rank==0: self.mpi_comm.Bcast(np.int32(1),root=0) - self.mpi_comm.Bcast(p, root=0) - - from ..util.debug import checkFinite - checkFinite(p, 'optimizer_array') - + self.mpi_comm.Bcast(p, root=0) SparseGP.optimizer_array.fset(self,p) def optimize(self, optimizer=None, start=None, **kwargs): @@ -103,7 +100,13 @@ class SparseGP_MPI(SparseGP): while True: self.mpi_comm.Bcast(flag,root=0) if flag==1: - self.optimizer_array = x + try: + self.optimizer_array = x + self._fail_count = 0 + except (LinAlgError, ZeroDivisionError, ValueError): + if self._fail_count >= self._allowed_failures: + raise + self._fail_count += 1 elif flag==-1: break else: diff --git a/GPy/inference/latent_function_inference/var_dtc_parallel.py b/GPy/inference/latent_function_inference/var_dtc_parallel.py index 53b31dab..b9ecbb5c 100644 --- a/GPy/inference/latent_function_inference/var_dtc_parallel.py +++ b/GPy/inference/latent_function_inference/var_dtc_parallel.py @@ -167,18 +167,12 @@ class VarDTC_minibatch(LatentFunctionInference): # Compute Common Components #====================================================================== - from ...util.debug import checkFullRank - Kmm = kern.K(Z).copy() diag.add(Kmm, self.const_jitter) - r1 = checkFullRank(Kmm,name='Kmm') KmmInv,Lm,LmInv,_ = pdinv(Kmm) LmInvPsi2LmInvT = LmInv.dot(psi2_full).dot(LmInv.T) Lambda = np.eye(Kmm.shape[0])+LmInvPsi2LmInvT - r2 = checkFullRank(Lambda,name='Lambda') -# if (not r1) or (not r2): -# raise LInv,LL,LLInv,logdet_L = pdinv(Lambda) b = LLInv.dot(LmInv.dot(psi1Y_full.T)) bbt = np.square(b).sum() @@ -203,7 +197,6 @@ class VarDTC_minibatch(LatentFunctionInference): # Compute dL_dKmm #====================================================================== -# dL_dKmm = -(output_dim*np.einsum('md,od->mo',KmmInvPsi2LLInvT,KmmInvPsi2LLInvT) + vvt)/2. dL_dKmm = dL_dpsi2R - output_dim*KmmInv.dot(psi2_full).dot(KmmInv)/2. #====================================================================== @@ -336,13 +329,7 @@ def update_gradients(model, mpi_comm=None): Y = model.Y_local X = model.X[model.N_range[0]:model.N_range[1]] - try: - model._log_marginal_likelihood, dL_dKmm, model.posterior = model.inference_method.inference_likelihood(model.kern, X, model.Z, model.likelihood, Y) - except Exception: - if model.mpi_comm is None or model.mpi_comm.rank==0: - import time - model.pickle('model_'+str(int(time.time()))+'.pickle') - raise + model._log_marginal_likelihood, dL_dKmm, model.posterior = model.inference_method.inference_likelihood(model.kern, X, model.Z, model.likelihood, Y) het_noise = model.likelihood.variance.size > 1