mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-07 19:12:40 +02:00
Merge branch 'reorder_choleskies' into devel
This commit is contained in:
commit
0651c933db
7 changed files with 451 additions and 463 deletions
|
|
@ -46,7 +46,7 @@ class SVGP(SparseGP):
|
||||||
num_latent_functions = Y.shape[1]
|
num_latent_functions = Y.shape[1]
|
||||||
|
|
||||||
self.m = Param('q_u_mean', np.zeros((self.num_inducing, num_latent_functions)))
|
self.m = Param('q_u_mean', np.zeros((self.num_inducing, num_latent_functions)))
|
||||||
chol = choleskies.triang_to_flat(np.tile(np.eye(self.num_inducing)[:,:,None], (1,1,num_latent_functions)))
|
chol = choleskies.triang_to_flat(np.tile(np.eye(self.num_inducing)[None,:,:], (num_latent_functions, 1,1)))
|
||||||
self.chol = Param('q_u_chol', chol)
|
self.chol = Param('q_u_chol', chol)
|
||||||
self.link_parameter(self.chol)
|
self.link_parameter(self.chol)
|
||||||
self.link_parameter(self.m)
|
self.link_parameter(self.m)
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ from ...util import linalg
|
||||||
from ...util import choleskies
|
from ...util import choleskies
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .posterior import Posterior
|
from .posterior import Posterior
|
||||||
|
from scipy.linalg.blas import dgemm, dsymm, dtrmm
|
||||||
|
|
||||||
class SVGP(LatentFunctionInference):
|
class SVGP(LatentFunctionInference):
|
||||||
|
|
||||||
|
|
@ -16,16 +17,13 @@ class SVGP(LatentFunctionInference):
|
||||||
|
|
||||||
|
|
||||||
S = np.empty((num_outputs, num_inducing, num_inducing))
|
S = np.empty((num_outputs, num_inducing, num_inducing))
|
||||||
[np.dot(L[:,:,i], L[:,:,i].T, S[i,:,:]) for i in range(num_outputs)]
|
[np.dot(L[i,:,:], L[i,:,:].T, S[i,:,:]) for i in range(num_outputs)]
|
||||||
S = S.swapaxes(0,2)
|
|
||||||
#Si,_ = linalg.dpotri(np.asfortranarray(L), lower=1)
|
#Si,_ = linalg.dpotri(np.asfortranarray(L), lower=1)
|
||||||
Si = choleskies.multiple_dpotri(L)
|
Si = choleskies.multiple_dpotri(L)
|
||||||
logdetS = np.array([2.*np.sum(np.log(np.abs(np.diag(L[:,:,i])))) for i in range(L.shape[-1])])
|
logdetS = np.array([2.*np.sum(np.log(np.abs(np.diag(L[i,:,:])))) for i in range(L.shape[0])])
|
||||||
|
|
||||||
if np.any(np.isinf(Si)):
|
if np.any(np.isinf(Si)):
|
||||||
raise ValueError("Cholesky representation unstable")
|
raise ValueError("Cholesky representation unstable")
|
||||||
#S = S + np.eye(S.shape[0])*1e-5*np.max(np.max(S))
|
|
||||||
#Si, Lnew, _,_ = linalg.pdinv(S)
|
|
||||||
|
|
||||||
#compute mean function stuff
|
#compute mean function stuff
|
||||||
if mean_function is not None:
|
if mean_function is not None:
|
||||||
|
|
@ -35,27 +33,29 @@ class SVGP(LatentFunctionInference):
|
||||||
prior_mean_u = np.zeros((num_inducing, num_outputs))
|
prior_mean_u = np.zeros((num_inducing, num_outputs))
|
||||||
prior_mean_f = np.zeros((num_data, num_outputs))
|
prior_mean_f = np.zeros((num_data, num_outputs))
|
||||||
|
|
||||||
|
|
||||||
#compute kernel related stuff
|
#compute kernel related stuff
|
||||||
Kmm = kern.K(Z)
|
Kmm = kern.K(Z)
|
||||||
Knm = kern.K(X, Z)
|
Kmn = kern.K(Z, X)
|
||||||
Knn_diag = kern.Kdiag(X)
|
Knn_diag = kern.Kdiag(X)
|
||||||
Kmmi, Lm, Lmi, logdetKmm = linalg.pdinv(Kmm)
|
Lm = linalg.jitchol(Kmm)
|
||||||
|
logdetKmm = 2.*np.sum(np.log(np.diag(Lm)))
|
||||||
|
Kmmi, _ = linalg.dpotri(Lm)
|
||||||
|
|
||||||
#compute the marginal means and variances of q(f)
|
#compute the marginal means and variances of q(f)
|
||||||
A = np.dot(Knm, Kmmi)
|
A, _ = linalg.dpotrs(Lm, Kmn)
|
||||||
mu = prior_mean_f + np.dot(A, q_u_mean - prior_mean_u)
|
mu = prior_mean_f + np.dot(A.T, q_u_mean - prior_mean_u)
|
||||||
#v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * np.einsum('ij,jlk->ilk', A, S),1)
|
LA = L.reshape(-1, num_inducing).dot(A).reshape(num_outputs, num_inducing, num_data)
|
||||||
v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * linalg.ij_jlk_to_ilk(A, S),1)
|
#TODO? possibly use dtrmm for the above line?
|
||||||
|
v = (Knn_diag - np.sum(A*Kmn,0))[:,None] + np.sum(np.square(LA),1).T
|
||||||
|
|
||||||
#compute the KL term
|
#compute the KL term
|
||||||
Kmmim = np.dot(Kmmi, q_u_mean)
|
Kmmim = np.dot(Kmmi, q_u_mean)
|
||||||
KLs = -0.5*logdetS -0.5*num_inducing + 0.5*logdetKmm + 0.5*np.sum(Kmmi[:,:,None]*S,0).sum(0) + 0.5*np.sum(q_u_mean*Kmmim,0)
|
KLs = -0.5*logdetS -0.5*num_inducing + 0.5*logdetKmm + 0.5*np.sum(Kmmi[None,:,:]*S,1).sum(1) + 0.5*np.sum(q_u_mean*Kmmim,0)
|
||||||
KL = KLs.sum()
|
KL = KLs.sum()
|
||||||
#gradient of the KL term (assuming zero mean function)
|
#gradient of the KL term (assuming zero mean function)
|
||||||
dKL_dm = Kmmim.copy()
|
dKL_dm = Kmmim.copy()
|
||||||
dKL_dS = 0.5*(Kmmi[:,:,None] - Si)
|
dKL_dS = 0.5*(Kmmi[None,:,:] - Si)
|
||||||
dKL_dKmm = 0.5*num_outputs*Kmmi - 0.5*Kmmi.dot(S.sum(-1)).dot(Kmmi) - 0.5*Kmmim.dot(Kmmim.T)
|
dKL_dKmm = 0.5*num_outputs*Kmmi - 0.5*Kmmi.dot(S.sum(0)).dot(Kmmi) - 0.5*Kmmim.dot(Kmmim.T)
|
||||||
|
|
||||||
if mean_function is not None:
|
if mean_function is not None:
|
||||||
#adjust KL term for mean function
|
#adjust KL term for mean function
|
||||||
|
|
@ -80,17 +80,22 @@ class SVGP(LatentFunctionInference):
|
||||||
dF_dthetaL = dF_dthetaL.sum(1).sum(1)*batch_scale
|
dF_dthetaL = dF_dthetaL.sum(1).sum(1)*batch_scale
|
||||||
|
|
||||||
#derivatives of expected likelihood, assuming zero mean function
|
#derivatives of expected likelihood, assuming zero mean function
|
||||||
Adv = A.T[:,:,None]*dF_dv[None,:,:] # As if dF_Dv is diagonal
|
Adv = A[None,:,:]*dF_dv.T[:,None,:] # As if dF_Dv is diagonal, D, M, N
|
||||||
Admu = A.T.dot(dF_dmu)
|
Admu = A.dot(dF_dmu)
|
||||||
AdvA = np.dstack([np.dot(A.T, Adv[:,:,i].T) for i in range(num_outputs)])
|
Adv = np.ascontiguousarray(Adv) # makes for faster operations later...(inc dsymm)
|
||||||
#tmp = np.einsum('ijk,jlk->il', AdvA, S).dot(Kmmi)
|
AdvA = np.dot(Adv.reshape(-1, num_data),A.T).reshape(num_outputs, num_inducing, num_inducing )
|
||||||
tmp = linalg.ijk_jlk_to_il(AdvA, S).dot(Kmmi)
|
tmp = np.sum([np.dot(a,s) for a, s in zip(AdvA, S)],0).dot(Kmmi)
|
||||||
dF_dKmm = -Admu.dot(Kmmim.T) + AdvA.sum(-1) - tmp - tmp.T
|
dF_dKmm = -Admu.dot(Kmmim.T) + AdvA.sum(0) - tmp - tmp.T
|
||||||
dF_dKmm = 0.5*(dF_dKmm + dF_dKmm.T) # necessary? GPy bug?
|
dF_dKmm = 0.5*(dF_dKmm + dF_dKmm.T) # necessary? GPy bug?
|
||||||
#tmp = 2.*(np.einsum('ij,jlk->ilk', Kmmi,S) - np.eye(num_inducing)[:,:,None])
|
tmp = S.reshape(-1, num_inducing).dot(Kmmi).reshape(num_outputs, num_inducing , num_inducing )
|
||||||
tmp = 2.*(linalg.ij_jlk_to_ilk(Kmmi, S) - np.eye(num_inducing)[:,:,None])
|
tmp = 2.*(tmp - np.eye(num_inducing)[None, :,:])
|
||||||
#dF_dKmn = np.einsum('ijk,jlk->il', tmp, Adv) + Kmmim.dot(dF_dmu.T)
|
|
||||||
dF_dKmn = linalg.ijk_jlk_to_il(tmp, Adv) + Kmmim.dot(dF_dmu.T)
|
dF_dKnm = Kmmim.dot(dF_dmu.T).T
|
||||||
|
assert dF_dKnm.flags['F_CONTIGUOUS'] # needed for dsymm in place call:
|
||||||
|
for a,b in zip(tmp, Adv):
|
||||||
|
dsymm(1.0, a.T, b.T, beta=1., side=1, c=dF_dKnm, overwrite_c=True)
|
||||||
|
dF_dKmn = dF_dKnm.T
|
||||||
|
|
||||||
dF_dm = Admu
|
dF_dm = Admu
|
||||||
dF_dS = AdvA
|
dF_dS = AdvA
|
||||||
|
|
||||||
|
|
@ -106,11 +111,11 @@ class SVGP(LatentFunctionInference):
|
||||||
log_marginal = F.sum() - KL
|
log_marginal = F.sum() - KL
|
||||||
dL_dm, dL_dS, dL_dKmm, dL_dKmn = dF_dm - dKL_dm, dF_dS- dKL_dS, dF_dKmm- dKL_dKmm, dF_dKmn
|
dL_dm, dL_dS, dL_dKmm, dL_dKmn = dF_dm - dKL_dm, dF_dS- dKL_dS, dF_dKmm- dKL_dKmm, dF_dKmn
|
||||||
|
|
||||||
dL_dchol = np.dstack([2.*np.dot(dL_dS[:,:,i], L[:,:,i]) for i in range(num_outputs)])
|
dL_dchol = 2.*np.array([np.dot(a,b) for a, b in zip(dL_dS, L) ])
|
||||||
dL_dchol = choleskies.triang_to_flat(dL_dchol)
|
dL_dchol = choleskies.triang_to_flat(dL_dchol)
|
||||||
|
|
||||||
grad_dict = {'dL_dKmm':dL_dKmm, 'dL_dKmn':dL_dKmn, 'dL_dKdiag': dF_dv.sum(1), 'dL_dm':dL_dm, 'dL_dchol':dL_dchol, 'dL_dthetaL':dF_dthetaL}
|
grad_dict = {'dL_dKmm':dL_dKmm, 'dL_dKmn':dL_dKmn, 'dL_dKdiag': dF_dv.sum(1), 'dL_dm':dL_dm, 'dL_dchol':dL_dchol, 'dL_dthetaL':dF_dthetaL}
|
||||||
if mean_function is not None:
|
if mean_function is not None:
|
||||||
grad_dict['dL_dmfZ'] = dF_dmfZ - dKL_dmfZ
|
grad_dict['dL_dmfZ'] = dF_dmfZ - dKL_dmfZ
|
||||||
grad_dict['dL_dmfX'] = dF_dmfX
|
grad_dict['dL_dmfX'] = dF_dmfX
|
||||||
return Posterior(mean=q_u_mean, cov=S, K=Kmm, prior_mean=prior_mean_u), log_marginal, grad_dict
|
return Posterior(mean=q_u_mean, cov=S.T, K=Kmm, prior_mean=prior_mean_u), log_marginal, grad_dict
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,22 @@ for(nd=0;nd<(D*N);nd++){
|
||||||
} //grad_X
|
} //grad_X
|
||||||
|
|
||||||
|
|
||||||
|
void _lengthscale_grads_unsafe(int N, int M, int Q, double* tmp, double* X, double* X2, double* grad){
|
||||||
|
int n,m,nm,q,nQ,mQ;
|
||||||
|
double dist;
|
||||||
|
#pragma omp parallel for private(n,m,nm,q,nQ,mQ,dist)
|
||||||
|
for(nm=0; nm<(N*M); nm++){
|
||||||
|
n = nm/M;
|
||||||
|
m = nm%M;
|
||||||
|
nQ = n*Q;
|
||||||
|
mQ = m*Q;
|
||||||
|
for(q=0; q<Q; q++){
|
||||||
|
dist = X[nQ+q]-X2[mQ+q];
|
||||||
|
grad[q] += tmp[nm]*dist*dist;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} //lengthscale_grads
|
||||||
|
|
||||||
|
|
||||||
void _lengthscale_grads(int N, int M, int Q, double* tmp, double* X, double* X2, double* grad){
|
void _lengthscale_grads(int N, int M, int Q, double* tmp, double* X, double* X2, double* grad){
|
||||||
int n,m,q;
|
int n,m,q;
|
||||||
|
|
@ -34,3 +50,5 @@ for(q=0; q<Q; q++){
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,8 @@ These tests make sure that the opure python and cython codes work the same
|
||||||
|
|
||||||
class CythonTestChols(np.testing.TestCase):
|
class CythonTestChols(np.testing.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.flat = np.random.randn(45, 5)
|
self.flat = np.random.randn(5,45)
|
||||||
self.triang = np.dstack([np.eye(20)[:,:,None] for i in range(3)])
|
self.triang = np.array([np.eye(20) for i in range(3)])
|
||||||
def test_flat_to_triang(self):
|
def test_flat_to_triang(self):
|
||||||
L1 = choleskies._flat_to_triang_pure(self.flat)
|
L1 = choleskies._flat_to_triang_pure(self.flat)
|
||||||
L2 = choleskies._flat_to_triang_cython(self.flat)
|
L2 = choleskies._flat_to_triang_cython(self.flat)
|
||||||
|
|
|
||||||
|
|
@ -17,12 +17,12 @@ def safe_root(N):
|
||||||
def _flat_to_triang_pure(flat_mat):
|
def _flat_to_triang_pure(flat_mat):
|
||||||
N, D = flat_mat.shape
|
N, D = flat_mat.shape
|
||||||
M = (-1 + safe_root(8*N+1))//2
|
M = (-1 + safe_root(8*N+1))//2
|
||||||
ret = np.zeros((M, M, D))
|
ret = np.zeros((D, M, M))
|
||||||
|
for d in range(D):
|
||||||
count = 0
|
count = 0
|
||||||
for m in range(M):
|
for m in range(M):
|
||||||
for mm in range(m+1):
|
for mm in range(m+1):
|
||||||
for d in range(D):
|
ret[d,m, mm] = flat_mat[count, d];
|
||||||
ret.flat[d + m*D*M + mm*D] = flat_mat.flat[count];
|
|
||||||
count = count+1
|
count = count+1
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
@ -33,15 +33,15 @@ def _flat_to_triang_cython(flat_mat):
|
||||||
|
|
||||||
|
|
||||||
def _triang_to_flat_pure(L):
|
def _triang_to_flat_pure(L):
|
||||||
M, _, D = L.shape
|
D, _, M = L.shape
|
||||||
|
|
||||||
N = M*(M+1)//2
|
N = M*(M+1)//2
|
||||||
flat = np.empty((N, D))
|
flat = np.empty((N, D))
|
||||||
|
for d in range(D):
|
||||||
count = 0;
|
count = 0;
|
||||||
for m in range(M):
|
for m in range(M):
|
||||||
for mm in range(m+1):
|
for mm in range(m+1):
|
||||||
for d in range(D):
|
flat[count,d] = L[d, m, mm]
|
||||||
flat.flat[count] = L.flat[d + m*D*M + mm*D];
|
|
||||||
count = count +1
|
count = count +1
|
||||||
return flat
|
return flat
|
||||||
|
|
||||||
|
|
@ -74,7 +74,7 @@ def triang_to_cov(L):
|
||||||
return np.dstack([np.dot(L[:,:,i], L[:,:,i].T) for i in range(L.shape[-1])])
|
return np.dstack([np.dot(L[:,:,i], L[:,:,i].T) for i in range(L.shape[-1])])
|
||||||
|
|
||||||
def multiple_dpotri(Ls):
|
def multiple_dpotri(Ls):
|
||||||
return np.dstack([linalg.dpotri(np.asfortranarray(Ls[:,:,i]), lower=1)[0] for i in range(Ls.shape[-1])])
|
return np.array([linalg.dpotri(np.asfortranarray(Ls[i]), lower=1)[0] for i in range(Ls.shape[0])])
|
||||||
|
|
||||||
def indexes_to_fix_for_low_rank(rank, size):
|
def indexes_to_fix_for_low_rank(rank, size):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -8,28 +8,28 @@ import numpy as np
|
||||||
cimport numpy as np
|
cimport numpy as np
|
||||||
|
|
||||||
def flat_to_triang(np.ndarray[double, ndim=2] flat, int M):
|
def flat_to_triang(np.ndarray[double, ndim=2] flat, int M):
|
||||||
"""take a matrix N x D and return a M X M x D array where
|
"""take a matrix N x D and return a D X M x M array where
|
||||||
|
|
||||||
N = M(M+1)/2
|
N = M(M+1)/2
|
||||||
|
|
||||||
the lower triangluar portion of the d'th slice of the result is filled by the d'th column of flat.
|
the lower triangluar portion of the d'th slice of the result is filled by the d'th column of flat.
|
||||||
"""
|
"""
|
||||||
cdef int N = flat.shape[0]
|
|
||||||
cdef int D = flat.shape[1]
|
cdef int D = flat.shape[1]
|
||||||
|
cdef int N = flat.shape[0]
|
||||||
cdef int count = 0
|
cdef int count = 0
|
||||||
cdef np.ndarray[double, ndim=3] ret = np.zeros((M, M, D))
|
cdef np.ndarray[double, ndim=3] ret = np.zeros((D, M, M))
|
||||||
cdef int d, m, mm
|
cdef int d, m, mm
|
||||||
for d in range(D):
|
for d in range(D):
|
||||||
count = 0
|
count = 0
|
||||||
for m in range(M):
|
for m in range(M):
|
||||||
for mm in range(m+1):
|
for mm in range(m+1):
|
||||||
ret[m, mm, d] = flat[count,d]
|
ret[d, m, mm] = flat[count,d]
|
||||||
count += 1
|
count += 1
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def triang_to_flat(np.ndarray[double, ndim=3] L):
|
def triang_to_flat(np.ndarray[double, ndim=3] L):
|
||||||
cdef int M = L.shape[0]
|
cdef int D = L.shape[0]
|
||||||
cdef int D = L.shape[2]
|
cdef int M = L.shape[1]
|
||||||
cdef int N = M*(M+1)/2
|
cdef int N = M*(M+1)/2
|
||||||
cdef int count = 0
|
cdef int count = 0
|
||||||
cdef np.ndarray[double, ndim=2] flat = np.empty((N, D))
|
cdef np.ndarray[double, ndim=2] flat = np.empty((N, D))
|
||||||
|
|
@ -38,7 +38,7 @@ def triang_to_flat(np.ndarray[double, ndim=3] L):
|
||||||
count = 0
|
count = 0
|
||||||
for m in range(M):
|
for m in range(M):
|
||||||
for mm in range(m+1):
|
for mm in range(m+1):
|
||||||
flat[count,d] = L[m, mm, d]
|
flat[count,d] = L[d, m, mm]
|
||||||
count += 1
|
count += 1
|
||||||
return flat
|
return flat
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue