mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-07 19:12:40 +02:00
[mpi] fix the bug of mpi
This commit is contained in:
parent
e6d07ad5ac
commit
e7177b6d37
2 changed files with 31 additions and 19 deletions
|
|
@ -78,7 +78,7 @@ class VarDTC_minibatch(LatentFunctionInference):
|
||||||
num_inducing = Z.shape[0]
|
num_inducing = Z.shape[0]
|
||||||
num_data, output_dim = Y.shape
|
num_data, output_dim = Y.shape
|
||||||
|
|
||||||
if self.batchsize == None or self.batchsize>num_data:
|
if self.batchsize == None:
|
||||||
self.batchsize = num_data
|
self.batchsize = num_data
|
||||||
|
|
||||||
trYYT = self.get_trYYT(Y)
|
trYYT = self.get_trYYT(Y)
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,6 @@ from ..core.parameterization.variational import SpikeAndSlabPrior, SpikeAndSlabP
|
||||||
from ..inference.latent_function_inference.var_dtc_parallel import update_gradients, VarDTC_minibatch
|
from ..inference.latent_function_inference.var_dtc_parallel import update_gradients, VarDTC_minibatch
|
||||||
from ..inference.latent_function_inference.var_dtc_gpu import VarDTC_GPU
|
from ..inference.latent_function_inference.var_dtc_gpu import VarDTC_GPU
|
||||||
|
|
||||||
|
|
||||||
class SSGPLVM(SparseGP):
|
class SSGPLVM(SparseGP):
|
||||||
"""
|
"""
|
||||||
Spike-and-Slab Gaussian Process Latent Variable Model
|
Spike-and-Slab Gaussian Process Latent Variable Model
|
||||||
|
|
@ -88,7 +87,8 @@ class SSGPLVM(SparseGP):
|
||||||
self.X_local = self.X[Y_start:Y_end]
|
self.X_local = self.X[Y_start:Y_end]
|
||||||
self.Y_range = (Y_start, Y_end)
|
self.Y_range = (Y_start, Y_end)
|
||||||
self.Y_list = np.array(Y_list)
|
self.Y_list = np.array(Y_list)
|
||||||
[mpi_comm.Bcast(p, root=0) for p in self.flattened_parameters]
|
print self.mpi_comm.rank, self.Y_range
|
||||||
|
mpi_comm.Bcast(self.param_array, root=0)
|
||||||
|
|
||||||
def set_X_gradients(self, X, X_grad):
|
def set_X_gradients(self, X, X_grad):
|
||||||
"""Set the gradients of the posterior distribution of X in its specific form."""
|
"""Set the gradients of the posterior distribution of X in its specific form."""
|
||||||
|
|
@ -136,20 +136,32 @@ class SSGPLVM(SparseGP):
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
return super(SSGPLVM, self).__setstate__(state)
|
return super(SSGPLVM, self).__setstate__(state)
|
||||||
|
|
||||||
def _grads(self, x):
|
#=====================================================
|
||||||
|
# The MPI parallelization
|
||||||
|
# - can move to model at some point
|
||||||
|
#=====================================================
|
||||||
|
|
||||||
|
def _set_params_transformed(self, p):
|
||||||
if self.mpi_comm != None:
|
if self.mpi_comm != None:
|
||||||
self.mpi_comm.Bcast(x, root=0)
|
if self.mpi_comm.rank==0:
|
||||||
obj_grads = super(SSGPLVM, self)._grads(x)
|
self.mpi_comm.Bcast(np.int32(1),root=0)
|
||||||
return obj_grads
|
self.mpi_comm.Bcast(p, root=0)
|
||||||
|
super(SSGPLVM, self)._set_params_transformed(p)
|
||||||
def _objective(self, x):
|
|
||||||
if self.mpi_comm != None:
|
def optimize(self, optimizer=None, start=None, **kwargs):
|
||||||
self.mpi_comm.Bcast(x, root=0)
|
if self.mpi_comm==None:
|
||||||
obj = super(SSGPLVM, self)._objective(x)
|
super(SSGPLVM, self).optimize(optimizer,start,**kwargs)
|
||||||
return obj
|
elif self.mpi_comm.rank==0:
|
||||||
|
super(SSGPLVM, self).optimize(optimizer,start,**kwargs)
|
||||||
def _objective_grads(self, x):
|
self.mpi_comm.Bcast(np.int32(-1),root=0)
|
||||||
if self.mpi_comm != None:
|
elif self.mpi_comm.rank>0:
|
||||||
self.mpi_comm.Bcast(x, root=0)
|
x = self._get_params_transformed().copy()
|
||||||
obj_f, obj_grads = super(SSGPLVM, self)._objective_grads(x)
|
flag = np.empty(1,dtype=np.int32)
|
||||||
return obj_f, obj_grads
|
while True:
|
||||||
|
self.mpi_comm.Bcast(flag,root=0)
|
||||||
|
if flag==1:
|
||||||
|
self._set_params_transformed(x)
|
||||||
|
elif flag==-1:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise Exception("Unrecognizable flag for synchronization!")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue