[ssgplvm] linear kernel

This commit is contained in:
Zhenwen Dai 2014-05-15 11:43:29 +01:00
parent b65da11df5
commit dad476faf6
7 changed files with 162 additions and 104 deletions

View file

@ -169,7 +169,7 @@ class Pickleable(object):
else:
pickle.dump(self, f, protocol)
#===========================================================================
#===========================================================================
# copy and pickling
#===========================================================================
def copy(self):

View file

@ -160,7 +160,7 @@ class SpikeAndSlabPosterior(VariationalPosterior):
else:
return super(VariationalPrior, self).__getitem__(s)
def plot(self, *args):
def plot(self, *args, **kwargs):
"""
Plot latent space X in 1D:
@ -169,4 +169,4 @@ class SpikeAndSlabPosterior(VariationalPosterior):
import sys
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from ...plotting.matplot_dep import variational_plots
return variational_plots.plot_SpikeSlab(self,*args)
return variational_plots.plot_SpikeSlab(self,*args, **kwargs)

View file

@ -29,6 +29,7 @@ class VarDTC_minibatch(object):
self.batchsize = batchsize
self.mpi_comm = mpi_comm
self.limit = limit
# Cache functions
from ...util.caching import Cacher
@ -37,6 +38,20 @@ class VarDTC_minibatch(object):
self.midRes = {}
self.batch_pos = 0 # the starting position of the current mini-batch
def __getstate__(self):
# has to be overridden, as Cacher objects cannot be pickled.
return self.batchsize, self.limit
def __setstate__(self, state):
# has to be overridden, as Cacher objects cannot be pickled.
self.batchsize, self.limit = state
self.mpi_comm = None
self.midRes = {}
self.batch_pos = 0
from ...util.caching import Cacher
self.get_trYYT = Cacher(self._get_trYYT, self.limit)
self.get_YYTfactor = Cacher(self._get_YYTfactor, self.limit)
def set_limit(self, limit):
self.get_trYYT.limit = limit
@ -334,7 +349,10 @@ def update_gradients(model, mpi_comm=None):
while not isEnd:
isEnd, n_range, grad_dict = model.inference_method.inference_minibatch(model.kern, X, model.Z, model.likelihood, Y)
if isinstance(model.X, VariationalPosterior):
X_slice = model.X[model.Y_range[0]+n_range[0]:model.Y_range[0]+n_range[1]]
if mpi_comm ==None:
X_slice = model.X[n_range[0]:n_range[1]]
else:
X_slice = model.X[model.Y_range[0]+n_range[0]:model.Y_range[0]+n_range[1]]
#gradients w.r.t. kernel
model.kern.update_gradients_expectations(variational_posterior=X_slice, Z=model.Z, dL_dpsi0=grad_dict['dL_dpsi0'], dL_dpsi1=grad_dict['dL_dpsi1'], dL_dpsi2=grad_dict['dL_dpsi2'])

View file

@ -52,6 +52,9 @@ class Linear(Kern):
self.variances = Param('variances', variances, Logexp())
self.add_parameter(self.variances)
def set_for_SpikeAndSlab(self):
self.psicomp = linear_psi_comp.PSICOMP_SSLinear()
@Cache_this(limit=2)
def K(self, X, X2=None):
@ -107,35 +110,20 @@ class Linear(Kern):
def psi0(self, Z, variational_posterior):
if isinstance(variational_posterior, variational.SpikeAndSlabPosterior):
gamma = variational_posterior.binary_prob
mu = variational_posterior.mean
S = variational_posterior.variance
return np.einsum('q,nq,nq->n',self.variances,gamma,np.square(mu)+S)
# return (self.variances*gamma*(np.square(mu)+S)).sum(axis=1)
return self.psicomp.psicomputations(self.variances, Z, variational_posterior.mean, variational_posterior.variance, variational_posterior.binary_prob)[0]
else:
return np.sum(self.variances * self._mu2S(variational_posterior), 1)
def psi1(self, Z, variational_posterior):
if isinstance(variational_posterior, variational.SpikeAndSlabPosterior):
gamma = variational_posterior.binary_prob
mu = variational_posterior.mean
return np.einsum('nq,q,mq,nq->nm',gamma,self.variances,Z,mu)
# return (self.variances*gamma*mu).sum(axis=1)
return self.psicomp.psicomputations(self.variances, Z, variational_posterior.mean, variational_posterior.variance, variational_posterior.binary_prob)[1]
else:
return self.K(variational_posterior.mean, Z) #the variance, it does nothing
@Cache_this(limit=1)
def psi2(self, Z, variational_posterior):
if isinstance(variational_posterior, variational.SpikeAndSlabPosterior):
gamma = variational_posterior.binary_prob
mu = variational_posterior.mean
S = variational_posterior.variance
mu2 = np.square(mu)
variances2 = np.square(self.variances)
tmp = np.einsum('nq,q,mq,nq->nm',gamma,self.variances,Z,mu)
return np.einsum('nq,q,mq,oq,nq->nmo',gamma,variances2,Z,Z,mu2+S)+\
np.einsum('nm,no->nmo',tmp,tmp) - np.einsum('nq,q,mq,oq,nq->nmo',np.square(gamma),variances2,Z,Z,mu2)
return self.psicomp.psicomputations(self.variances, Z, variational_posterior.mean, variational_posterior.variance, variational_posterior.binary_prob)[2]
else:
ZA = Z * self.variances
ZAinner = self._ZAinner(variational_posterior, Z)
@ -143,17 +131,11 @@ class Linear(Kern):
def update_gradients_expectations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
if isinstance(variational_posterior, variational.SpikeAndSlabPosterior):
gamma = variational_posterior.binary_prob
mu = variational_posterior.mean
S = variational_posterior.variance
mu2S = np.square(mu)+S
_dpsi2_dvariance, _, _, _, _ = linear_psi_comp._psi2computations(self.variances, Z, mu, S, gamma)
grad = np.einsum('n,nq,nq->q',dL_dpsi0,gamma,mu2S) + np.einsum('nm,nq,mq,nq->q',dL_dpsi1,gamma,Z,mu) +\
np.einsum('nmo,nmoq->q',dL_dpsi2,_dpsi2_dvariance)
dL_dvar,_,_,_,_ = self.psicomp.psiDerivativecomputations(dL_dpsi0, dL_dpsi1, dL_dpsi2, self.variances, Z, variational_posterior)
if self.ARD:
self.variances.gradient = grad
self.variances.gradient = dL_dvar
else:
self.variances.gradient = grad.sum()
self.variances.gradient = dL_dvar.sum()
else:
#psi1
self.update_gradients_full(dL_dpsi1, variational_posterior.mean, Z)
@ -170,15 +152,8 @@ class Linear(Kern):
def gradients_Z_expectations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
if isinstance(variational_posterior, variational.SpikeAndSlabPosterior):
gamma = variational_posterior.binary_prob
mu = variational_posterior.mean
S = variational_posterior.variance
_, _, _, _, _dpsi2_dZ = linear_psi_comp._psi2computations(self.variances, Z, mu, S, gamma)
grad = np.einsum('nm,nq,q,nq->mq',dL_dpsi1,gamma, self.variances,mu) +\
np.einsum('nmo,noq->mq',dL_dpsi2,_dpsi2_dZ)
return grad
_,dL_dZ,_,_,_ = self.psicomp.psiDerivativecomputations(dL_dpsi0, dL_dpsi1, dL_dpsi2, self.variances, Z, variational_posterior)
return dL_dZ
else:
#psi1
grad = self.gradients_X(dL_dpsi1.T, Z, variational_posterior.mean)
@ -188,19 +163,8 @@ class Linear(Kern):
def gradients_qX_expectations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
if isinstance(variational_posterior, variational.SpikeAndSlabPosterior):
gamma = variational_posterior.binary_prob
mu = variational_posterior.mean
S = variational_posterior.variance
mu2S = np.square(mu)+S
_, _dpsi2_dgamma, _dpsi2_dmu, _dpsi2_dS, _ = linear_psi_comp._psi2computations(self.variances, Z, mu, S, gamma)
grad_gamma = np.einsum('n,q,nq->nq',dL_dpsi0,self.variances,mu2S) + np.einsum('nm,q,mq,nq->nq',dL_dpsi1,self.variances,Z,mu) +\
np.einsum('nmo,nmoq->nq',dL_dpsi2,_dpsi2_dgamma)
grad_mu = np.einsum('n,nq,q,nq->nq',dL_dpsi0,gamma,2.*self.variances,mu) + np.einsum('nm,nq,q,mq->nq',dL_dpsi1,gamma,self.variances,Z) +\
np.einsum('nmo,nmoq->nq',dL_dpsi2,_dpsi2_dmu)
grad_S = np.einsum('n,nq,q->nq',dL_dpsi0,gamma,self.variances) + np.einsum('nmo,nmoq->nq',dL_dpsi2,_dpsi2_dS)
return grad_mu, grad_S, grad_gamma
_,_,dL_dmu, dL_dS, dL_dgamma = self.psicomp.psiDerivativecomputations(dL_dpsi0, dL_dpsi1, dL_dpsi2, self.variances, Z, variational_posterior)
return dL_dmu, dL_dS, dL_dgamma
else:
grad_mu, grad_S = np.zeros(variational_posterior.mean.shape), np.zeros(variational_posterior.mean.shape)
# psi0

View file

@ -8,44 +8,100 @@ The package for the Psi statistics computation of the linear kernel for SSGPLVM
import numpy as np
from GPy.util.caching import Cache_this
#@Cache_this(limit=1)
def _psi2computations(variance, Z, mu, S, gamma):
"""
Z - MxQ
mu - NxQ
S - NxQ
gamma - NxQ
"""
# here are the "statistics" for psi1 and psi2
# Produced intermediate results:
# _psi2 NxMxM
# _psi2_dvariance NxMxMxQ
# _psi2_dZ NxMxQ
# _psi2_dgamma NxMxMxQ
# _psi2_dmu NxMxMxQ
# _psi2_dS NxMxMxQ
mu2 = np.square(mu)
gamma2 = np.square(gamma)
variance2 = np.square(variance)
mu2S = mu2+S # NxQ
common_sum = np.einsum('nq,q,mq,nq->nm',gamma,variance,Z,mu) # NxM
_dpsi2_dvariance = np.einsum('nq,q,mq,oq->nmoq',2.*(gamma*mu2S-gamma2*mu2),variance,Z,Z)+\
np.einsum('nq,mq,nq,no->nmoq',gamma,Z,mu,common_sum)+\
np.einsum('nq,oq,nq,nm->nmoq',gamma,Z,mu,common_sum)
_dpsi2_dgamma = np.einsum('q,mq,oq,nq->nmoq',variance2,Z,Z,(mu2S-2.*gamma*mu2))+\
np.einsum('q,mq,nq,no->nmoq',variance,Z,mu,common_sum)+\
np.einsum('q,oq,nq,nm->nmoq',variance,Z,mu,common_sum)
_dpsi2_dmu = np.einsum('q,mq,oq,nq,nq->nmoq',variance2,Z,Z,mu,2.*(gamma-gamma2))+\
np.einsum('nq,q,mq,no->nmoq',gamma,variance,Z,common_sum)+\
np.einsum('nq,q,oq,nm->nmoq',gamma,variance,Z,common_sum)
_dpsi2_dS = np.einsum('nq,q,mq,oq->nmoq',gamma,variance2,Z,Z)
_dpsi2_dZ = 2.*(np.einsum('nq,q,mq,nq->nmq',gamma,variance2,Z,mu2S)+np.einsum('nq,q,nq,nm->nmq',gamma,variance,mu,common_sum)
-np.einsum('nq,q,mq,nq->nmq',gamma2,variance2,Z,mu2))
class PSICOMP_SSLinear(object):
#@Cache_this(limit=1, ignore_args=(0,))
def psicomputations(self, variance, Z, mu, S, gamma):
"""
Compute psi-statistics for ss-linear kernel
"""
# here are the "statistics" for psi0, psi1 and psi2
# Produced intermediate results:
# psi0 N
# psi1 NxM
# psi2 MxM
return _dpsi2_dvariance, _dpsi2_dgamma, _dpsi2_dmu, _dpsi2_dS, _dpsi2_dZ
psi0 = np.einsum('q,nq,nq->n',variance,gamma,np.square(mu)+S)
psi1 = np.einsum('nq,q,mq,nq->nm',gamma,variance,Z,mu)
mu2 = np.square(mu)
variances2 = np.square(variance)
tmp = np.einsum('nq,q,mq,nq->nm',gamma,variance,Z,mu)
psi2 = np.einsum('nq,q,mq,oq,nq->mo',gamma,variances2,Z,Z,mu2+S)+\
np.einsum('nm,no->mo',tmp,tmp) - np.einsum('nq,q,mq,oq,nq->mo',np.square(gamma),variances2,Z,Z,mu2)
return psi0, psi1, psi2
#@Cache_this(limit=1, ignore_args=(0,1,2,3))
def psiDerivativecomputations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, variance, Z, variational_posterior):
mu = variational_posterior.mean
S = variational_posterior.variance
gamma = variational_posterior.binary_prob
dL_dvar, dL_dgamma, dL_dmu, dL_dS, dL_dZ = self._psi2computations(dL_dpsi2, variance, Z, mu, S, gamma)
# Compute for psi0 and psi1
mu2S = np.square(mu)+S
dL_dvar += np.einsum('n,nq,nq->q',dL_dpsi0,gamma,mu2S) + np.einsum('nm,nq,mq,nq->q',dL_dpsi1,gamma,Z,mu)
dL_dgamma += np.einsum('n,q,nq->nq',dL_dpsi0,variance,mu2S) + np.einsum('nm,q,mq,nq->nq',dL_dpsi1,variance,Z,mu)
dL_dmu += np.einsum('n,nq,q,nq->nq',dL_dpsi0,gamma,2.*variance,mu) + np.einsum('nm,nq,q,mq->nq',dL_dpsi1,gamma,variance,Z)
dL_dS += np.einsum('n,nq,q->nq',dL_dpsi0,gamma,variance)
dL_dZ += np.einsum('nm,nq,q,nq->mq',dL_dpsi1,gamma, variance,mu)
return dL_dvar, dL_dZ, dL_dmu, dL_dS, dL_dgamma
def _psi2computations(self, dL_dpsi2, variance, Z, mu, S, gamma):
"""
Z - MxQ
mu - NxQ
S - NxQ
gamma - NxQ
"""
# here are the "statistics" for psi1 and psi2
# Produced intermediate results:
# _psi2_dvariance Q
# _psi2_dZ MxQ
# _psi2_dgamma NxQ
# _psi2_dmu NxQ
# _psi2_dS NxQ
mu2 = np.square(mu)
gamma2 = np.square(gamma)
variance2 = np.square(variance)
mu2S = mu2+S # NxQ
common_sum = np.einsum('nq,q,mq,nq->nm',gamma,variance,Z,mu) # NxM
# _dpsi2_dvariance = np.einsum('nq,q,mq,oq->nmoq',2.*(gamma*mu2S-gamma2*mu2),variance,Z,Z)+\
# np.einsum('nq,mq,nq,no->nmoq',gamma,Z,mu,common_sum)+\
# np.einsum('nq,oq,nq,nm->nmoq',gamma,Z,mu,common_sum)
#
# _dpsi2_dgamma = np.einsum('q,mq,oq,nq->nmoq',variance2,Z,Z,(mu2S-2.*gamma*mu2))+\
# np.einsum('q,mq,nq,no->nmoq',variance,Z,mu,common_sum)+\
# np.einsum('q,oq,nq,nm->nmoq',variance,Z,mu,common_sum)
#
_dpsi2_dmu = np.einsum('q,mq,oq,nq,nq->nmoq',variance2,Z,Z,mu,2.*(gamma-gamma2))+\
np.einsum('nq,q,mq,no->nmoq',gamma,variance,Z,common_sum)+\
np.einsum('nq,q,oq,nm->nmoq',gamma,variance,Z,common_sum)
#
# _dpsi2_dS = np.einsum('nq,q,mq,oq->nmoq',gamma,variance2,Z,Z)
#
# _dpsi2_dZ = 2.*(np.einsum('nq,q,mq,nq->nmq',gamma,variance2,Z,mu2S)+np.einsum('nq,q,nq,nm->nmq',gamma,variance,mu,common_sum)
# -np.einsum('nq,q,mq,nq->nmq',gamma2,variance2,Z,mu2))
dL_dmu = np.einsum('mo,nmoq->nq', dL_dpsi2, _dpsi2_dmu)
dL_dvar = np.einsum('mo,nq,q,mq,oq->q',dL_dpsi2,2.*(gamma*mu2S-gamma2*mu2),variance,Z,Z)+\
np.einsum('mo,nq,mq,nq,no->q',dL_dpsi2,gamma,Z,mu,common_sum)+\
np.einsum('mo,nq,oq,nq,nm->q',dL_dpsi2,gamma,Z,mu,common_sum)
dL_dgamma = np.einsum('mo,q,mq,oq,nq->nq',dL_dpsi2,variance2,Z,Z,(mu2S-2.*gamma*mu2))+\
np.einsum('mo,q,mq,nq,no->nq',dL_dpsi2,variance,Z,mu,common_sum)+\
np.einsum('mo,q,oq,nq,nm->nq',dL_dpsi2,variance,Z,mu,common_sum)
# dL_dmu = np.einsum('mo,q,mq,oq,nq,nq->nq',dL_dpsi2,variance2,Z,Z,mu,2.*(gamma-gamma2))+\
# np.einsum('mo,nq,q,mq,no->nq',dL_dpsi2,gamma,variance,Z,common_sum)+\
# np.einsum('mo,nq,q,oq,nm->nq',dL_dpsi2,gamma,variance,Z,common_sum)
dL_dS = np.einsum('mo,nq,q,mq,oq->nq',dL_dpsi2,gamma,variance2,Z,Z)
dL_dZ = 2.*(np.einsum('om,nq,q,mq,nq->oq',dL_dpsi2,gamma,variance2,Z,mu2S)+np.einsum('om,nq,q,nq,nm->oq',dL_dpsi2,gamma,variance,mu,common_sum)
-np.einsum('om,nq,q,mq,nq->oq',dL_dpsi2,gamma2,variance2,Z,mu2))
return dL_dvar, dL_dgamma, dL_dmu, dL_dS, dL_dZ

View file

@ -9,7 +9,7 @@ from stationary import Stationary
from GPy.util.caching import Cache_this
from ...core.parameterization import variational
from psi_comp import ssrbf_psi_comp
from psi_comp.ssrbf_psi_gpucomp import PSICOMP_SSRBF
from psi_comp import ssrbf_psi_gpucomp
class RBF(Stationary):
"""
@ -26,8 +26,11 @@ class RBF(Stationary):
self.weave_options = {}
self.group_spike_prob = False
def set_for_SpikeAndSlab(self):
if self.useGPU:
self.psicomp = PSICOMP_SSRBF()
self.psicomp = ssrbf_psi_gpucomp.PSICOMP_SSRBF()
else:
self.psicomp = ssrbf_psi_comp
def K_of_r(self, r):

