diff --git a/GPy/models/bayesian_gplvm.py b/GPy/models/bayesian_gplvm.py index 8f61a661..a55f7e60 100644 --- a/GPy/models/bayesian_gplvm.py +++ b/GPy/models/bayesian_gplvm.py @@ -57,11 +57,13 @@ class BayesianGPLVM(SparseGP): if np.any(np.isnan(Y)): from ..inference.latent_function_inference.var_dtc import VarDTCMissingData inference_method = VarDTCMissingData() - elif mpi_comm != None: + elif mpi_comm is not None: inference_method = VarDTC_minibatch(mpi_comm=mpi_comm) else: from ..inference.latent_function_inference.var_dtc import VarDTC inference_method = VarDTC() + if isinstance(inference_method,VarDTC_minibatch): + inference_method.mpi_comm = mpi_comm if kernel.useGPU and isinstance(inference_method, VarDTC_GPU): kernel.psicomp.GPU_direct = True