remove the automatic importing mpi4py

This commit is contained in:
Zhenwen Dai 2015-09-23 16:25:39 +01:00
parent a98bafb5b4
commit cb1f6f1486
4 changed files with 18 additions and 24 deletions

View file

@ -9,10 +9,6 @@ import numpy as np
from . import LatentFunctionInference
log_2_pi = np.log(2*np.pi)
try:
from mpi4py import MPI
except:
pass
class VarDTC_minibatch(LatentFunctionInference):
"""
@ -123,6 +119,7 @@ class VarDTC_minibatch(LatentFunctionInference):
YRY_full = trYYT*beta
if self.mpi_comm != None:
from mpi4py import MPI
psi0_all = np.array(psi0_full)
psi1Y_all = psi1Y_full.copy()
psi2_all = psi2_full.copy()
@ -146,6 +143,7 @@ class VarDTC_minibatch(LatentFunctionInference):
num_data, output_dim = Y.shape
input_dim = Z.shape[0]
if self.mpi_comm != None:
from mpi4py import MPI
num_data_all = np.array(num_data,dtype=np.int32)
self.mpi_comm.Allreduce([np.int32(num_data), MPI.INT], [num_data_all, MPI.INT])
num_data = num_data_all
@ -387,6 +385,7 @@ def update_gradients(model, mpi_comm=None):
# Gather the gradients from multiple MPI nodes
if mpi_comm != None:
from mpi4py import MPI
if het_noise:
raise "het_noise not implemented!"
kern_grad_all = kern_grad.copy()
@ -409,6 +408,7 @@ def update_gradients(model, mpi_comm=None):
model.variational_prior.update_gradients_KL(X)
if mpi_comm != None:
from mpi4py import MPI
KL_div_all = np.array(KL_div)
mpi_comm.Allreduce([np.float64(KL_div), MPI.DOUBLE], [KL_div_all, MPI.DOUBLE])
KL_div = KL_div_all
@ -468,6 +468,7 @@ def update_gradients_sparsegp(model, mpi_comm=None):
# Gather the gradients from multiple MPI nodes
if mpi_comm != None:
from mpi4py import MPI
if het_noise:
raise "het_noise not implemented!"
kern_grad_all = kern_grad.copy()