fix pickle for ssgplvm and bgplvm with mpi

This commit is contained in:
Zhenwen Dai 2014-06-26 14:39:02 +01:00
parent 1512590a8c
commit 8a4a4e56a9
2 changed files with 7 additions and 8 deletions

View file

@ -71,11 +71,11 @@ class BayesianGPLVM(SparseGP):
if mpi_comm != None: if mpi_comm != None:
from ..util.mpi import divide_data from ..util.mpi import divide_data
Y_start, Y_end, Y_list = divide_data(Y.shape[0], mpi_comm) N_start, N_end, N_list = divide_data(Y.shape[0], mpi_comm)
self.Y_local = self.Y[Y_start:Y_end] self.N_range = (N_start, N_end)
self.X_local = self.X[Y_start:Y_end] self.N_list = np.array(N_list)
self.Y_range = (Y_start, Y_end) self.Y_local = self.Y[N_start:N_end]
self.Y_list = np.array(Y_list) print 'MPI RANK: '+str(self.mpi_comm.rank)+' with datasize: '+str(self.N_range)
mpi_comm.Bcast(self.param_array, root=0) mpi_comm.Bcast(self.param_array, root=0)
def set_X_gradients(self, X, X_grad): def set_X_gradients(self, X, X_grad):
@ -184,9 +184,9 @@ class BayesianGPLVM(SparseGP):
dc = super(BayesianGPLVM, self).__getstate__() dc = super(BayesianGPLVM, self).__getstate__()
dc['mpi_comm'] = None dc['mpi_comm'] = None
if self.mpi_comm != None: if self.mpi_comm != None:
del dc['N_range']
del dc['N_list']
del dc['Y_local'] del dc['Y_local']
del dc['X_local']
del dc['Y_range']
return dc return dc
def __setstate__(self, state): def __setstate__(self, state):

View file

@ -125,7 +125,6 @@ class SSGPLVM(SparseGP):
del dc['N_range'] del dc['N_range']
del dc['N_list'] del dc['N_list']
del dc['Y_local'] del dc['Y_local']
del dc['X_local']
return dc return dc
def __setstate__(self, state): def __setstate__(self, state):