[mpi] enable checkgrad

This commit is contained in:
Zhenwen Dai 2014-05-18 10:05:35 +01:00
parent e7177b6d37
commit 001db6b089

View file

@ -30,6 +30,7 @@ class SSGPLVM(SparseGP):
Z=None, kernel=None, inference_method=None, likelihood=None, name='Spike-and-Slab GPLVM', group_spike=False, mpi_comm=None, **kwargs): Z=None, kernel=None, inference_method=None, likelihood=None, name='Spike-and-Slab GPLVM', group_spike=False, mpi_comm=None, **kwargs):
self.mpi_comm = mpi_comm self.mpi_comm = mpi_comm
self.__IN_OPTIMIZATION__ = False
if X == None: if X == None:
from ..util.initialization import initialize_latent from ..util.initialization import initialize_latent
@ -143,12 +144,13 @@ class SSGPLVM(SparseGP):
def _set_params_transformed(self, p): def _set_params_transformed(self, p):
if self.mpi_comm != None: if self.mpi_comm != None:
if self.mpi_comm.rank==0: if self.__IN_OPTIMIZATION__ and self.mpi_comm.rank==0:
self.mpi_comm.Bcast(np.int32(1),root=0) self.mpi_comm.Bcast(np.int32(1),root=0)
self.mpi_comm.Bcast(p, root=0) self.mpi_comm.Bcast(p, root=0)
super(SSGPLVM, self)._set_params_transformed(p) super(SSGPLVM, self)._set_params_transformed(p)
def optimize(self, optimizer=None, start=None, **kwargs): def optimize(self, optimizer=None, start=None, **kwargs):
self.__IN_OPTIMIZATION__ = True
if self.mpi_comm==None: if self.mpi_comm==None:
super(SSGPLVM, self).optimize(optimizer,start,**kwargs) super(SSGPLVM, self).optimize(optimizer,start,**kwargs)
elif self.mpi_comm.rank==0: elif self.mpi_comm.rank==0:
@ -164,4 +166,6 @@ class SSGPLVM(SparseGP):
elif flag==-1: elif flag==-1:
break break
else: else:
self.__IN_OPTIMIZATION__ = False
raise Exception("Unrecognizable flag for synchronization!") raise Exception("Unrecognizable flag for synchronization!")
self.__IN_OPTIMIZATION__ = False