mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-05 09:42:39 +02:00
fix pickle
This commit is contained in:
parent
17b6b94db3
commit
e6d07ad5ac
7 changed files with 230 additions and 275 deletions
|
|
@ -59,20 +59,15 @@ class SSGPLVM(SparseGP):
|
|||
pi = np.empty((input_dim))
|
||||
pi[:] = 0.5
|
||||
|
||||
# if mpi_comm != None:
|
||||
# mpi_comm.Bcast(X, root=0)
|
||||
# mpi_comm.Bcast(fracs, root=0)
|
||||
# mpi_comm.Bcast(X_variance, root=0)
|
||||
# mpi_comm.Bcast(gamma, root=0)
|
||||
# mpi_comm.Bcast(Z, root=0)
|
||||
# mpi_comm.Bcast(pi, root=0)
|
||||
|
||||
if likelihood is None:
|
||||
likelihood = Gaussian()
|
||||
|
||||
if kernel is None:
|
||||
kernel = kern.RBF(input_dim, lengthscale=fracs, ARD=True) # + kern.white(input_dim)
|
||||
kernel.set_for_SpikeAndSlab()
|
||||
|
||||
if inference_method is None:
|
||||
inference_method = VarDTC_minibatch(mpi_comm=mpi_comm)
|
||||
|
||||
self.variational_prior = SpikeAndSlabPrior(pi=pi) # the prior probability of the latent binary variable b
|
||||
|
||||
|
|
@ -131,16 +126,14 @@ class SSGPLVM(SparseGP):
|
|||
|
||||
def __getstate__(self):
|
||||
dc = super(SSGPLVM, self).__getstate__()
|
||||
del dc['mpi_comm']
|
||||
del dc['Y_local']
|
||||
del dc['X_local']
|
||||
dc['mpi_comm'] = None
|
||||
if self.mpi_comm != None:
|
||||
del dc['Y_local']
|
||||
del dc['X_local']
|
||||
del dc['Y_range']
|
||||
return dc
|
||||
|
||||
def __setstate__(self, state):
|
||||
state['mpi_comm'] = None
|
||||
Y_range = state['Y_range']
|
||||
state['Y_local'] = state['Y'][Y_range[0]:Y_range[1]]
|
||||
state['X_local'] = state['X'][Y_range[0]:Y_range[1]]
|
||||
return super(SSGPLVM, self).__setstate__(state)
|
||||
|
||||
def _grads(self, x):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue