diff --git a/GPy/inference/latent_function_inference/var_dtc_parallel.py b/GPy/inference/latent_function_inference/var_dtc_parallel.py index 7b11085e..dc4f45d5 100644 --- a/GPy/inference/latent_function_inference/var_dtc_parallel.py +++ b/GPy/inference/latent_function_inference/var_dtc_parallel.py @@ -78,7 +78,7 @@ class VarDTC_minibatch(LatentFunctionInference): num_inducing = Z.shape[0] num_data, output_dim = Y.shape - if self.batchsize == None or self.batchsize>num_data: + if self.batchsize == None: self.batchsize = num_data trYYT = self.get_trYYT(Y) diff --git a/GPy/models/ss_gplvm.py b/GPy/models/ss_gplvm.py index 8581bf7c..08dc6d95 100644 --- a/GPy/models/ss_gplvm.py +++ b/GPy/models/ss_gplvm.py @@ -14,7 +14,6 @@ from ..core.parameterization.variational import SpikeAndSlabPrior, SpikeAndSlabP from ..inference.latent_function_inference.var_dtc_parallel import update_gradients, VarDTC_minibatch from ..inference.latent_function_inference.var_dtc_gpu import VarDTC_GPU - class SSGPLVM(SparseGP): """ Spike-and-Slab Gaussian Process Latent Variable Model @@ -88,7 +87,8 @@ class SSGPLVM(SparseGP): self.X_local = self.X[Y_start:Y_end] self.Y_range = (Y_start, Y_end) self.Y_list = np.array(Y_list) - [mpi_comm.Bcast(p, root=0) for p in self.flattened_parameters] + print self.mpi_comm.rank, self.Y_range + mpi_comm.Bcast(self.param_array, root=0) def set_X_gradients(self, X, X_grad): """Set the gradients of the posterior distribution of X in its specific form.""" @@ -136,20 +136,32 @@ class SSGPLVM(SparseGP): def __setstate__(self, state): return super(SSGPLVM, self).__setstate__(state) - def _grads(self, x): + #===================================================== + # The MPI parallelization + # - can move to model at some point + #===================================================== + + def _set_params_transformed(self, p): if self.mpi_comm != None: - self.mpi_comm.Bcast(x, root=0) - obj_grads = super(SSGPLVM, self)._grads(x) - return obj_grads - - def _objective(self, x): - if self.mpi_comm != None: - self.mpi_comm.Bcast(x, root=0) - obj = super(SSGPLVM, self)._objective(x) - return obj - - def _objective_grads(self, x): - if self.mpi_comm != None: - self.mpi_comm.Bcast(x, root=0) - obj_f, obj_grads = super(SSGPLVM, self)._objective_grads(x) - return obj_f, obj_grads + if self.mpi_comm.rank==0: + self.mpi_comm.Bcast(np.int32(1),root=0) + self.mpi_comm.Bcast(p, root=0) + super(SSGPLVM, self)._set_params_transformed(p) + + def optimize(self, optimizer=None, start=None, **kwargs): + if self.mpi_comm==None: + super(SSGPLVM, self).optimize(optimizer,start,**kwargs) + elif self.mpi_comm.rank==0: + super(SSGPLVM, self).optimize(optimizer,start,**kwargs) + self.mpi_comm.Bcast(np.int32(-1),root=0) + elif self.mpi_comm.rank>0: + x = self._get_params_transformed().copy() + flag = np.empty(1,dtype=np.int32) + while True: + self.mpi_comm.Bcast(flag,root=0) + if flag==1: + self._set_params_transformed(x) + elif flag==-1: + break + else: + raise Exception("Unrecognizable flag for synchronization!")