mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-30 14:35:15 +02:00
enable the mpi capability of ssmrd
This commit is contained in:
parent
3971e68b9c
commit
eed647bec3
1 changed files with 37 additions and 0 deletions
|
|
@ -8,6 +8,7 @@ from .ss_gplvm import SSGPLVM
|
||||||
from ..core.parameterization.variational import SpikeAndSlabPrior
|
from ..core.parameterization.variational import SpikeAndSlabPrior
|
||||||
from ..util.misc import param_to_array
|
from ..util.misc import param_to_array
|
||||||
from ..kern import RBF
|
from ..kern import RBF
|
||||||
|
from numpy.linalg.linalg import LinAlgError
|
||||||
|
|
||||||
class SSMRD(Model):
|
class SSMRD(Model):
|
||||||
|
|
||||||
|
|
@ -15,6 +16,7 @@ class SSMRD(Model):
|
||||||
num_inducing=10, Zs=None, kernel=None, inference_methods=None, likelihoods=None,
|
num_inducing=10, Zs=None, kernel=None, inference_methods=None, likelihoods=None,
|
||||||
pi=0.5, name='ss_mrd', Ynames=None, mpi_comm=None):
|
pi=0.5, name='ss_mrd', Ynames=None, mpi_comm=None):
|
||||||
super(SSMRD, self).__init__(name)
|
super(SSMRD, self).__init__(name)
|
||||||
|
self.mpi_comm = mpi_comm
|
||||||
|
|
||||||
# initialize X for individual models
|
# initialize X for individual models
|
||||||
X, X_variance, Gammas, fracs = self._init_X(Ylist, input_dim, X, X_variance, Gammas, initx)
|
X, X_variance, Gammas, fracs = self._init_X(Ylist, input_dim, X, X_variance, Gammas, initx)
|
||||||
|
|
@ -87,6 +89,41 @@ class SSMRD(Model):
|
||||||
Gammas.append(gamma)
|
Gammas.append(gamma)
|
||||||
return X, X_variance, Gammas, fracs
|
return X, X_variance, Gammas, fracs
|
||||||
|
|
||||||
|
@Model.optimizer_array.setter
|
||||||
|
def optimizer_array(self, p):
|
||||||
|
if self.mpi_comm != None:
|
||||||
|
if self._IN_OPTIMIZATION_ and self.mpi_comm.rank==0:
|
||||||
|
self.mpi_comm.Bcast(np.int32(1),root=0)
|
||||||
|
self.mpi_comm.Bcast(p, root=0)
|
||||||
|
Model.optimizer_array.fset(self,p)
|
||||||
|
|
||||||
|
def optimize(self, optimizer=None, start=None, **kwargs):
|
||||||
|
self._IN_OPTIMIZATION_ = True
|
||||||
|
if self.mpi_comm==None:
|
||||||
|
super(SSMRD, self).optimize(optimizer,start,**kwargs)
|
||||||
|
elif self.mpi_comm.rank==0:
|
||||||
|
super(SSMRD, self).optimize(optimizer,start,**kwargs)
|
||||||
|
self.mpi_comm.Bcast(np.int32(-1),root=0)
|
||||||
|
elif self.mpi_comm.rank>0:
|
||||||
|
x = self.optimizer_array.copy()
|
||||||
|
flag = np.empty(1,dtype=np.int32)
|
||||||
|
while True:
|
||||||
|
self.mpi_comm.Bcast(flag,root=0)
|
||||||
|
if flag==1:
|
||||||
|
try:
|
||||||
|
self.optimizer_array = x
|
||||||
|
self._fail_count = 0
|
||||||
|
except (LinAlgError, ZeroDivisionError, ValueError):
|
||||||
|
if self._fail_count >= self._allowed_failures:
|
||||||
|
raise
|
||||||
|
self._fail_count += 1
|
||||||
|
elif flag==-1:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
self._IN_OPTIMIZATION_ = False
|
||||||
|
raise Exception("Unrecognizable flag for synchronization!")
|
||||||
|
self._IN_OPTIMIZATION_ = False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue