mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-30 14:35:15 +02:00
enable the mpi capability of ssmrd
This commit is contained in:
parent
3971e68b9c
commit
eed647bec3
1 changed files with 37 additions and 0 deletions
|
|
@ -8,6 +8,7 @@ from .ss_gplvm import SSGPLVM
|
|||
from ..core.parameterization.variational import SpikeAndSlabPrior
|
||||
from ..util.misc import param_to_array
|
||||
from ..kern import RBF
|
||||
from numpy.linalg.linalg import LinAlgError
|
||||
|
||||
class SSMRD(Model):
|
||||
|
||||
|
|
@ -15,6 +16,7 @@ class SSMRD(Model):
|
|||
num_inducing=10, Zs=None, kernel=None, inference_methods=None, likelihoods=None,
|
||||
pi=0.5, name='ss_mrd', Ynames=None, mpi_comm=None):
|
||||
super(SSMRD, self).__init__(name)
|
||||
self.mpi_comm = mpi_comm
|
||||
|
||||
# initialize X for individual models
|
||||
X, X_variance, Gammas, fracs = self._init_X(Ylist, input_dim, X, X_variance, Gammas, initx)
|
||||
|
|
@ -87,6 +89,41 @@ class SSMRD(Model):
|
|||
Gammas.append(gamma)
|
||||
return X, X_variance, Gammas, fracs
|
||||
|
||||
@Model.optimizer_array.setter
|
||||
def optimizer_array(self, p):
|
||||
if self.mpi_comm != None:
|
||||
if self._IN_OPTIMIZATION_ and self.mpi_comm.rank==0:
|
||||
self.mpi_comm.Bcast(np.int32(1),root=0)
|
||||
self.mpi_comm.Bcast(p, root=0)
|
||||
Model.optimizer_array.fset(self,p)
|
||||
|
||||
def optimize(self, optimizer=None, start=None, **kwargs):
|
||||
self._IN_OPTIMIZATION_ = True
|
||||
if self.mpi_comm==None:
|
||||
super(SSMRD, self).optimize(optimizer,start,**kwargs)
|
||||
elif self.mpi_comm.rank==0:
|
||||
super(SSMRD, self).optimize(optimizer,start,**kwargs)
|
||||
self.mpi_comm.Bcast(np.int32(-1),root=0)
|
||||
elif self.mpi_comm.rank>0:
|
||||
x = self.optimizer_array.copy()
|
||||
flag = np.empty(1,dtype=np.int32)
|
||||
while True:
|
||||
self.mpi_comm.Bcast(flag,root=0)
|
||||
if flag==1:
|
||||
try:
|
||||
self.optimizer_array = x
|
||||
self._fail_count = 0
|
||||
except (LinAlgError, ZeroDivisionError, ValueError):
|
||||
if self._fail_count >= self._allowed_failures:
|
||||
raise
|
||||
self._fail_count += 1
|
||||
elif flag==-1:
|
||||
break
|
||||
else:
|
||||
self._IN_OPTIMIZATION_ = False
|
||||
raise Exception("Unrecognizable flag for synchronization!")
|
||||
self._IN_OPTIMIZATION_ = False
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue