fix pickle

This commit is contained in:
Zhenwen Dai 2014-05-16 11:23:44 +01:00
parent 17b6b94db3
commit e6d07ad5ac
7 changed files with 230 additions and 275 deletions

View file

@ -59,20 +59,15 @@ class SSGPLVM(SparseGP):
pi = np.empty((input_dim))
pi[:] = 0.5
# if mpi_comm != None:
# mpi_comm.Bcast(X, root=0)
# mpi_comm.Bcast(fracs, root=0)
# mpi_comm.Bcast(X_variance, root=0)
# mpi_comm.Bcast(gamma, root=0)
# mpi_comm.Bcast(Z, root=0)
# mpi_comm.Bcast(pi, root=0)
if likelihood is None:
likelihood = Gaussian()
if kernel is None:
kernel = kern.RBF(input_dim, lengthscale=fracs, ARD=True) # + kern.white(input_dim)
kernel.set_for_SpikeAndSlab()
if inference_method is None:
inference_method = VarDTC_minibatch(mpi_comm=mpi_comm)
self.variational_prior = SpikeAndSlabPrior(pi=pi) # the prior probability of the latent binary variable b
@ -131,16 +126,14 @@ class SSGPLVM(SparseGP):
def __getstate__(self):
dc = super(SSGPLVM, self).__getstate__()
del dc['mpi_comm']
del dc['Y_local']
del dc['X_local']
dc['mpi_comm'] = None
if self.mpi_comm != None:
del dc['Y_local']
del dc['X_local']
del dc['Y_range']
return dc
def __setstate__(self, state):
state['mpi_comm'] = None
Y_range = state['Y_range']
state['Y_local'] = state['Y'][Y_range[0]:Y_range[1]]
state['X_local'] = state['X'][Y_range[0]:Y_range[1]]
return super(SSGPLVM, self).__setstate__(state)
def _grads(self, x):