mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-21 14:05:14 +02:00
[SSGPLVM] implemented linear kernel
This commit is contained in:
parent
16555aa11d
commit
99d6b8220c
6 changed files with 147 additions and 34 deletions
|
|
@ -6,10 +6,12 @@ import numpy as np
|
||||||
from scipy import weave
|
from scipy import weave
|
||||||
from kern import Kern
|
from kern import Kern
|
||||||
from ...util.linalg import tdot
|
from ...util.linalg import tdot
|
||||||
from ...util.misc import fast_array_equal, param_to_array
|
from ...util.misc import param_to_array
|
||||||
from ...core.parameterization import Param
|
from ...core.parameterization import Param
|
||||||
from ...core.parameterization.transformations import Logexp
|
from ...core.parameterization.transformations import Logexp
|
||||||
from ...util.caching import Cache_this
|
from ...util.caching import Cache_this
|
||||||
|
from ...core.parameterization import variational
|
||||||
|
from psi_comp import linear_psi_comp
|
||||||
|
|
||||||
class Linear(Kern):
|
class Linear(Kern):
|
||||||
"""
|
"""
|
||||||
|
|
@ -104,18 +106,52 @@ class Linear(Kern):
|
||||||
#---------------------------------------#
|
#---------------------------------------#
|
||||||
|
|
||||||
def psi0(self, Z, variational_posterior):
|
def psi0(self, Z, variational_posterior):
|
||||||
|
if isinstance(variational_posterior, variational.SpikeAndSlabPosterior):
|
||||||
|
gamma = variational_posterior.binary_prob
|
||||||
|
mu = variational_posterior.mean
|
||||||
|
S = variational_posterior.variance
|
||||||
|
return np.einsum('q,nq,nq->n',self.variances,gamma,np.square(mu)+S)
|
||||||
|
# return (self.variances*gamma*(np.square(mu)+S)).sum(axis=1)
|
||||||
|
else:
|
||||||
return np.sum(self.variances * self._mu2S(variational_posterior), 1)
|
return np.sum(self.variances * self._mu2S(variational_posterior), 1)
|
||||||
|
|
||||||
def psi1(self, Z, variational_posterior):
|
def psi1(self, Z, variational_posterior):
|
||||||
|
if isinstance(variational_posterior, variational.SpikeAndSlabPosterior):
|
||||||
|
gamma = variational_posterior.binary_prob
|
||||||
|
mu = variational_posterior.mean
|
||||||
|
return np.einsum('nq,q,mq,nq->nm',gamma,self.variances,Z,mu)
|
||||||
|
# return (self.variances*gamma*mu).sum(axis=1)
|
||||||
|
else:
|
||||||
return self.K(variational_posterior.mean, Z) #the variance, it does nothing
|
return self.K(variational_posterior.mean, Z) #the variance, it does nothing
|
||||||
|
|
||||||
@Cache_this(limit=1)
|
@Cache_this(limit=1)
|
||||||
def psi2(self, Z, variational_posterior):
|
def psi2(self, Z, variational_posterior):
|
||||||
|
if isinstance(variational_posterior, variational.SpikeAndSlabPosterior):
|
||||||
|
gamma = variational_posterior.binary_prob
|
||||||
|
mu = variational_posterior.mean
|
||||||
|
S = variational_posterior.variance
|
||||||
|
mu2 = np.square(mu)
|
||||||
|
variances2 = np.square(self.variances)
|
||||||
|
tmp = np.einsum('nq,q,mq,nq->nm',gamma,self.variances,Z,mu)
|
||||||
|
return np.einsum('nq,q,mq,oq,nq->nmo',gamma,variances2,Z,Z,mu2+S)+\
|
||||||
|
np.einsum('nm,no->nmo',tmp,tmp) - np.einsum('nq,q,mq,oq,nq->nmo',np.square(gamma),variances2,Z,Z,mu2)
|
||||||
|
else:
|
||||||
ZA = Z * self.variances
|
ZA = Z * self.variances
|
||||||
ZAinner = self._ZAinner(variational_posterior, Z)
|
ZAinner = self._ZAinner(variational_posterior, Z)
|
||||||
return np.dot(ZAinner, ZA.T)
|
return np.dot(ZAinner, ZA.T)
|
||||||
|
|
||||||
def update_gradients_expectations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
def update_gradients_expectations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
||||||
|
if isinstance(variational_posterior, variational.SpikeAndSlabPosterior):
|
||||||
|
gamma = variational_posterior.binary_prob
|
||||||
|
mu = variational_posterior.mean
|
||||||
|
S = variational_posterior.variance
|
||||||
|
mu2S = np.square(mu)+S
|
||||||
|
|
||||||
|
_dpsi2_dvariance, _, _, _, _ = linear_psi_comp._psi2computations(self.variances, Z, mu, S, gamma)
|
||||||
|
grad = np.einsum('n,nq,nq->q',dL_dpsi0,gamma,mu2S) + np.einsum('nm,nq,mq,nq->q',dL_dpsi1,gamma,Z,mu) +\
|
||||||
|
np.einsum('nmo,nmoq->q',dL_dpsi2,_dpsi2_dvariance)
|
||||||
|
self.variances.gradient = grad
|
||||||
|
else:
|
||||||
#psi1
|
#psi1
|
||||||
self.update_gradients_full(dL_dpsi1, variational_posterior.mean, Z)
|
self.update_gradients_full(dL_dpsi1, variational_posterior.mean, Z)
|
||||||
# psi0:
|
# psi0:
|
||||||
|
|
@ -130,6 +166,17 @@ class Linear(Kern):
|
||||||
self.variances.gradient += 2.*np.sum(dL_dpsi2 * self.psi2(Z, variational_posterior))/self.variances
|
self.variances.gradient += 2.*np.sum(dL_dpsi2 * self.psi2(Z, variational_posterior))/self.variances
|
||||||
|
|
||||||
def gradients_Z_expectations(self, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
def gradients_Z_expectations(self, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
||||||
|
if isinstance(variational_posterior, variational.SpikeAndSlabPosterior):
|
||||||
|
gamma = variational_posterior.binary_prob
|
||||||
|
mu = variational_posterior.mean
|
||||||
|
S = variational_posterior.variance
|
||||||
|
_, _, _, _, _dpsi2_dZ = linear_psi_comp._psi2computations(self.variances, Z, mu, S, gamma)
|
||||||
|
|
||||||
|
grad = np.einsum('nm,nq,q,nq->mq',dL_dpsi1,gamma, self.variances,mu) +\
|
||||||
|
np.einsum('nmo,noq->mq',dL_dpsi2,_dpsi2_dZ)
|
||||||
|
|
||||||
|
return grad
|
||||||
|
else:
|
||||||
#psi1
|
#psi1
|
||||||
grad = self.gradients_X(dL_dpsi1.T, Z, variational_posterior.mean)
|
grad = self.gradients_X(dL_dpsi1.T, Z, variational_posterior.mean)
|
||||||
#psi2
|
#psi2
|
||||||
|
|
@ -137,6 +184,21 @@ class Linear(Kern):
|
||||||
return grad
|
return grad
|
||||||
|
|
||||||
def gradients_qX_expectations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
def gradients_qX_expectations(self, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
||||||
|
if isinstance(variational_posterior, variational.SpikeAndSlabPosterior):
|
||||||
|
gamma = variational_posterior.binary_prob
|
||||||
|
mu = variational_posterior.mean
|
||||||
|
S = variational_posterior.variance
|
||||||
|
mu2S = np.square(mu)+S
|
||||||
|
_, _dpsi2_dgamma, _dpsi2_dmu, _dpsi2_dS, _ = linear_psi_comp._psi2computations(self.variances, Z, mu, S, gamma)
|
||||||
|
|
||||||
|
grad_gamma = np.einsum('n,q,nq->nq',dL_dpsi0,self.variances,mu2S) + np.einsum('nm,q,mq,nq->nq',dL_dpsi1,self.variances,Z,mu) +\
|
||||||
|
np.einsum('nmo,nmoq->nq',dL_dpsi2,_dpsi2_dgamma)
|
||||||
|
grad_mu = np.einsum('n,nq,q,nq->nq',dL_dpsi0,gamma,2.*self.variances,mu) + np.einsum('nm,nq,q,mq->nq',dL_dpsi1,gamma,self.variances,Z) +\
|
||||||
|
np.einsum('nmo,nmoq->nq',dL_dpsi2,_dpsi2_dmu)
|
||||||
|
grad_S = np.einsum('n,nq,q->nq',dL_dpsi0,gamma,self.variances) + np.einsum('nmo,nmoq->nq',dL_dpsi2,_dpsi2_dS)
|
||||||
|
|
||||||
|
return grad_mu, grad_S, grad_gamma
|
||||||
|
else:
|
||||||
grad_mu, grad_S = np.zeros(variational_posterior.mean.shape), np.zeros(variational_posterior.mean.shape)
|
grad_mu, grad_S = np.zeros(variational_posterior.mean.shape), np.zeros(variational_posterior.mean.shape)
|
||||||
# psi0
|
# psi0
|
||||||
grad_mu += dL_dpsi0[:, None] * (2.0 * variational_posterior.mean * self.variances)
|
grad_mu += dL_dpsi0[:, None] * (2.0 * variational_posterior.mean * self.variances)
|
||||||
|
|
|
||||||
51
GPy/kern/_src/psi_comp/linear_psi_comp.py
Normal file
51
GPy/kern/_src/psi_comp/linear_psi_comp.py
Normal file
|
|
@ -0,0 +1,51 @@
|
||||||
|
# Copyright (c) 2012, GPy authors (see AUTHORS.txt).
|
||||||
|
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||||
|
|
||||||
|
"""
|
||||||
|
The package for the Psi statistics computation of the linear kernel for SSGPLVM
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from GPy.util.caching import Cache_this
|
||||||
|
|
||||||
|
#@Cache_this(limit=1)
|
||||||
|
def _psi2computations(variance, Z, mu, S, gamma):
|
||||||
|
"""
|
||||||
|
Z - MxQ
|
||||||
|
mu - NxQ
|
||||||
|
S - NxQ
|
||||||
|
gamma - NxQ
|
||||||
|
"""
|
||||||
|
# here are the "statistics" for psi1 and psi2
|
||||||
|
# Produced intermediate results:
|
||||||
|
# _psi2 NxMxM
|
||||||
|
# _psi2_dvariance NxMxMxQ
|
||||||
|
# _psi2_dZ NxMxQ
|
||||||
|
# _psi2_dgamma NxMxMxQ
|
||||||
|
# _psi2_dmu NxMxMxQ
|
||||||
|
# _psi2_dS NxMxMxQ
|
||||||
|
|
||||||
|
mu2 = np.square(mu)
|
||||||
|
gamma2 = np.square(gamma)
|
||||||
|
variance2 = np.square(variance)
|
||||||
|
mu2S = mu2+S # NxQ
|
||||||
|
common_sum = np.einsum('nq,q,mq,nq->nm',gamma,variance,Z,mu) # NxM
|
||||||
|
|
||||||
|
_dpsi2_dvariance = np.einsum('nq,q,mq,oq->nmoq',2.*(gamma*mu2S-gamma2*mu2),variance,Z,Z)+\
|
||||||
|
np.einsum('nq,mq,nq,no->nmoq',gamma,Z,mu,common_sum)+\
|
||||||
|
np.einsum('nq,oq,nq,nm->nmoq',gamma,Z,mu,common_sum)
|
||||||
|
|
||||||
|
_dpsi2_dgamma = np.einsum('q,mq,oq,nq->nmoq',variance2,Z,Z,(mu2S-2.*gamma*mu2))+\
|
||||||
|
np.einsum('q,mq,nq,no->nmoq',variance,Z,mu,common_sum)+\
|
||||||
|
np.einsum('q,oq,nq,nm->nmoq',variance,Z,mu,common_sum)
|
||||||
|
|
||||||
|
_dpsi2_dmu = np.einsum('q,mq,oq,nq,nq->nmoq',variance2,Z,Z,mu,2.*(gamma-gamma2))+\
|
||||||
|
np.einsum('nq,q,mq,no->nmoq',gamma,variance,Z,common_sum)+\
|
||||||
|
np.einsum('nq,q,oq,nm->nmoq',gamma,variance,Z,common_sum)
|
||||||
|
|
||||||
|
_dpsi2_dS = np.einsum('nq,q,mq,oq->nmoq',gamma,variance2,Z,Z)
|
||||||
|
|
||||||
|
_dpsi2_dZ = 2.*(np.einsum('nq,q,mq,nq->nmq',gamma,variance2,Z,mu2S)+np.einsum('nq,q,nq,nm->nmq',gamma,variance,mu,common_sum)
|
||||||
|
-np.einsum('nq,q,mq,nq->nmq',gamma2,variance2,Z,mu2))
|
||||||
|
|
||||||
|
return _dpsi2_dvariance, _dpsi2_dgamma, _dpsi2_dmu, _dpsi2_dS, _dpsi2_dZ
|
||||||
|
|
@ -8,7 +8,7 @@ from ...util.misc import param_to_array
|
||||||
from stationary import Stationary
|
from stationary import Stationary
|
||||||
from GPy.util.caching import Cache_this
|
from GPy.util.caching import Cache_this
|
||||||
from ...core.parameterization import variational
|
from ...core.parameterization import variational
|
||||||
from rbf_psi_comp import ssrbf_psi_comp
|
from psi_comp import ssrbf_psi_comp
|
||||||
|
|
||||||
class RBF(Stationary):
|
class RBF(Stationary):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import numpy as np
|
||||||
from ...util.linalg import tdot
|
from ...util.linalg import tdot
|
||||||
from ...util.config import *
|
from ...util.config import *
|
||||||
from stationary import Stationary
|
from stationary import Stationary
|
||||||
from rbf_psi_comp import ssrbf_psi_comp
|
from psi_comp import ssrbf_psi_comp
|
||||||
|
|
||||||
class SSRBF(Stationary):
|
class SSRBF(Stationary):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue