diff --git a/GPy/core/sparse_gp_mpi.py b/GPy/core/sparse_gp_mpi.py index e9bd770d..e8779f51 100644 --- a/GPy/core/sparse_gp_mpi.py +++ b/GPy/core/sparse_gp_mpi.py @@ -47,7 +47,7 @@ class SparseGP_MPI(SparseGP): if variational_prior is not None: self.link_parameter(variational_prior) - + self.mpi_comm = mpi_comm # Manage the data (Y) division if mpi_comm != None: @@ -60,7 +60,6 @@ class SparseGP_MPI(SparseGP): mpi_comm.Bcast(self.param_array, root=0) self.update_model(True) - def __getstate__(self): dc = super(SparseGP_MPI, self).__getstate__() dc['mpi_comm'] = None diff --git a/GPy/models/bayesian_gplvm.py b/GPy/models/bayesian_gplvm.py index d363fb7a..fca97e96 100644 --- a/GPy/models/bayesian_gplvm.py +++ b/GPy/models/bayesian_gplvm.py @@ -25,8 +25,6 @@ class BayesianGPLVM(SparseGP_MPI): Z=None, kernel=None, inference_method=None, likelihood=None, name='bayesian gplvm', mpi_comm=None, normalizer=None, missing_data=False, stochastic=False, batchsize=1): - self.mpi_comm = mpi_comm - self.__IN_OPTIMIZATION__ = False self.logger = logging.getLogger(self.__class__.__name__) if X is None: diff --git a/GPy/models/sparse_gp_regression.py b/GPy/models/sparse_gp_regression.py index 5a56bb7d..49c3914c 100644 --- a/GPy/models/sparse_gp_regression.py +++ b/GPy/models/sparse_gp_regression.py @@ -9,6 +9,7 @@ from .. import likelihoods from .. import kern from ..inference.latent_function_inference import VarDTC from ..core.parameterization.variational import NormalPosterior +from GPy.inference.latent_function_inference.var_dtc_parallel import VarDTC_minibatch class SparseGPRegression(SparseGP_MPI): """ @@ -30,7 +31,7 @@ class SparseGPRegression(SparseGP_MPI): """ - def __init__(self, X, Y, kernel=None, Z=None, num_inducing=10, X_variance=None, normalizer=None): + def __init__(self, X, Y, kernel=None, Z=None, num_inducing=10, X_variance=None, normalizer=None, mpi_comm=None): num_data, input_dim = X.shape # kern defaults to rbf (plus white for stability) @@ -48,8 +49,14 @@ class SparseGPRegression(SparseGP_MPI): if not (X_variance is None): X = NormalPosterior(X,X_variance) + + if mpi_comm is not None: + from ..inference.latent_function_inference.var_dtc_parallel import VarDTC_minibatch + infr = VarDTC_minibatch(mpi_comm=mpi_comm) + else: + infr = VarDTC() - SparseGP_MPI.__init__(self, X, Y, Z, kernel, likelihood, inference_method=VarDTC(), normalizer=normalizer) + SparseGP_MPI.__init__(self, X, Y, Z, kernel, likelihood, inference_method=infr, normalizer=normalizer, mpi_comm=mpi_comm) def parameters_changed(self): from ..inference.latent_function_inference.var_dtc_parallel import update_gradients_sparsegp,VarDTC_minibatch diff --git a/GPy/testing/mpi_tests.py b/GPy/testing/mpi_tests.py index 4848a6ec..45777eb1 100644 --- a/GPy/testing/mpi_tests.py +++ b/GPy/testing/mpi_tests.py @@ -21,6 +21,7 @@ comm = MPI.COMM_WORLD N = 100 x = np.linspace(-6., 6., N) y = np.sin(x) + np.random.randn(N) * 0.05 +comm.Bcast(y) data = np.vstack([x,y]) infr = GPy.inference.latent_function_inference.VarDTC_minibatch(mpi_comm=comm) m = GPy.models.BayesianGPLVM(data.T,1,mpi_comm=comm) @@ -39,9 +40,43 @@ if comm.rank==0: (stdout, stderr) = p.communicate() L1 = float(stdout.splitlines()[-2]) L2 = float(stdout.splitlines()[-1]) - self.assertAlmostEqual(L1, L2) + self.assertTrue(np.allclose(L1,L2)) import os os.remove('mpi_test__.py') + + def test_SparseGPRegression_MPI(self): + code = """ +import numpy as np +import GPy +from mpi4py import MPI +np.random.seed(123456) +comm = MPI.COMM_WORLD +N = 100 +x = np.linspace(-6., 6., N) +y = np.sin(x) + np.random.randn(N) * 0.05 +comm.Bcast(y) +data = np.vstack([x,y]) +#infr = GPy.inference.latent_function_inference.VarDTC_minibatch(mpi_comm=comm) +m = GPy.models.SparseGPRegression(data[:1].T,data[1:2].T,mpi_comm=comm) +m.optimize(max_iters=10) +if comm.rank==0: + print float(m.objective_function()) + m.inference_method.mpi_comm=None + m.mpi_comm=None + m._trigger_params_changed() + print float(m.objective_function()) + """ + with open('mpi_test__.py','w') as f: + f.write(code) + f.close() + p = subprocess.Popen('mpirun -n 4 python mpi_test__.py',stdout=subprocess.PIPE,shell=True) + (stdout, stderr) = p.communicate() + L1 = float(stdout.splitlines()[-2]) + L2 = float(stdout.splitlines()[-1]) + self.assertTrue(np.allclose(L1,L2)) + import os + os.remove('mpi_test__.py') + except: pass