add test case for mpi

This commit is contained in:
Zhenwen Dai 2014-11-05 16:33:02 +00:00
parent e37b58670c
commit ef84339e46
2 changed files with 59 additions and 3 deletions

View file

@ -84,6 +84,8 @@ class BayesianGPLVM(SparseGP_MPI):
def parameters_changed(self): def parameters_changed(self):
super(BayesianGPLVM,self).parameters_changed() super(BayesianGPLVM,self).parameters_changed()
if isinstance(self.inference_method, VarDTC_minibatch):
return
kl_fctr = 1. kl_fctr = 1.
self._log_marginal_likelihood -= kl_fctr*self.variational_prior.KL_divergence(self.X) self._log_marginal_likelihood -= kl_fctr*self.variational_prior.KL_divergence(self.X)
@ -98,9 +100,6 @@ class BayesianGPLVM(SparseGP_MPI):
self.variational_prior.update_gradients_KL(self.X) self.variational_prior.update_gradients_KL(self.X)
if isinstance(self.inference_method, VarDTC_minibatch):
return
#super(BayesianGPLVM, self).parameters_changed() #super(BayesianGPLVM, self).parameters_changed()
#self._log_marginal_likelihood -= self.variational_prior.KL_divergence(self.X) #self._log_marginal_likelihood -= self.variational_prior.KL_divergence(self.X)

57
GPy/testing/mpi_tests.py Normal file
View file

@ -0,0 +1,57 @@
# Copyright (c) 2012, GPy authors (see AUTHORS.txt).
# Licensed under the BSD 3-clause license (see LICENSE.txt)
import unittest
import numpy as np
import GPy
try:
from mpi4py import MPI
import subprocess
class MPITests(unittest.TestCase):
def test_BayesianGPLVM_MPI(self):
code = """
import numpy as np
import GPy
from mpi4py import MPI
np.random.seed(123456)
comm = MPI.COMM_WORLD
N = 100
x = np.linspace(-6., 6., N)
y = np.sin(x) + np.random.randn(N) * 0.05
data = np.vstack([x,y])
infr = GPy.inference.latent_function_inference.VarDTC_minibatch(mpi_comm=comm)
m = GPy.models.BayesianGPLVM(data.T,1,mpi_comm=comm)
m.optimize(max_iters=10)
if comm.rank==0:
print float(m.objective_function())
m.inference_method.mpi_comm=None
m.mpi_comm=None
m._trigger_params_changed()
print float(m.objective_function())
"""
with open('mpi_test__.py','w') as f:
f.write(code)
f.close()
p = subprocess.Popen('mpirun -n 4 python mpi_test__.py',stdout=subprocess.PIPE,shell=True)
(stdout, stderr) = p.communicate()
L1 = float(stdout.splitlines()[-2])
L2 = float(stdout.splitlines()[-1])
self.assertAlmostEqual(L1, L2)
import os
os.remove('mpi_test__.py')
except:
pass
if __name__ == "__main__":
print "Running unit tests, please be (very) patient..."
try:
import mpi4py
unittest.main()
except:
pass