diff --git a/GPy/inference/latent_function_inference/var_dtc_parallel.py b/GPy/inference/latent_function_inference/var_dtc_parallel.py index ef7d03d0..457ede66 100644 --- a/GPy/inference/latent_function_inference/var_dtc_parallel.py +++ b/GPy/inference/latent_function_inference/var_dtc_parallel.py @@ -9,10 +9,6 @@ import numpy as np from . import LatentFunctionInference log_2_pi = np.log(2*np.pi) -try: - from mpi4py import MPI -except: - pass class VarDTC_minibatch(LatentFunctionInference): """ @@ -123,6 +119,7 @@ class VarDTC_minibatch(LatentFunctionInference): YRY_full = trYYT*beta if self.mpi_comm != None: + from mpi4py import MPI psi0_all = np.array(psi0_full) psi1Y_all = psi1Y_full.copy() psi2_all = psi2_full.copy() @@ -146,6 +143,7 @@ class VarDTC_minibatch(LatentFunctionInference): num_data, output_dim = Y.shape input_dim = Z.shape[0] if self.mpi_comm != None: + from mpi4py import MPI num_data_all = np.array(num_data,dtype=np.int32) self.mpi_comm.Allreduce([np.int32(num_data), MPI.INT], [num_data_all, MPI.INT]) num_data = num_data_all @@ -387,6 +385,7 @@ def update_gradients(model, mpi_comm=None): # Gather the gradients from multiple MPI nodes if mpi_comm != None: + from mpi4py import MPI if het_noise: raise "het_noise not implemented!" kern_grad_all = kern_grad.copy() @@ -409,6 +408,7 @@ def update_gradients(model, mpi_comm=None): model.variational_prior.update_gradients_KL(X) if mpi_comm != None: + from mpi4py import MPI KL_div_all = np.array(KL_div) mpi_comm.Allreduce([np.float64(KL_div), MPI.DOUBLE], [KL_div_all, MPI.DOUBLE]) KL_div = KL_div_all @@ -468,6 +468,7 @@ def update_gradients_sparsegp(model, mpi_comm=None): # Gather the gradients from multiple MPI nodes if mpi_comm != None: + from mpi4py import MPI if het_noise: raise "het_noise not implemented!" kern_grad_all = kern_grad.copy() diff --git a/GPy/plotting/matplot_dep/base_plots.py b/GPy/plotting/matplot_dep/base_plots.py index 7ee0bd37..5e513ec2 100644 --- a/GPy/plotting/matplot_dep/base_plots.py +++ b/GPy/plotting/matplot_dep/base_plots.py @@ -148,7 +148,11 @@ def x_frame1D(X,plot_limits=None,resolution=None): """ assert X.shape[1] ==1, "x_frame1D is defined for one-dimensional inputs" if plot_limits is None: - xmin,xmax = X.min(0),X.max(0) + from ...core.parameterization.variational import VariationalPosterior + if isinstance(X, VariationalPosterior): + xmin,xmax = X.mean.min(0),X.mean.max(0) + else: + xmin,xmax = X.min(0),X.max(0) xmin, xmax = xmin-0.2*(xmax-xmin), xmax+0.2*(xmax-xmin) elif len(plot_limits)==2: xmin, xmax = plot_limits diff --git a/GPy/util/parallel.py b/GPy/util/parallel.py index a2211945..0c99287c 100644 --- a/GPy/util/parallel.py +++ b/GPy/util/parallel.py @@ -2,25 +2,14 @@ The module of tools for parallelization (MPI) """ import numpy as np -try: - from mpi4py import MPI - def get_id_within_node(comm=MPI.COMM_WORLD): - rank = comm.rank - nodename = MPI.Get_processor_name() - nodelist = comm.allgather(nodename) - return len([i for i in nodelist[:rank] if i==nodename]) - numpy_to_MPI_typemap = { - np.dtype(np.float64) : MPI.DOUBLE, - np.dtype(np.float32) : MPI.FLOAT, - np.dtype(np.int) : MPI.INT, - np.dtype(np.int8) : MPI.CHAR, - np.dtype(np.uint8) : MPI.UNSIGNED_CHAR, - np.dtype(np.int32) : MPI.INT, - np.dtype(np.uint32) : MPI.UNSIGNED_INT, - } -except: - pass +def get_id_within_node(comm=None): + from mpi4py import MPI + if comm is None: comm = MPI.COMM_WORLD + rank = comm.rank + nodename = MPI.Get_processor_name() + nodelist = comm.allgather(nodename) + return len([i for i in nodelist[:rank] if i==nodename]) def divide_data(datanum, rank, size): assert rank0 diff --git a/setup.py b/setup.py index 7edd7066..2ac587ac 100644 --- a/setup.py +++ b/setup.py @@ -88,7 +88,7 @@ setup(name = 'GPy', test_suite = 'GPy.testing', long_description=read_to_rst('README.md'), install_requires=['numpy>=1.7', 'scipy>=0.16', 'six'], - extras_require = {'docs':['matplotlib >=1.3','Sphinx','IPython']}, + extras_require = {'docs':['matplotlib >=1.3','Sphinx','IPython'],'optional':['mpi4py']}, classifiers=['License :: OSI Approved :: BSD License', 'Natural Language :: English', 'Operating System :: MacOS :: MacOS X',