mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-15 06:52:39 +02:00
remove the automatic importing mpi4py
This commit is contained in:
parent
a98bafb5b4
commit
cb1f6f1486
4 changed files with 18 additions and 24 deletions
|
|
@ -9,10 +9,6 @@ import numpy as np
|
||||||
from . import LatentFunctionInference
|
from . import LatentFunctionInference
|
||||||
log_2_pi = np.log(2*np.pi)
|
log_2_pi = np.log(2*np.pi)
|
||||||
|
|
||||||
try:
|
|
||||||
from mpi4py import MPI
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
class VarDTC_minibatch(LatentFunctionInference):
|
class VarDTC_minibatch(LatentFunctionInference):
|
||||||
"""
|
"""
|
||||||
|
|
@ -123,6 +119,7 @@ class VarDTC_minibatch(LatentFunctionInference):
|
||||||
YRY_full = trYYT*beta
|
YRY_full = trYYT*beta
|
||||||
|
|
||||||
if self.mpi_comm != None:
|
if self.mpi_comm != None:
|
||||||
|
from mpi4py import MPI
|
||||||
psi0_all = np.array(psi0_full)
|
psi0_all = np.array(psi0_full)
|
||||||
psi1Y_all = psi1Y_full.copy()
|
psi1Y_all = psi1Y_full.copy()
|
||||||
psi2_all = psi2_full.copy()
|
psi2_all = psi2_full.copy()
|
||||||
|
|
@ -146,6 +143,7 @@ class VarDTC_minibatch(LatentFunctionInference):
|
||||||
num_data, output_dim = Y.shape
|
num_data, output_dim = Y.shape
|
||||||
input_dim = Z.shape[0]
|
input_dim = Z.shape[0]
|
||||||
if self.mpi_comm != None:
|
if self.mpi_comm != None:
|
||||||
|
from mpi4py import MPI
|
||||||
num_data_all = np.array(num_data,dtype=np.int32)
|
num_data_all = np.array(num_data,dtype=np.int32)
|
||||||
self.mpi_comm.Allreduce([np.int32(num_data), MPI.INT], [num_data_all, MPI.INT])
|
self.mpi_comm.Allreduce([np.int32(num_data), MPI.INT], [num_data_all, MPI.INT])
|
||||||
num_data = num_data_all
|
num_data = num_data_all
|
||||||
|
|
@ -387,6 +385,7 @@ def update_gradients(model, mpi_comm=None):
|
||||||
|
|
||||||
# Gather the gradients from multiple MPI nodes
|
# Gather the gradients from multiple MPI nodes
|
||||||
if mpi_comm != None:
|
if mpi_comm != None:
|
||||||
|
from mpi4py import MPI
|
||||||
if het_noise:
|
if het_noise:
|
||||||
raise "het_noise not implemented!"
|
raise "het_noise not implemented!"
|
||||||
kern_grad_all = kern_grad.copy()
|
kern_grad_all = kern_grad.copy()
|
||||||
|
|
@ -409,6 +408,7 @@ def update_gradients(model, mpi_comm=None):
|
||||||
model.variational_prior.update_gradients_KL(X)
|
model.variational_prior.update_gradients_KL(X)
|
||||||
|
|
||||||
if mpi_comm != None:
|
if mpi_comm != None:
|
||||||
|
from mpi4py import MPI
|
||||||
KL_div_all = np.array(KL_div)
|
KL_div_all = np.array(KL_div)
|
||||||
mpi_comm.Allreduce([np.float64(KL_div), MPI.DOUBLE], [KL_div_all, MPI.DOUBLE])
|
mpi_comm.Allreduce([np.float64(KL_div), MPI.DOUBLE], [KL_div_all, MPI.DOUBLE])
|
||||||
KL_div = KL_div_all
|
KL_div = KL_div_all
|
||||||
|
|
@ -468,6 +468,7 @@ def update_gradients_sparsegp(model, mpi_comm=None):
|
||||||
|
|
||||||
# Gather the gradients from multiple MPI nodes
|
# Gather the gradients from multiple MPI nodes
|
||||||
if mpi_comm != None:
|
if mpi_comm != None:
|
||||||
|
from mpi4py import MPI
|
||||||
if het_noise:
|
if het_noise:
|
||||||
raise "het_noise not implemented!"
|
raise "het_noise not implemented!"
|
||||||
kern_grad_all = kern_grad.copy()
|
kern_grad_all = kern_grad.copy()
|
||||||
|
|
|
||||||
|
|
@ -148,6 +148,10 @@ def x_frame1D(X,plot_limits=None,resolution=None):
|
||||||
"""
|
"""
|
||||||
assert X.shape[1] ==1, "x_frame1D is defined for one-dimensional inputs"
|
assert X.shape[1] ==1, "x_frame1D is defined for one-dimensional inputs"
|
||||||
if plot_limits is None:
|
if plot_limits is None:
|
||||||
|
from ...core.parameterization.variational import VariationalPosterior
|
||||||
|
if isinstance(X, VariationalPosterior):
|
||||||
|
xmin,xmax = X.mean.min(0),X.mean.max(0)
|
||||||
|
else:
|
||||||
xmin,xmax = X.min(0),X.max(0)
|
xmin,xmax = X.min(0),X.max(0)
|
||||||
xmin, xmax = xmin-0.2*(xmax-xmin), xmax+0.2*(xmax-xmin)
|
xmin, xmax = xmin-0.2*(xmax-xmin), xmax+0.2*(xmax-xmin)
|
||||||
elif len(plot_limits)==2:
|
elif len(plot_limits)==2:
|
||||||
|
|
|
||||||
|
|
@ -2,26 +2,15 @@
|
||||||
The module of tools for parallelization (MPI)
|
The module of tools for parallelization (MPI)
|
||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
try:
|
|
||||||
|
def get_id_within_node(comm=None):
|
||||||
from mpi4py import MPI
|
from mpi4py import MPI
|
||||||
def get_id_within_node(comm=MPI.COMM_WORLD):
|
if comm is None: comm = MPI.COMM_WORLD
|
||||||
rank = comm.rank
|
rank = comm.rank
|
||||||
nodename = MPI.Get_processor_name()
|
nodename = MPI.Get_processor_name()
|
||||||
nodelist = comm.allgather(nodename)
|
nodelist = comm.allgather(nodename)
|
||||||
return len([i for i in nodelist[:rank] if i==nodename])
|
return len([i for i in nodelist[:rank] if i==nodename])
|
||||||
|
|
||||||
numpy_to_MPI_typemap = {
|
|
||||||
np.dtype(np.float64) : MPI.DOUBLE,
|
|
||||||
np.dtype(np.float32) : MPI.FLOAT,
|
|
||||||
np.dtype(np.int) : MPI.INT,
|
|
||||||
np.dtype(np.int8) : MPI.CHAR,
|
|
||||||
np.dtype(np.uint8) : MPI.UNSIGNED_CHAR,
|
|
||||||
np.dtype(np.int32) : MPI.INT,
|
|
||||||
np.dtype(np.uint32) : MPI.UNSIGNED_INT,
|
|
||||||
}
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def divide_data(datanum, rank, size):
|
def divide_data(datanum, rank, size):
|
||||||
assert rank<size and datanum>0
|
assert rank<size and datanum>0
|
||||||
|
|
||||||
|
|
|
||||||
2
setup.py
2
setup.py
|
|
@ -88,7 +88,7 @@ setup(name = 'GPy',
|
||||||
test_suite = 'GPy.testing',
|
test_suite = 'GPy.testing',
|
||||||
long_description=read_to_rst('README.md'),
|
long_description=read_to_rst('README.md'),
|
||||||
install_requires=['numpy>=1.7', 'scipy>=0.16', 'six'],
|
install_requires=['numpy>=1.7', 'scipy>=0.16', 'six'],
|
||||||
extras_require = {'docs':['matplotlib >=1.3','Sphinx','IPython']},
|
extras_require = {'docs':['matplotlib >=1.3','Sphinx','IPython'],'optional':['mpi4py']},
|
||||||
classifiers=['License :: OSI Approved :: BSD License',
|
classifiers=['License :: OSI Approved :: BSD License',
|
||||||
'Natural Language :: English',
|
'Natural Language :: English',
|
||||||
'Operating System :: MacOS :: MacOS X',
|
'Operating System :: MacOS :: MacOS X',
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue