mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 20:42:39 +02:00
[mpi] enable checkgrad
This commit is contained in:
parent
e7177b6d37
commit
001db6b089
1 changed files with 6 additions and 2 deletions
|
|
@ -30,6 +30,7 @@ class SSGPLVM(SparseGP):
|
||||||
Z=None, kernel=None, inference_method=None, likelihood=None, name='Spike-and-Slab GPLVM', group_spike=False, mpi_comm=None, **kwargs):
|
Z=None, kernel=None, inference_method=None, likelihood=None, name='Spike-and-Slab GPLVM', group_spike=False, mpi_comm=None, **kwargs):
|
||||||
|
|
||||||
self.mpi_comm = mpi_comm
|
self.mpi_comm = mpi_comm
|
||||||
|
self.__IN_OPTIMIZATION__ = False
|
||||||
|
|
||||||
if X == None:
|
if X == None:
|
||||||
from ..util.initialization import initialize_latent
|
from ..util.initialization import initialize_latent
|
||||||
|
|
@ -143,12 +144,13 @@ class SSGPLVM(SparseGP):
|
||||||
|
|
||||||
def _set_params_transformed(self, p):
|
def _set_params_transformed(self, p):
|
||||||
if self.mpi_comm != None:
|
if self.mpi_comm != None:
|
||||||
if self.mpi_comm.rank==0:
|
if self.__IN_OPTIMIZATION__ and self.mpi_comm.rank==0:
|
||||||
self.mpi_comm.Bcast(np.int32(1),root=0)
|
self.mpi_comm.Bcast(np.int32(1),root=0)
|
||||||
self.mpi_comm.Bcast(p, root=0)
|
self.mpi_comm.Bcast(p, root=0)
|
||||||
super(SSGPLVM, self)._set_params_transformed(p)
|
super(SSGPLVM, self)._set_params_transformed(p)
|
||||||
|
|
||||||
def optimize(self, optimizer=None, start=None, **kwargs):
|
def optimize(self, optimizer=None, start=None, **kwargs):
|
||||||
|
self.__IN_OPTIMIZATION__ = True
|
||||||
if self.mpi_comm==None:
|
if self.mpi_comm==None:
|
||||||
super(SSGPLVM, self).optimize(optimizer,start,**kwargs)
|
super(SSGPLVM, self).optimize(optimizer,start,**kwargs)
|
||||||
elif self.mpi_comm.rank==0:
|
elif self.mpi_comm.rank==0:
|
||||||
|
|
@ -164,4 +166,6 @@ class SSGPLVM(SparseGP):
|
||||||
elif flag==-1:
|
elif flag==-1:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
self.__IN_OPTIMIZATION__ = False
|
||||||
raise Exception("Unrecognizable flag for synchronization!")
|
raise Exception("Unrecognizable flag for synchronization!")
|
||||||
|
self.__IN_OPTIMIZATION__ = False
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue