mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-15 06:52:39 +02:00
remove the automatic importing mpi4py
This commit is contained in:
parent
a98bafb5b4
commit
cb1f6f1486
4 changed files with 18 additions and 24 deletions
|
|
@ -9,10 +9,6 @@ import numpy as np
|
|||
from . import LatentFunctionInference
|
||||
log_2_pi = np.log(2*np.pi)
|
||||
|
||||
try:
|
||||
from mpi4py import MPI
|
||||
except:
|
||||
pass
|
||||
|
||||
class VarDTC_minibatch(LatentFunctionInference):
|
||||
"""
|
||||
|
|
@ -123,6 +119,7 @@ class VarDTC_minibatch(LatentFunctionInference):
|
|||
YRY_full = trYYT*beta
|
||||
|
||||
if self.mpi_comm != None:
|
||||
from mpi4py import MPI
|
||||
psi0_all = np.array(psi0_full)
|
||||
psi1Y_all = psi1Y_full.copy()
|
||||
psi2_all = psi2_full.copy()
|
||||
|
|
@ -146,6 +143,7 @@ class VarDTC_minibatch(LatentFunctionInference):
|
|||
num_data, output_dim = Y.shape
|
||||
input_dim = Z.shape[0]
|
||||
if self.mpi_comm != None:
|
||||
from mpi4py import MPI
|
||||
num_data_all = np.array(num_data,dtype=np.int32)
|
||||
self.mpi_comm.Allreduce([np.int32(num_data), MPI.INT], [num_data_all, MPI.INT])
|
||||
num_data = num_data_all
|
||||
|
|
@ -387,6 +385,7 @@ def update_gradients(model, mpi_comm=None):
|
|||
|
||||
# Gather the gradients from multiple MPI nodes
|
||||
if mpi_comm != None:
|
||||
from mpi4py import MPI
|
||||
if het_noise:
|
||||
raise "het_noise not implemented!"
|
||||
kern_grad_all = kern_grad.copy()
|
||||
|
|
@ -409,6 +408,7 @@ def update_gradients(model, mpi_comm=None):
|
|||
model.variational_prior.update_gradients_KL(X)
|
||||
|
||||
if mpi_comm != None:
|
||||
from mpi4py import MPI
|
||||
KL_div_all = np.array(KL_div)
|
||||
mpi_comm.Allreduce([np.float64(KL_div), MPI.DOUBLE], [KL_div_all, MPI.DOUBLE])
|
||||
KL_div = KL_div_all
|
||||
|
|
@ -468,6 +468,7 @@ def update_gradients_sparsegp(model, mpi_comm=None):
|
|||
|
||||
# Gather the gradients from multiple MPI nodes
|
||||
if mpi_comm != None:
|
||||
from mpi4py import MPI
|
||||
if het_noise:
|
||||
raise "het_noise not implemented!"
|
||||
kern_grad_all = kern_grad.copy()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue