mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-11 15:15:15 +02:00
BayersianGPLVM mpi support
This commit is contained in:
parent
52c0be1848
commit
04ab93a961
1 changed files with 63 additions and 2 deletions
|
|
@ -24,7 +24,9 @@ class BayesianGPLVM(SparseGP):
|
|||
|
||||
"""
|
||||
def __init__(self, Y, input_dim, X=None, X_variance=None, init='PCA', num_inducing=10,
|
||||
Z=None, kernel=None, inference_method=None, likelihood=None, name='bayesian gplvm', **kwargs):
|
||||
Z=None, kernel=None, inference_method=None, likelihood=None, name='bayesian gplvm', mpi_comm=None, **kwargs):
|
||||
self.mpi_comm = mpi_comm
|
||||
self.__IN_OPTIMIZATION__ = False
|
||||
if X == None:
|
||||
from ..util.initialization import initialize_latent
|
||||
X, fracs = initialize_latent(init, input_dim, Y)
|
||||
|
|
@ -55,6 +57,8 @@ class BayesianGPLVM(SparseGP):
|
|||
if np.any(np.isnan(Y)):
|
||||
from ..inference.latent_function_inference.var_dtc import VarDTCMissingData
|
||||
inference_method = VarDTCMissingData()
|
||||
elif mpi_comm != None:
|
||||
inference_method = VarDTC_minibatch(mpi_comm=mpi_comm)
|
||||
else:
|
||||
from ..inference.latent_function_inference.var_dtc import VarDTC
|
||||
inference_method = VarDTC()
|
||||
|
|
@ -62,13 +66,26 @@ class BayesianGPLVM(SparseGP):
|
|||
SparseGP.__init__(self, X, Y, Z, kernel, likelihood, inference_method, name, **kwargs)
|
||||
self.add_parameter(self.X, index=0)
|
||||
|
||||
if mpi_comm != None:
|
||||
from ..util.mpi import divide_data
|
||||
Y_start, Y_end, Y_list = divide_data(Y.shape[0], mpi_comm)
|
||||
self.Y_local = self.Y[Y_start:Y_end]
|
||||
self.X_local = self.X[Y_start:Y_end]
|
||||
self.Y_range = (Y_start, Y_end)
|
||||
self.Y_list = np.array(Y_list)
|
||||
mpi_comm.Bcast(self.param_array, root=0)
|
||||
|
||||
def set_X_gradients(self, X, X_grad):
|
||||
"""Set the gradients of the posterior distribution of X in its specific form."""
|
||||
X.mean.gradient, X.variance.gradient = X_grad
|
||||
|
||||
def get_X_gradients(self, X):
|
||||
"""Get the gradients of the posterior distribution of X in its specific form."""
|
||||
return X.mean.gradient, X.variance.gradient
|
||||
|
||||
def parameters_changed(self):
|
||||
if isinstance(self.inference_method, VarDTC_GPU) or isinstance(self.inference_method, VarDTC_minibatch):
|
||||
update_gradients(self)
|
||||
update_gradients(self, mpi_comm=self.mpi_comm)
|
||||
return
|
||||
|
||||
super(BayesianGPLVM, self).parameters_changed()
|
||||
|
|
@ -160,6 +177,50 @@ class BayesianGPLVM(SparseGP):
|
|||
|
||||
return dim_reduction_plots.plot_steepest_gradient_map(self,*args,**kwargs)
|
||||
|
||||
def __getstate__(self):
|
||||
dc = super(BayesianGPLVM, self).__getstate__()
|
||||
dc['mpi_comm'] = None
|
||||
if self.mpi_comm != None:
|
||||
del dc['Y_local']
|
||||
del dc['X_local']
|
||||
del dc['Y_range']
|
||||
return dc
|
||||
|
||||
def __setstate__(self, state):
|
||||
return super(BayesianGPLVM, self).__setstate__(state)
|
||||
|
||||
#=====================================================
|
||||
# The MPI parallelization
|
||||
# - can move to model at some point
|
||||
#=====================================================
|
||||
|
||||
def _set_params_transformed(self, p):
|
||||
if self.mpi_comm != None:
|
||||
if self.__IN_OPTIMIZATION__ and self.mpi_comm.rank==0:
|
||||
self.mpi_comm.Bcast(np.int32(1),root=0)
|
||||
self.mpi_comm.Bcast(p, root=0)
|
||||
super(BayesianGPLVM, self)._set_params_transformed(p)
|
||||
|
||||
def optimize(self, optimizer=None, start=None, **kwargs):
|
||||
self.__IN_OPTIMIZATION__ = True
|
||||
if self.mpi_comm==None:
|
||||
super(BayesianGPLVM, self).optimize(optimizer,start,**kwargs)
|
||||
elif self.mpi_comm.rank==0:
|
||||
super(BayesianGPLVM, self).optimize(optimizer,start,**kwargs)
|
||||
self.mpi_comm.Bcast(np.int32(-1),root=0)
|
||||
elif self.mpi_comm.rank>0:
|
||||
x = self._get_params_transformed().copy()
|
||||
flag = np.empty(1,dtype=np.int32)
|
||||
while True:
|
||||
self.mpi_comm.Bcast(flag,root=0)
|
||||
if flag==1:
|
||||
self._set_params_transformed(x)
|
||||
elif flag==-1:
|
||||
break
|
||||
else:
|
||||
self.__IN_OPTIMIZATION__ = False
|
||||
raise Exception("Unrecognizable flag for synchronization!")
|
||||
self.__IN_OPTIMIZATION__ = False
|
||||
|
||||
def latent_cost_and_grad(mu_S, kern, Z, dL_dpsi0, dL_dpsi1, dL_dpsi2):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue