From eed647bec3ae1671fa3ec78c76ace49491381228 Mon Sep 17 00:00:00 2001 From: Zhenwen Dai Date: Fri, 17 Oct 2014 15:53:33 +0100 Subject: [PATCH] enable the mpi capability of ssmrd --- GPy/models/ss_mrd.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/GPy/models/ss_mrd.py b/GPy/models/ss_mrd.py index 497215ef..adaf6260 100644 --- a/GPy/models/ss_mrd.py +++ b/GPy/models/ss_mrd.py @@ -8,6 +8,7 @@ from .ss_gplvm import SSGPLVM from ..core.parameterization.variational import SpikeAndSlabPrior from ..util.misc import param_to_array from ..kern import RBF +from numpy.linalg.linalg import LinAlgError class SSMRD(Model): @@ -15,6 +16,7 @@ class SSMRD(Model): num_inducing=10, Zs=None, kernel=None, inference_methods=None, likelihoods=None, pi=0.5, name='ss_mrd', Ynames=None, mpi_comm=None): super(SSMRD, self).__init__(name) + self.mpi_comm = mpi_comm # initialize X for individual models X, X_variance, Gammas, fracs = self._init_X(Ylist, input_dim, X, X_variance, Gammas, initx) @@ -87,6 +89,41 @@ class SSMRD(Model): Gammas.append(gamma) return X, X_variance, Gammas, fracs + @Model.optimizer_array.setter + def optimizer_array(self, p): + if self.mpi_comm != None: + if self._IN_OPTIMIZATION_ and self.mpi_comm.rank==0: + self.mpi_comm.Bcast(np.int32(1),root=0) + self.mpi_comm.Bcast(p, root=0) + Model.optimizer_array.fset(self,p) + + def optimize(self, optimizer=None, start=None, **kwargs): + self._IN_OPTIMIZATION_ = True + if self.mpi_comm==None: + super(SSMRD, self).optimize(optimizer,start,**kwargs) + elif self.mpi_comm.rank==0: + super(SSMRD, self).optimize(optimizer,start,**kwargs) + self.mpi_comm.Bcast(np.int32(-1),root=0) + elif self.mpi_comm.rank>0: + x = self.optimizer_array.copy() + flag = np.empty(1,dtype=np.int32) + while True: + self.mpi_comm.Bcast(flag,root=0) + if flag==1: + try: + self.optimizer_array = x + self._fail_count = 0 + except (LinAlgError, ZeroDivisionError, ValueError): + if self._fail_count >= self._allowed_failures: + raise + self._fail_count += 1 + elif flag==-1: + break + else: + self._IN_OPTIMIZATION_ = False + raise Exception("Unrecognizable flag for synchronization!") + self._IN_OPTIMIZATION_ = False +