View file

@ -44,8 +44,10 @@ class SSGPLVM(SparseGP):
X_variance = np.random.uniform(0,.1,X.shape)
gamma = np.empty_like(X, order='F') # The posterior probabilities of the binary variable in the variational approximation
#gamma[:] = 0.5 + 0.01 * np.random.randn(X.shape[0], input_dim)
gamma[:] = 0.5
gamma[:] = 0.5 + 0.1 * np.random.randn(X.shape[0], input_dim)
gamma[gamma>=1. - 1e-9] = 1e-9
gamma[gamma<1e-9] = 1e-9
#gamma[:] = 0.5
if group_spike:
gamma[:] = gamma.mean(axis=0)
@ -57,19 +59,20 @@ class SSGPLVM(SparseGP):
pi = np.empty((input_dim))
pi[:] = 0.5
if mpi_comm != None:
mpi_comm.Bcast(X, root=0)
mpi_comm.Bcast(fracs, root=0)
mpi_comm.Bcast(X_variance, root=0)
mpi_comm.Bcast(gamma, root=0)
mpi_comm.Bcast(Z, root=0)
mpi_comm.Bcast(pi, root=0)
# if mpi_comm != None:
# mpi_comm.Bcast(X, root=0)
# mpi_comm.Bcast(fracs, root=0)
# mpi_comm.Bcast(X_variance, root=0)
# mpi_comm.Bcast(gamma, root=0)
# mpi_comm.Bcast(Z, root=0)
# mpi_comm.Bcast(pi, root=0)
if likelihood is None:
likelihood = Gaussian()
if kernel is None:
kernel = kern.RBF(input_dim, lengthscale=fracs, ARD=True) # + kern.white(input_dim)
kernel.set_for_SpikeAndSlab()
self.variational_prior = SpikeAndSlabPrior(pi=pi) # the prior probability of the latent binary variable b
@ -90,6 +93,7 @@ class SSGPLVM(SparseGP):
self.X_local = self.X[Y_start:Y_end]
self.Y_range = (Y_start, Y_end)
self.Y_list = np.array(Y_list)
[mpi_comm.Bcast(p, root=0) for p in self.flattened_parameters]
def set_X_gradients(self, X, X_grad):
"""Set the gradients of the posterior distribution of X in its specific form."""
@ -125,3 +129,16 @@ class SSGPLVM(SparseGP):
return dim_reduction_plots.plot_latent(self, plot_inducing=plot_inducing, *args, **kwargs)
def __getstate__(self):
dc = super(SSGPLVM, self).__getstate__()
del dc['mpi_comm']
del dc['Y_local']
del dc['X_local']
return dc
def __setstate__(self, state):
state['mpi_comm'] = None
Y_range = state['Y_range']
state['Y_local'] = state['Y'][Y_range[0]:Y_range[1]]
state['X_local'] = state['X'][Y_range[0]:Y_range[1]]
return super(SSGPLVM, self).__setstate__(state)