mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-15 06:52:39 +02:00
[ssgplvm] linear kernel
This commit is contained in:
parent
b65da11df5
commit
dad476faf6
7 changed files with 162 additions and 104 deletions
|
|
@ -169,7 +169,7 @@ class Pickleable(object):
|
|||
else:
|
||||
pickle.dump(self, f, protocol)
|
||||
|
||||
#===========================================================================
|
||||
#===========================================================================
|
||||
# copy and pickling
|
||||
#===========================================================================
|
||||
def copy(self):
|
||||
|
|
|
|||
|
|
@ -160,7 +160,7 @@ class SpikeAndSlabPosterior(VariationalPosterior):
|
|||
else:
|
||||
return super(VariationalPrior, self).__getitem__(s)
|
||||
|
||||
def plot(self, *args):
|
||||
def plot(self, *args, **kwargs):
|
||||
"""
|
||||
Plot latent space X in 1D:
|
||||
|
||||
|
|
@ -169,4 +169,4 @@ class SpikeAndSlabPosterior(VariationalPosterior):
|
|||
import sys
|
||||
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
||||
from ...plotting.matplot_dep import variational_plots
|
||||
return variational_plots.plot_SpikeSlab(self,*args)
|
||||
return variational_plots.plot_SpikeSlab(self,*args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ class VarDTC_minibatch(object):
|
|||
|
||||
self.batchsize = batchsize
|
||||
self.mpi_comm = mpi_comm
|
||||
self.limit = limit
|
||||
|
||||
# Cache functions
|
||||
from ...util.caching import Cacher
|
||||
|
|
@ -37,6 +38,20 @@ class VarDTC_minibatch(object):
|
|||
|
||||
self.midRes = {}
|
||||
self.batch_pos = 0 # the starting position of the current mini-batch
|
||||
|
||||
def __getstate__(self):
|
||||
# has to be overridden, as Cacher objects cannot be pickled.
|
||||
return self.batchsize, self.limit
|
||||
|
||||
def __setstate__(self, state):
|
||||
# has to be overridden, as Cacher objects cannot be pickled.
|
||||
self.batchsize, self.limit = state
|
||||
self.mpi_comm = None
|
||||
self.midRes = {}
|
||||
self.batch_pos = 0
|
||||
from ...util.caching import Cacher
|
||||
self.get_trYYT = Cacher(self._get_trYYT, self.limit)
|
||||
self.get_YYTfactor = Cacher(self._get_YYTfactor, self.limit)
|
||||
|
||||
def set_limit(self, limit):
|
||||
self.get_trYYT.limit = limit
|
||||
|
|
@ -334,7 +349,10 @@ def update_gradients(model, mpi_comm=None):
|
|||
while not isEnd:
|
||||
isEnd, n_range, grad_dict = model.inference_method.inference_minibatch(model.kern, X, model.Z, model.likelihood, Y)
|
||||
if isinstance(model.X, VariationalPosterior):
|
||||
X_slice = model.X[model.Y_range[0]+n_range[0]:model.Y_range[0]+n_range[1]]
|
||||
if mpi_comm ==None:
|
||||
X_slice = model.X[n_range[0]:n_range[1]]
|
||||
else:
|
||||
X_slice = model.X[model.Y_range[0]+n_range[0]:model.Y_range[0]+n_range[1]]
|
||||
|
||||
#gradients w.r.t. kernel
|
||||
model.kern.update_gradients_expectations(variational_posterior=X_slice, Z=model.Z, dL_dpsi0=grad_dict['dL_dpsi0'], dL_dpsi1=grad_dict['dL_dpsi1'], dL_dpsi2=grad_dict['dL_dpsi2'])
|
||||
|
|
|
|||
|
|
@ -52,6 +52,9 @@ class Linear(Kern):
|
|||
|
||||
self.variances = Param('variances', variances, Logexp())
|
||||
self.add_parameter(self.variances)
|
||||
|
||||
def set_for_SpikeAndSlab(self):
|
||||
self.psicomp = linear_psi_comp.PSICOMP_SSLinear()
|
||||
|
||||
@Cache_this(limit=2)
|
||||
def K(self, X, X2=None):
|
||||
|
|
@ -107,35 +110,20 @@ class Linear(Kern):
|
|||
|
||||
def psi0(self, Z, variational_posterior):
|
||||
if isinstance(variational_posterior, variational.SpikeAndSlabPosterior):
|
||||
gamma = variational_posterior.binary_prob
|
||||
mu = variational_posterior.mean
|
||||
S = variational_posterior.variance
|
||||
|
||||
return np.einsum('q,nq,nq->n',self.variances,gamma,np.square(mu)+S)
|
||||
# return (self.variances*gamma*(np.square(mu)+S)).sum(axis=1)
|
||||
return self.psicomp.psicomputations(self.variances, Z, variational_posterior.mean, variational_posterior.variance, variational_posterior.binary_prob)[0]
|
||||
else:
|
||||
return np.sum(self.variances * self._mu2S(variational_posterior), 1)
|
||||
|
||||
def psi1(self, Z, variational_posterior):
|
||||
if isinstance(variational_posterior, variational.SpikeAndSlabPosterior):
|
||||
gamma = variational_posterior.binary_prob
|
||||
mu = variational_posterior.mean
|
||||
return np.einsum('nq,q,mq,nq->nm',gamma,self.variances,Z,mu)
|
||||
# return (self.variances*gamma*mu).sum(axis=1)
|
||||
return self.psicomp.psicomputations(self.variances, Z, variational_posterior.mean, variational_posterior.variance, variational_posterior.binary_prob)[1]
|
||||
else:
|
||||
return self.K(variational_posterior.mean, Z) #the variance, it does nothing
|
||||
|
||||
@Cache_this(limit=1)
|
||||
def psi2(self, Z, variational_posterior):
|
||||
if isinstance(variational_posterior, variational.SpikeAndSlabPosterior):
|
||||
gamma = variational_posterior.binary_prob
|
||||
mu = variational_posterior.mean
|
||||
S = variational_posterior.variance
|
||||
mu2 = np.square(mu)
|
||||
variances2 = np.square(self.variances)
|
||||
tmp = np.einsum('nq,q,mq,nq->nm',gamma,self.variances,Z,mu)
|
||||
return np.einsum('nq,q,mq,oq,nq->nmo',gamma,variances2,Z,Z,mu2+S)+\
|
||||
np.einsum('nm,no->nmo',tmp,tmp) - np.einsum('nq,q,mq,oq,nq->nmo',np.square(gamma),variances2,Z,Z,mu2)
|
||||
return self.psicomp.psicomputations(self.variances, Z, variational_posterior.mean, variational_posterior.variance, variational_posterior.binary_prob)[2]
|
||||
else:
|
||||
ZA = Z * self.variances
|
||||
ZAinner = self._ZAinner(variational_posterior, Z)
|
||||
|
|
@ -143,17 +131,11 @@ class Linear(Kern):
|
|||
|
||||
def update_gradients_expectations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
||||
if isinstance(variational_posterior, variational.SpikeAndSlabPosterior):
|
||||
gamma = variational_posterior.binary_prob
|
||||
mu = variational_posterior.mean
|
||||
S = variational_posterior.variance
|
||||
mu2S = np.square(mu)+S
|
||||
_dpsi2_dvariance, _, _, _, _ = linear_psi_comp._psi2computations(self.variances, Z, mu, S, gamma)
|
||||
grad = np.einsum('n,nq,nq->q',dL_dpsi0,gamma,mu2S) + np.einsum('nm,nq,mq,nq->q',dL_dpsi1,gamma,Z,mu) +\
|
||||
np.einsum('nmo,nmoq->q',dL_dpsi2,_dpsi2_dvariance)
|
||||
dL_dvar,_,_,_,_ = self.psicomp.psiDerivativecomputations(dL_dpsi0, dL_dpsi1, dL_dpsi2, self.variances, Z, variational_posterior)
|
||||
if self.ARD:
|
||||
self.variances.gradient = grad
|
||||
self.variances.gradient = dL_dvar
|
||||
else:
|
||||
self.variances.gradient = grad.sum()
|
||||
self.variances.gradient = dL_dvar.sum()
|
||||
else:
|
||||
#psi1
|
||||
self.update_gradients_full(dL_dpsi1, variational_posterior.mean, Z)
|
||||
|
|
@ -170,15 +152,8 @@ class Linear(Kern):
|
|||
|
||||
def gradients_Z_expectations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
||||
if isinstance(variational_posterior, variational.SpikeAndSlabPosterior):
|
||||
gamma = variational_posterior.binary_prob
|
||||
mu = variational_posterior.mean
|
||||
S = variational_posterior.variance
|
||||
_, _, _, _, _dpsi2_dZ = linear_psi_comp._psi2computations(self.variances, Z, mu, S, gamma)
|
||||
|
||||
grad = np.einsum('nm,nq,q,nq->mq',dL_dpsi1,gamma, self.variances,mu) +\
|
||||
np.einsum('nmo,noq->mq',dL_dpsi2,_dpsi2_dZ)
|
||||
|
||||
return grad
|
||||
_,dL_dZ,_,_,_ = self.psicomp.psiDerivativecomputations(dL_dpsi0, dL_dpsi1, dL_dpsi2, self.variances, Z, variational_posterior)
|
||||
return dL_dZ
|
||||
else:
|
||||
#psi1
|
||||
grad = self.gradients_X(dL_dpsi1.T, Z, variational_posterior.mean)
|
||||
|
|
@ -188,19 +163,8 @@ class Linear(Kern):
|
|||
|
||||
def gradients_qX_expectations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
||||
if isinstance(variational_posterior, variational.SpikeAndSlabPosterior):
|
||||
gamma = variational_posterior.binary_prob
|
||||
mu = variational_posterior.mean
|
||||
S = variational_posterior.variance
|
||||
mu2S = np.square(mu)+S
|
||||
_, _dpsi2_dgamma, _dpsi2_dmu, _dpsi2_dS, _ = linear_psi_comp._psi2computations(self.variances, Z, mu, S, gamma)
|
||||
|
||||
grad_gamma = np.einsum('n,q,nq->nq',dL_dpsi0,self.variances,mu2S) + np.einsum('nm,q,mq,nq->nq',dL_dpsi1,self.variances,Z,mu) +\
|
||||
np.einsum('nmo,nmoq->nq',dL_dpsi2,_dpsi2_dgamma)
|
||||
grad_mu = np.einsum('n,nq,q,nq->nq',dL_dpsi0,gamma,2.*self.variances,mu) + np.einsum('nm,nq,q,mq->nq',dL_dpsi1,gamma,self.variances,Z) +\
|
||||
np.einsum('nmo,nmoq->nq',dL_dpsi2,_dpsi2_dmu)
|
||||
grad_S = np.einsum('n,nq,q->nq',dL_dpsi0,gamma,self.variances) + np.einsum('nmo,nmoq->nq',dL_dpsi2,_dpsi2_dS)
|
||||
|
||||
return grad_mu, grad_S, grad_gamma
|
||||
_,_,dL_dmu, dL_dS, dL_dgamma = self.psicomp.psiDerivativecomputations(dL_dpsi0, dL_dpsi1, dL_dpsi2, self.variances, Z, variational_posterior)
|
||||
return dL_dmu, dL_dS, dL_dgamma
|
||||
else:
|
||||
grad_mu, grad_S = np.zeros(variational_posterior.mean.shape), np.zeros(variational_posterior.mean.shape)
|
||||
# psi0
|
||||
|
|
|
|||
|
|
@ -8,44 +8,100 @@ The package for the Psi statistics computation of the linear kernel for SSGPLVM
|
|||
import numpy as np
|
||||
from GPy.util.caching import Cache_this
|
||||
|
||||
#@Cache_this(limit=1)
|
||||
def _psi2computations(variance, Z, mu, S, gamma):
|
||||
"""
|
||||
Z - MxQ
|
||||
mu - NxQ
|
||||
S - NxQ
|
||||
gamma - NxQ
|
||||
"""
|
||||
# here are the "statistics" for psi1 and psi2
|
||||
# Produced intermediate results:
|
||||
# _psi2 NxMxM
|
||||
# _psi2_dvariance NxMxMxQ
|
||||
# _psi2_dZ NxMxQ
|
||||
# _psi2_dgamma NxMxMxQ
|
||||
# _psi2_dmu NxMxMxQ
|
||||
# _psi2_dS NxMxMxQ
|
||||
|
||||
mu2 = np.square(mu)
|
||||
gamma2 = np.square(gamma)
|
||||
variance2 = np.square(variance)
|
||||
mu2S = mu2+S # NxQ
|
||||
common_sum = np.einsum('nq,q,mq,nq->nm',gamma,variance,Z,mu) # NxM
|
||||
|
||||
_dpsi2_dvariance = np.einsum('nq,q,mq,oq->nmoq',2.*(gamma*mu2S-gamma2*mu2),variance,Z,Z)+\
|
||||
np.einsum('nq,mq,nq,no->nmoq',gamma,Z,mu,common_sum)+\
|
||||
np.einsum('nq,oq,nq,nm->nmoq',gamma,Z,mu,common_sum)
|
||||
|
||||
_dpsi2_dgamma = np.einsum('q,mq,oq,nq->nmoq',variance2,Z,Z,(mu2S-2.*gamma*mu2))+\
|
||||
np.einsum('q,mq,nq,no->nmoq',variance,Z,mu,common_sum)+\
|
||||
np.einsum('q,oq,nq,nm->nmoq',variance,Z,mu,common_sum)
|
||||
|
||||
_dpsi2_dmu = np.einsum('q,mq,oq,nq,nq->nmoq',variance2,Z,Z,mu,2.*(gamma-gamma2))+\
|
||||
np.einsum('nq,q,mq,no->nmoq',gamma,variance,Z,common_sum)+\
|
||||
np.einsum('nq,q,oq,nm->nmoq',gamma,variance,Z,common_sum)
|
||||
|
||||
_dpsi2_dS = np.einsum('nq,q,mq,oq->nmoq',gamma,variance2,Z,Z)
|
||||
|
||||
_dpsi2_dZ = 2.*(np.einsum('nq,q,mq,nq->nmq',gamma,variance2,Z,mu2S)+np.einsum('nq,q,nq,nm->nmq',gamma,variance,mu,common_sum)
|
||||
-np.einsum('nq,q,mq,nq->nmq',gamma2,variance2,Z,mu2))
|
||||
class PSICOMP_SSLinear(object):
|
||||
#@Cache_this(limit=1, ignore_args=(0,))
|
||||
def psicomputations(self, variance, Z, mu, S, gamma):
|
||||
"""
|
||||
Compute psi-statistics for ss-linear kernel
|
||||
"""
|
||||
# here are the "statistics" for psi0, psi1 and psi2
|
||||
# Produced intermediate results:
|
||||
# psi0 N
|
||||
# psi1 NxM
|
||||
# psi2 MxM
|
||||
|
||||
return _dpsi2_dvariance, _dpsi2_dgamma, _dpsi2_dmu, _dpsi2_dS, _dpsi2_dZ
|
||||
psi0 = np.einsum('q,nq,nq->n',variance,gamma,np.square(mu)+S)
|
||||
psi1 = np.einsum('nq,q,mq,nq->nm',gamma,variance,Z,mu)
|
||||
mu2 = np.square(mu)
|
||||
variances2 = np.square(variance)
|
||||
tmp = np.einsum('nq,q,mq,nq->nm',gamma,variance,Z,mu)
|
||||
psi2 = np.einsum('nq,q,mq,oq,nq->mo',gamma,variances2,Z,Z,mu2+S)+\
|
||||
np.einsum('nm,no->mo',tmp,tmp) - np.einsum('nq,q,mq,oq,nq->mo',np.square(gamma),variances2,Z,Z,mu2)
|
||||
|
||||
return psi0, psi1, psi2
|
||||
|
||||
#@Cache_this(limit=1, ignore_args=(0,1,2,3))
|
||||
def psiDerivativecomputations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, variance, Z, variational_posterior):
|
||||
mu = variational_posterior.mean
|
||||
S = variational_posterior.variance
|
||||
gamma = variational_posterior.binary_prob
|
||||
|
||||
dL_dvar, dL_dgamma, dL_dmu, dL_dS, dL_dZ = self._psi2computations(dL_dpsi2, variance, Z, mu, S, gamma)
|
||||
|
||||
# Compute for psi0 and psi1
|
||||
mu2S = np.square(mu)+S
|
||||
dL_dvar += np.einsum('n,nq,nq->q',dL_dpsi0,gamma,mu2S) + np.einsum('nm,nq,mq,nq->q',dL_dpsi1,gamma,Z,mu)
|
||||
dL_dgamma += np.einsum('n,q,nq->nq',dL_dpsi0,variance,mu2S) + np.einsum('nm,q,mq,nq->nq',dL_dpsi1,variance,Z,mu)
|
||||
dL_dmu += np.einsum('n,nq,q,nq->nq',dL_dpsi0,gamma,2.*variance,mu) + np.einsum('nm,nq,q,mq->nq',dL_dpsi1,gamma,variance,Z)
|
||||
dL_dS += np.einsum('n,nq,q->nq',dL_dpsi0,gamma,variance)
|
||||
dL_dZ += np.einsum('nm,nq,q,nq->mq',dL_dpsi1,gamma, variance,mu)
|
||||
|
||||
return dL_dvar, dL_dZ, dL_dmu, dL_dS, dL_dgamma
|
||||
|
||||
def _psi2computations(self, dL_dpsi2, variance, Z, mu, S, gamma):
|
||||
"""
|
||||
Z - MxQ
|
||||
mu - NxQ
|
||||
S - NxQ
|
||||
gamma - NxQ
|
||||
"""
|
||||
# here are the "statistics" for psi1 and psi2
|
||||
# Produced intermediate results:
|
||||
# _psi2_dvariance Q
|
||||
# _psi2_dZ MxQ
|
||||
# _psi2_dgamma NxQ
|
||||
# _psi2_dmu NxQ
|
||||
# _psi2_dS NxQ
|
||||
|
||||
mu2 = np.square(mu)
|
||||
gamma2 = np.square(gamma)
|
||||
variance2 = np.square(variance)
|
||||
mu2S = mu2+S # NxQ
|
||||
common_sum = np.einsum('nq,q,mq,nq->nm',gamma,variance,Z,mu) # NxM
|
||||
|
||||
# _dpsi2_dvariance = np.einsum('nq,q,mq,oq->nmoq',2.*(gamma*mu2S-gamma2*mu2),variance,Z,Z)+\
|
||||
# np.einsum('nq,mq,nq,no->nmoq',gamma,Z,mu,common_sum)+\
|
||||
# np.einsum('nq,oq,nq,nm->nmoq',gamma,Z,mu,common_sum)
|
||||
#
|
||||
# _dpsi2_dgamma = np.einsum('q,mq,oq,nq->nmoq',variance2,Z,Z,(mu2S-2.*gamma*mu2))+\
|
||||
# np.einsum('q,mq,nq,no->nmoq',variance,Z,mu,common_sum)+\
|
||||
# np.einsum('q,oq,nq,nm->nmoq',variance,Z,mu,common_sum)
|
||||
#
|
||||
_dpsi2_dmu = np.einsum('q,mq,oq,nq,nq->nmoq',variance2,Z,Z,mu,2.*(gamma-gamma2))+\
|
||||
np.einsum('nq,q,mq,no->nmoq',gamma,variance,Z,common_sum)+\
|
||||
np.einsum('nq,q,oq,nm->nmoq',gamma,variance,Z,common_sum)
|
||||
#
|
||||
# _dpsi2_dS = np.einsum('nq,q,mq,oq->nmoq',gamma,variance2,Z,Z)
|
||||
#
|
||||
# _dpsi2_dZ = 2.*(np.einsum('nq,q,mq,nq->nmq',gamma,variance2,Z,mu2S)+np.einsum('nq,q,nq,nm->nmq',gamma,variance,mu,common_sum)
|
||||
# -np.einsum('nq,q,mq,nq->nmq',gamma2,variance2,Z,mu2))
|
||||
dL_dmu = np.einsum('mo,nmoq->nq', dL_dpsi2, _dpsi2_dmu)
|
||||
|
||||
dL_dvar = np.einsum('mo,nq,q,mq,oq->q',dL_dpsi2,2.*(gamma*mu2S-gamma2*mu2),variance,Z,Z)+\
|
||||
np.einsum('mo,nq,mq,nq,no->q',dL_dpsi2,gamma,Z,mu,common_sum)+\
|
||||
np.einsum('mo,nq,oq,nq,nm->q',dL_dpsi2,gamma,Z,mu,common_sum)
|
||||
|
||||
dL_dgamma = np.einsum('mo,q,mq,oq,nq->nq',dL_dpsi2,variance2,Z,Z,(mu2S-2.*gamma*mu2))+\
|
||||
np.einsum('mo,q,mq,nq,no->nq',dL_dpsi2,variance,Z,mu,common_sum)+\
|
||||
np.einsum('mo,q,oq,nq,nm->nq',dL_dpsi2,variance,Z,mu,common_sum)
|
||||
|
||||
# dL_dmu = np.einsum('mo,q,mq,oq,nq,nq->nq',dL_dpsi2,variance2,Z,Z,mu,2.*(gamma-gamma2))+\
|
||||
# np.einsum('mo,nq,q,mq,no->nq',dL_dpsi2,gamma,variance,Z,common_sum)+\
|
||||
# np.einsum('mo,nq,q,oq,nm->nq',dL_dpsi2,gamma,variance,Z,common_sum)
|
||||
|
||||
dL_dS = np.einsum('mo,nq,q,mq,oq->nq',dL_dpsi2,gamma,variance2,Z,Z)
|
||||
|
||||
dL_dZ = 2.*(np.einsum('om,nq,q,mq,nq->oq',dL_dpsi2,gamma,variance2,Z,mu2S)+np.einsum('om,nq,q,nq,nm->oq',dL_dpsi2,gamma,variance,mu,common_sum)
|
||||
-np.einsum('om,nq,q,mq,nq->oq',dL_dpsi2,gamma2,variance2,Z,mu2))
|
||||
|
||||
return dL_dvar, dL_dgamma, dL_dmu, dL_dS, dL_dZ
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from stationary import Stationary
|
|||
from GPy.util.caching import Cache_this
|
||||
from ...core.parameterization import variational
|
||||
from psi_comp import ssrbf_psi_comp
|
||||
from psi_comp.ssrbf_psi_gpucomp import PSICOMP_SSRBF
|
||||
from psi_comp import ssrbf_psi_gpucomp
|
||||
|
||||
class RBF(Stationary):
|
||||
"""
|
||||
|
|
@ -26,8 +26,11 @@ class RBF(Stationary):
|
|||
self.weave_options = {}
|
||||
self.group_spike_prob = False
|
||||
|
||||
def set_for_SpikeAndSlab(self):
|
||||
if self.useGPU:
|
||||
self.psicomp = PSICOMP_SSRBF()
|
||||
self.psicomp = ssrbf_psi_gpucomp.PSICOMP_SSRBF()
|
||||
else:
|
||||
self.psicomp = ssrbf_psi_comp
|
||||
|
||||
|
||||
def K_of_r(self, r):
|
||||
|
|
|
|||
|
|
@ -44,8 +44,10 @@ class SSGPLVM(SparseGP):
|
|||
X_variance = np.random.uniform(0,.1,X.shape)
|
||||
|
||||
gamma = np.empty_like(X, order='F') # The posterior probabilities of the binary variable in the variational approximation
|
||||
#gamma[:] = 0.5 + 0.01 * np.random.randn(X.shape[0], input_dim)
|
||||
gamma[:] = 0.5
|
||||
gamma[:] = 0.5 + 0.1 * np.random.randn(X.shape[0], input_dim)
|
||||
gamma[gamma>=1. - 1e-9] = 1e-9
|
||||
gamma[gamma<1e-9] = 1e-9
|
||||
#gamma[:] = 0.5
|
||||
|
||||
if group_spike:
|
||||
gamma[:] = gamma.mean(axis=0)
|
||||
|
|
@ -57,19 +59,20 @@ class SSGPLVM(SparseGP):
|
|||
pi = np.empty((input_dim))
|
||||
pi[:] = 0.5
|
||||
|
||||
if mpi_comm != None:
|
||||
mpi_comm.Bcast(X, root=0)
|
||||
mpi_comm.Bcast(fracs, root=0)
|
||||
mpi_comm.Bcast(X_variance, root=0)
|
||||
mpi_comm.Bcast(gamma, root=0)
|
||||
mpi_comm.Bcast(Z, root=0)
|
||||
mpi_comm.Bcast(pi, root=0)
|
||||
# if mpi_comm != None:
|
||||
# mpi_comm.Bcast(X, root=0)
|
||||
# mpi_comm.Bcast(fracs, root=0)
|
||||
# mpi_comm.Bcast(X_variance, root=0)
|
||||
# mpi_comm.Bcast(gamma, root=0)
|
||||
# mpi_comm.Bcast(Z, root=0)
|
||||
# mpi_comm.Bcast(pi, root=0)
|
||||
|
||||
if likelihood is None:
|
||||
likelihood = Gaussian()
|
||||
|
||||
if kernel is None:
|
||||
kernel = kern.RBF(input_dim, lengthscale=fracs, ARD=True) # + kern.white(input_dim)
|
||||
kernel.set_for_SpikeAndSlab()
|
||||
|
||||
self.variational_prior = SpikeAndSlabPrior(pi=pi) # the prior probability of the latent binary variable b
|
||||
|
||||
|
|
@ -90,6 +93,7 @@ class SSGPLVM(SparseGP):
|
|||
self.X_local = self.X[Y_start:Y_end]
|
||||
self.Y_range = (Y_start, Y_end)
|
||||
self.Y_list = np.array(Y_list)
|
||||
[mpi_comm.Bcast(p, root=0) for p in self.flattened_parameters]
|
||||
|
||||
def set_X_gradients(self, X, X_grad):
|
||||
"""Set the gradients of the posterior distribution of X in its specific form."""
|
||||
|
|
@ -125,3 +129,16 @@ class SSGPLVM(SparseGP):
|
|||
|
||||
return dim_reduction_plots.plot_latent(self, plot_inducing=plot_inducing, *args, **kwargs)
|
||||
|
||||
def __getstate__(self):
|
||||
dc = super(SSGPLVM, self).__getstate__()
|
||||
del dc['mpi_comm']
|
||||
del dc['Y_local']
|
||||
del dc['X_local']
|
||||
return dc
|
||||
|
||||
def __setstate__(self, state):
|
||||
state['mpi_comm'] = None
|
||||
Y_range = state['Y_range']
|
||||
state['Y_local'] = state['Y'][Y_range[0]:Y_range[1]]
|
||||
state['X_local'] = state['X'][Y_range[0]:Y_range[1]]
|
||||
return super(SSGPLVM, self).__setstate__(state)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue