From 001db6b089c3a332fc9fed1078ee362461df0625 Mon Sep 17 00:00:00 2001 From: Zhenwen Dai Date: Sun, 18 May 2014 10:05:35 +0100 Subject: [PATCH] [mpi] enable checkgrad --- GPy/models/ss_gplvm.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/GPy/models/ss_gplvm.py b/GPy/models/ss_gplvm.py index 08dc6d95..b0fe48b1 100644 --- a/GPy/models/ss_gplvm.py +++ b/GPy/models/ss_gplvm.py @@ -30,6 +30,7 @@ class SSGPLVM(SparseGP): Z=None, kernel=None, inference_method=None, likelihood=None, name='Spike-and-Slab GPLVM', group_spike=False, mpi_comm=None, **kwargs): self.mpi_comm = mpi_comm + self.__IN_OPTIMIZATION__ = False if X == None: from ..util.initialization import initialize_latent @@ -143,12 +144,13 @@ class SSGPLVM(SparseGP): def _set_params_transformed(self, p): if self.mpi_comm != None: - if self.mpi_comm.rank==0: + if self.__IN_OPTIMIZATION__ and self.mpi_comm.rank==0: self.mpi_comm.Bcast(np.int32(1),root=0) self.mpi_comm.Bcast(p, root=0) super(SSGPLVM, self)._set_params_transformed(p) - def optimize(self, optimizer=None, start=None, **kwargs): + def optimize(self, optimizer=None, start=None, **kwargs): + self.__IN_OPTIMIZATION__ = True if self.mpi_comm==None: super(SSGPLVM, self).optimize(optimizer,start,**kwargs) elif self.mpi_comm.rank==0: @@ -164,4 +166,6 @@ class SSGPLVM(SparseGP): elif flag==-1: break else: + self.__IN_OPTIMIZATION__ = False raise Exception("Unrecognizable flag for synchronization!") + self.__IN_OPTIMIZATION__ = False