diff --git a/GPy/models/bayesian_gplvm.py b/GPy/models/bayesian_gplvm.py index 0838b684..d363fb7a 100644 --- a/GPy/models/bayesian_gplvm.py +++ b/GPy/models/bayesian_gplvm.py @@ -84,6 +84,8 @@ class BayesianGPLVM(SparseGP_MPI): def parameters_changed(self): super(BayesianGPLVM,self).parameters_changed() + if isinstance(self.inference_method, VarDTC_minibatch): + return kl_fctr = 1. self._log_marginal_likelihood -= kl_fctr*self.variational_prior.KL_divergence(self.X) @@ -98,9 +100,6 @@ class BayesianGPLVM(SparseGP_MPI): self.variational_prior.update_gradients_KL(self.X) - if isinstance(self.inference_method, VarDTC_minibatch): - return - #super(BayesianGPLVM, self).parameters_changed() #self._log_marginal_likelihood -= self.variational_prior.KL_divergence(self.X) diff --git a/GPy/testing/mpi_tests.py b/GPy/testing/mpi_tests.py new file mode 100644 index 00000000..4848a6ec --- /dev/null +++ b/GPy/testing/mpi_tests.py @@ -0,0 +1,57 @@ +# Copyright (c) 2012, GPy authors (see AUTHORS.txt). +# Licensed under the BSD 3-clause license (see LICENSE.txt) + +import unittest +import numpy as np +import GPy + +try: + from mpi4py import MPI + import subprocess + + class MPITests(unittest.TestCase): + + def test_BayesianGPLVM_MPI(self): + code = """ +import numpy as np +import GPy +from mpi4py import MPI +np.random.seed(123456) +comm = MPI.COMM_WORLD +N = 100 +x = np.linspace(-6., 6., N) +y = np.sin(x) + np.random.randn(N) * 0.05 +data = np.vstack([x,y]) +infr = GPy.inference.latent_function_inference.VarDTC_minibatch(mpi_comm=comm) +m = GPy.models.BayesianGPLVM(data.T,1,mpi_comm=comm) +m.optimize(max_iters=10) +if comm.rank==0: + print float(m.objective_function()) + m.inference_method.mpi_comm=None + m.mpi_comm=None + m._trigger_params_changed() + print float(m.objective_function()) + """ + with open('mpi_test__.py','w') as f: + f.write(code) + f.close() + p = subprocess.Popen('mpirun -n 4 python mpi_test__.py',stdout=subprocess.PIPE,shell=True) + (stdout, stderr) = p.communicate() + L1 = float(stdout.splitlines()[-2]) + L2 = float(stdout.splitlines()[-1]) + self.assertAlmostEqual(L1, L2) + import os + os.remove('mpi_test__.py') + +except: + pass + + + +if __name__ == "__main__": + print "Running unit tests, please be (very) patient..." + try: + import mpi4py + unittest.main() + except: + pass \ No newline at end of file