add mpi support for sparsegpregression

This commit is contained in:
Zhenwen Dai 2014-11-05 17:50:23 +00:00
parent 1dbe3e34b0
commit f5377aa441
4 changed files with 46 additions and 7 deletions

View file

@ -47,7 +47,7 @@ class SparseGP_MPI(SparseGP):
if variational_prior is not None: if variational_prior is not None:
self.link_parameter(variational_prior) self.link_parameter(variational_prior)
self.mpi_comm = mpi_comm self.mpi_comm = mpi_comm
# Manage the data (Y) division # Manage the data (Y) division
if mpi_comm != None: if mpi_comm != None:
@ -60,7 +60,6 @@ class SparseGP_MPI(SparseGP):
mpi_comm.Bcast(self.param_array, root=0) mpi_comm.Bcast(self.param_array, root=0)
self.update_model(True) self.update_model(True)
def __getstate__(self): def __getstate__(self):
dc = super(SparseGP_MPI, self).__getstate__() dc = super(SparseGP_MPI, self).__getstate__()
dc['mpi_comm'] = None dc['mpi_comm'] = None

View file

@ -25,8 +25,6 @@ class BayesianGPLVM(SparseGP_MPI):
Z=None, kernel=None, inference_method=None, likelihood=None, Z=None, kernel=None, inference_method=None, likelihood=None,
name='bayesian gplvm', mpi_comm=None, normalizer=None, name='bayesian gplvm', mpi_comm=None, normalizer=None,
missing_data=False, stochastic=False, batchsize=1): missing_data=False, stochastic=False, batchsize=1):
self.mpi_comm = mpi_comm
self.__IN_OPTIMIZATION__ = False
self.logger = logging.getLogger(self.__class__.__name__) self.logger = logging.getLogger(self.__class__.__name__)
if X is None: if X is None:

View file

@ -9,6 +9,7 @@ from .. import likelihoods
from .. import kern from .. import kern
from ..inference.latent_function_inference import VarDTC from ..inference.latent_function_inference import VarDTC
from ..core.parameterization.variational import NormalPosterior from ..core.parameterization.variational import NormalPosterior
from GPy.inference.latent_function_inference.var_dtc_parallel import VarDTC_minibatch
class SparseGPRegression(SparseGP_MPI): class SparseGPRegression(SparseGP_MPI):
""" """
@ -30,7 +31,7 @@ class SparseGPRegression(SparseGP_MPI):
""" """
def __init__(self, X, Y, kernel=None, Z=None, num_inducing=10, X_variance=None, normalizer=None): def __init__(self, X, Y, kernel=None, Z=None, num_inducing=10, X_variance=None, normalizer=None, mpi_comm=None):
num_data, input_dim = X.shape num_data, input_dim = X.shape
# kern defaults to rbf (plus white for stability) # kern defaults to rbf (plus white for stability)
@ -48,8 +49,14 @@ class SparseGPRegression(SparseGP_MPI):
if not (X_variance is None): if not (X_variance is None):
X = NormalPosterior(X,X_variance) X = NormalPosterior(X,X_variance)
if mpi_comm is not None:
from ..inference.latent_function_inference.var_dtc_parallel import VarDTC_minibatch
infr = VarDTC_minibatch(mpi_comm=mpi_comm)
else:
infr = VarDTC()
SparseGP_MPI.__init__(self, X, Y, Z, kernel, likelihood, inference_method=VarDTC(), normalizer=normalizer) SparseGP_MPI.__init__(self, X, Y, Z, kernel, likelihood, inference_method=infr, normalizer=normalizer, mpi_comm=mpi_comm)
def parameters_changed(self): def parameters_changed(self):
from ..inference.latent_function_inference.var_dtc_parallel import update_gradients_sparsegp,VarDTC_minibatch from ..inference.latent_function_inference.var_dtc_parallel import update_gradients_sparsegp,VarDTC_minibatch

View file

@ -21,6 +21,7 @@ comm = MPI.COMM_WORLD
N = 100 N = 100
x = np.linspace(-6., 6., N) x = np.linspace(-6., 6., N)
y = np.sin(x) + np.random.randn(N) * 0.05 y = np.sin(x) + np.random.randn(N) * 0.05
comm.Bcast(y)
data = np.vstack([x,y]) data = np.vstack([x,y])
infr = GPy.inference.latent_function_inference.VarDTC_minibatch(mpi_comm=comm) infr = GPy.inference.latent_function_inference.VarDTC_minibatch(mpi_comm=comm)
m = GPy.models.BayesianGPLVM(data.T,1,mpi_comm=comm) m = GPy.models.BayesianGPLVM(data.T,1,mpi_comm=comm)
@ -39,9 +40,43 @@ if comm.rank==0:
(stdout, stderr) = p.communicate() (stdout, stderr) = p.communicate()
L1 = float(stdout.splitlines()[-2]) L1 = float(stdout.splitlines()[-2])
L2 = float(stdout.splitlines()[-1]) L2 = float(stdout.splitlines()[-1])
self.assertAlmostEqual(L1, L2) self.assertTrue(np.allclose(L1,L2))
import os import os
os.remove('mpi_test__.py') os.remove('mpi_test__.py')
def test_SparseGPRegression_MPI(self):
code = """
import numpy as np
import GPy
from mpi4py import MPI
np.random.seed(123456)
comm = MPI.COMM_WORLD
N = 100
x = np.linspace(-6., 6., N)
y = np.sin(x) + np.random.randn(N) * 0.05
comm.Bcast(y)
data = np.vstack([x,y])
#infr = GPy.inference.latent_function_inference.VarDTC_minibatch(mpi_comm=comm)
m = GPy.models.SparseGPRegression(data[:1].T,data[1:2].T,mpi_comm=comm)
m.optimize(max_iters=10)
if comm.rank==0:
print float(m.objective_function())
m.inference_method.mpi_comm=None
m.mpi_comm=None
m._trigger_params_changed()
print float(m.objective_function())
"""
with open('mpi_test__.py','w') as f:
f.write(code)
f.close()
p = subprocess.Popen('mpirun -n 4 python mpi_test__.py',stdout=subprocess.PIPE,shell=True)
(stdout, stderr) = p.communicate()
L1 = float(stdout.splitlines()[-2])
L2 = float(stdout.splitlines()[-1])
self.assertTrue(np.allclose(L1,L2))
import os
os.remove('mpi_test__.py')
except: except:
pass pass