From c568bad4fbef7a0e91079af0c14bb808770a8be8 Mon Sep 17 00:00:00 2001 From: Zhenwen Dai Date: Fri, 16 May 2014 10:40:45 +0100 Subject: [PATCH] bug fix for mpi SSGPLVM --- .../var_dtc_parallel.py | 4 ++++ GPy/models/ss_gplvm.py | 22 +++++++++++++++++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/GPy/inference/latent_function_inference/var_dtc_parallel.py b/GPy/inference/latent_function_inference/var_dtc_parallel.py index 4123e5b2..44babf25 100644 --- a/GPy/inference/latent_function_inference/var_dtc_parallel.py +++ b/GPy/inference/latent_function_inference/var_dtc_parallel.py @@ -400,6 +400,10 @@ def update_gradients(model, mpi_comm=None): mpi_comm.Allreduce([np.float64(KL_div), MPI.DOUBLE], [KL_div_all, MPI.DOUBLE]) KL_div = KL_div_all [mpi_comm.Allgatherv([pp.copy(), MPI.DOUBLE], [pa, (model.Y_list*pa.shape[-1], None), MPI.DOUBLE]) for pp,pa in zip(model.get_X_gradients(X),model.get_X_gradients(model.X))] + from ...models import SSGPLVM + if isinstance(model, SSGPLVM): + grad_pi = np.array(model.variational_prior.pi.gradient) + mpi_comm.Allreduce([grad_pi.copy(), MPI.DOUBLE], [model.variational_prior.pi.gradient, MPI.DOUBLE]) model._log_marginal_likelihood -= KL_div # dL_dthetaL diff --git a/GPy/models/ss_gplvm.py b/GPy/models/ss_gplvm.py index 76e2f0ef..27d5158f 100644 --- a/GPy/models/ss_gplvm.py +++ b/GPy/models/ss_gplvm.py @@ -45,9 +45,9 @@ class SSGPLVM(SparseGP): gamma = np.empty_like(X, order='F') # The posterior probabilities of the binary variable in the variational approximation gamma[:] = 0.5 + 0.1 * np.random.randn(X.shape[0], input_dim) - gamma[gamma>=1. - 1e-9] = 1e-9 + gamma[gamma>1.-1e-9] = 1.-1e-9 gamma[gamma<1e-9] = 1e-9 - #gamma[:] = 0.5 + gamma[:] = 0.5 if group_spike: gamma[:] = gamma.mean(axis=0) @@ -142,3 +142,21 @@ class SSGPLVM(SparseGP): state['Y_local'] = state['Y'][Y_range[0]:Y_range[1]] state['X_local'] = state['X'][Y_range[0]:Y_range[1]] return super(SSGPLVM, self).__setstate__(state) + + def _grads(self, x): + if self.mpi_comm != None: + self.mpi_comm.Bcast(x, root=0) + obj_grads = super(SSGPLVM, self)._grads(x) + return obj_grads + + def _objective(self, x): + if self.mpi_comm != None: + self.mpi_comm.Bcast(x, root=0) + obj = super(SSGPLVM, self)._objective(x) + return obj + + def _objective_grads(self, x): + if self.mpi_comm != None: + self.mpi_comm.Bcast(x, root=0) + obj_f, obj_grads = super(SSGPLVM, self)._objective_grads(x) + return obj_f, obj_grads