mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-15 06:52:39 +02:00
remove the automatic importing mpi4py
This commit is contained in:
parent
a98bafb5b4
commit
cb1f6f1486
4 changed files with 18 additions and 24 deletions
|
|
@ -9,10 +9,6 @@ import numpy as np
|
|||
from . import LatentFunctionInference
|
||||
log_2_pi = np.log(2*np.pi)
|
||||
|
||||
try:
|
||||
from mpi4py import MPI
|
||||
except:
|
||||
pass
|
||||
|
||||
class VarDTC_minibatch(LatentFunctionInference):
|
||||
"""
|
||||
|
|
@ -123,6 +119,7 @@ class VarDTC_minibatch(LatentFunctionInference):
|
|||
YRY_full = trYYT*beta
|
||||
|
||||
if self.mpi_comm != None:
|
||||
from mpi4py import MPI
|
||||
psi0_all = np.array(psi0_full)
|
||||
psi1Y_all = psi1Y_full.copy()
|
||||
psi2_all = psi2_full.copy()
|
||||
|
|
@ -146,6 +143,7 @@ class VarDTC_minibatch(LatentFunctionInference):
|
|||
num_data, output_dim = Y.shape
|
||||
input_dim = Z.shape[0]
|
||||
if self.mpi_comm != None:
|
||||
from mpi4py import MPI
|
||||
num_data_all = np.array(num_data,dtype=np.int32)
|
||||
self.mpi_comm.Allreduce([np.int32(num_data), MPI.INT], [num_data_all, MPI.INT])
|
||||
num_data = num_data_all
|
||||
|
|
@ -387,6 +385,7 @@ def update_gradients(model, mpi_comm=None):
|
|||
|
||||
# Gather the gradients from multiple MPI nodes
|
||||
if mpi_comm != None:
|
||||
from mpi4py import MPI
|
||||
if het_noise:
|
||||
raise "het_noise not implemented!"
|
||||
kern_grad_all = kern_grad.copy()
|
||||
|
|
@ -409,6 +408,7 @@ def update_gradients(model, mpi_comm=None):
|
|||
model.variational_prior.update_gradients_KL(X)
|
||||
|
||||
if mpi_comm != None:
|
||||
from mpi4py import MPI
|
||||
KL_div_all = np.array(KL_div)
|
||||
mpi_comm.Allreduce([np.float64(KL_div), MPI.DOUBLE], [KL_div_all, MPI.DOUBLE])
|
||||
KL_div = KL_div_all
|
||||
|
|
@ -468,6 +468,7 @@ def update_gradients_sparsegp(model, mpi_comm=None):
|
|||
|
||||
# Gather the gradients from multiple MPI nodes
|
||||
if mpi_comm != None:
|
||||
from mpi4py import MPI
|
||||
if het_noise:
|
||||
raise "het_noise not implemented!"
|
||||
kern_grad_all = kern_grad.copy()
|
||||
|
|
|
|||
|
|
@ -148,7 +148,11 @@ def x_frame1D(X,plot_limits=None,resolution=None):
|
|||
"""
|
||||
assert X.shape[1] ==1, "x_frame1D is defined for one-dimensional inputs"
|
||||
if plot_limits is None:
|
||||
xmin,xmax = X.min(0),X.max(0)
|
||||
from ...core.parameterization.variational import VariationalPosterior
|
||||
if isinstance(X, VariationalPosterior):
|
||||
xmin,xmax = X.mean.min(0),X.mean.max(0)
|
||||
else:
|
||||
xmin,xmax = X.min(0),X.max(0)
|
||||
xmin, xmax = xmin-0.2*(xmax-xmin), xmax+0.2*(xmax-xmin)
|
||||
elif len(plot_limits)==2:
|
||||
xmin, xmax = plot_limits
|
||||
|
|
|
|||
|
|
@ -2,25 +2,14 @@
|
|||
The module of tools for parallelization (MPI)
|
||||
"""
|
||||
import numpy as np
|
||||
try:
|
||||
from mpi4py import MPI
|
||||
def get_id_within_node(comm=MPI.COMM_WORLD):
|
||||
rank = comm.rank
|
||||
nodename = MPI.Get_processor_name()
|
||||
nodelist = comm.allgather(nodename)
|
||||
return len([i for i in nodelist[:rank] if i==nodename])
|
||||
|
||||
numpy_to_MPI_typemap = {
|
||||
np.dtype(np.float64) : MPI.DOUBLE,
|
||||
np.dtype(np.float32) : MPI.FLOAT,
|
||||
np.dtype(np.int) : MPI.INT,
|
||||
np.dtype(np.int8) : MPI.CHAR,
|
||||
np.dtype(np.uint8) : MPI.UNSIGNED_CHAR,
|
||||
np.dtype(np.int32) : MPI.INT,
|
||||
np.dtype(np.uint32) : MPI.UNSIGNED_INT,
|
||||
}
|
||||
except:
|
||||
pass
|
||||
def get_id_within_node(comm=None):
|
||||
from mpi4py import MPI
|
||||
if comm is None: comm = MPI.COMM_WORLD
|
||||
rank = comm.rank
|
||||
nodename = MPI.Get_processor_name()
|
||||
nodelist = comm.allgather(nodename)
|
||||
return len([i for i in nodelist[:rank] if i==nodename])
|
||||
|
||||
def divide_data(datanum, rank, size):
|
||||
assert rank<size and datanum>0
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -88,7 +88,7 @@ setup(name = 'GPy',
|
|||
test_suite = 'GPy.testing',
|
||||
long_description=read_to_rst('README.md'),
|
||||
install_requires=['numpy>=1.7', 'scipy>=0.16', 'six'],
|
||||
extras_require = {'docs':['matplotlib >=1.3','Sphinx','IPython']},
|
||||
extras_require = {'docs':['matplotlib >=1.3','Sphinx','IPython'],'optional':['mpi4py']},
|
||||
classifiers=['License :: OSI Approved :: BSD License',
|
||||
'Natural Language :: English',
|
||||
'Operating System :: MacOS :: MacOS X',
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue