mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 20:42:39 +02:00
[mpi] enable checkgrad
This commit is contained in:
parent
e7177b6d37
commit
001db6b089
1 changed files with 6 additions and 2 deletions
|
|
@ -30,6 +30,7 @@ class SSGPLVM(SparseGP):
|
|||
Z=None, kernel=None, inference_method=None, likelihood=None, name='Spike-and-Slab GPLVM', group_spike=False, mpi_comm=None, **kwargs):
|
||||
|
||||
self.mpi_comm = mpi_comm
|
||||
self.__IN_OPTIMIZATION__ = False
|
||||
|
||||
if X == None:
|
||||
from ..util.initialization import initialize_latent
|
||||
|
|
@ -143,12 +144,13 @@ class SSGPLVM(SparseGP):
|
|||
|
||||
def _set_params_transformed(self, p):
|
||||
if self.mpi_comm != None:
|
||||
if self.mpi_comm.rank==0:
|
||||
if self.__IN_OPTIMIZATION__ and self.mpi_comm.rank==0:
|
||||
self.mpi_comm.Bcast(np.int32(1),root=0)
|
||||
self.mpi_comm.Bcast(p, root=0)
|
||||
super(SSGPLVM, self)._set_params_transformed(p)
|
||||
|
||||
def optimize(self, optimizer=None, start=None, **kwargs):
|
||||
self.__IN_OPTIMIZATION__ = True
|
||||
if self.mpi_comm==None:
|
||||
super(SSGPLVM, self).optimize(optimizer,start,**kwargs)
|
||||
elif self.mpi_comm.rank==0:
|
||||
|
|
@ -164,4 +166,6 @@ class SSGPLVM(SparseGP):
|
|||
elif flag==-1:
|
||||
break
|
||||
else:
|
||||
self.__IN_OPTIMIZATION__ = False
|
||||
raise Exception("Unrecognizable flag for synchronization!")
|
||||
self.__IN_OPTIMIZATION__ = False
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue