mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-08 11:32:39 +02:00
Merge branch 'reorder_choleskies' into devel
This commit is contained in:
commit
0651c933db
7 changed files with 451 additions and 463 deletions
|
|
@ -46,7 +46,7 @@ class SVGP(SparseGP):
|
|||
num_latent_functions = Y.shape[1]
|
||||
|
||||
self.m = Param('q_u_mean', np.zeros((self.num_inducing, num_latent_functions)))
|
||||
chol = choleskies.triang_to_flat(np.tile(np.eye(self.num_inducing)[:,:,None], (1,1,num_latent_functions)))
|
||||
chol = choleskies.triang_to_flat(np.tile(np.eye(self.num_inducing)[None,:,:], (num_latent_functions, 1,1)))
|
||||
self.chol = Param('q_u_chol', chol)
|
||||
self.link_parameter(self.chol)
|
||||
self.link_parameter(self.m)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from ...util import linalg
|
|||
from ...util import choleskies
|
||||
import numpy as np
|
||||
from .posterior import Posterior
|
||||
from scipy.linalg.blas import dgemm, dsymm, dtrmm
|
||||
|
||||
class SVGP(LatentFunctionInference):
|
||||
|
||||
|
|
@ -16,16 +17,13 @@ class SVGP(LatentFunctionInference):
|
|||
|
||||
|
||||
S = np.empty((num_outputs, num_inducing, num_inducing))
|
||||
[np.dot(L[:,:,i], L[:,:,i].T, S[i,:,:]) for i in range(num_outputs)]
|
||||
S = S.swapaxes(0,2)
|
||||
[np.dot(L[i,:,:], L[i,:,:].T, S[i,:,:]) for i in range(num_outputs)]
|
||||
#Si,_ = linalg.dpotri(np.asfortranarray(L), lower=1)
|
||||
Si = choleskies.multiple_dpotri(L)
|
||||
logdetS = np.array([2.*np.sum(np.log(np.abs(np.diag(L[:,:,i])))) for i in range(L.shape[-1])])
|
||||
logdetS = np.array([2.*np.sum(np.log(np.abs(np.diag(L[i,:,:])))) for i in range(L.shape[0])])
|
||||
|
||||
if np.any(np.isinf(Si)):
|
||||
raise ValueError("Cholesky representation unstable")
|
||||
#S = S + np.eye(S.shape[0])*1e-5*np.max(np.max(S))
|
||||
#Si, Lnew, _,_ = linalg.pdinv(S)
|
||||
|
||||
#compute mean function stuff
|
||||
if mean_function is not None:
|
||||
|
|
@ -35,27 +33,29 @@ class SVGP(LatentFunctionInference):
|
|||
prior_mean_u = np.zeros((num_inducing, num_outputs))
|
||||
prior_mean_f = np.zeros((num_data, num_outputs))
|
||||
|
||||
|
||||
#compute kernel related stuff
|
||||
Kmm = kern.K(Z)
|
||||
Knm = kern.K(X, Z)
|
||||
Kmn = kern.K(Z, X)
|
||||
Knn_diag = kern.Kdiag(X)
|
||||
Kmmi, Lm, Lmi, logdetKmm = linalg.pdinv(Kmm)
|
||||
Lm = linalg.jitchol(Kmm)
|
||||
logdetKmm = 2.*np.sum(np.log(np.diag(Lm)))
|
||||
Kmmi, _ = linalg.dpotri(Lm)
|
||||
|
||||
#compute the marginal means and variances of q(f)
|
||||
A = np.dot(Knm, Kmmi)
|
||||
mu = prior_mean_f + np.dot(A, q_u_mean - prior_mean_u)
|
||||
#v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * np.einsum('ij,jlk->ilk', A, S),1)
|
||||
v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * linalg.ij_jlk_to_ilk(A, S),1)
|
||||
A, _ = linalg.dpotrs(Lm, Kmn)
|
||||
mu = prior_mean_f + np.dot(A.T, q_u_mean - prior_mean_u)
|
||||
LA = L.reshape(-1, num_inducing).dot(A).reshape(num_outputs, num_inducing, num_data)
|
||||
#TODO? possibly use dtrmm for the above line?
|
||||
v = (Knn_diag - np.sum(A*Kmn,0))[:,None] + np.sum(np.square(LA),1).T
|
||||
|
||||
#compute the KL term
|
||||
Kmmim = np.dot(Kmmi, q_u_mean)
|
||||
KLs = -0.5*logdetS -0.5*num_inducing + 0.5*logdetKmm + 0.5*np.sum(Kmmi[:,:,None]*S,0).sum(0) + 0.5*np.sum(q_u_mean*Kmmim,0)
|
||||
KLs = -0.5*logdetS -0.5*num_inducing + 0.5*logdetKmm + 0.5*np.sum(Kmmi[None,:,:]*S,1).sum(1) + 0.5*np.sum(q_u_mean*Kmmim,0)
|
||||
KL = KLs.sum()
|
||||
#gradient of the KL term (assuming zero mean function)
|
||||
dKL_dm = Kmmim.copy()
|
||||
dKL_dS = 0.5*(Kmmi[:,:,None] - Si)
|
||||
dKL_dKmm = 0.5*num_outputs*Kmmi - 0.5*Kmmi.dot(S.sum(-1)).dot(Kmmi) - 0.5*Kmmim.dot(Kmmim.T)
|
||||
dKL_dS = 0.5*(Kmmi[None,:,:] - Si)
|
||||
dKL_dKmm = 0.5*num_outputs*Kmmi - 0.5*Kmmi.dot(S.sum(0)).dot(Kmmi) - 0.5*Kmmim.dot(Kmmim.T)
|
||||
|
||||
if mean_function is not None:
|
||||
#adjust KL term for mean function
|
||||
|
|
@ -80,17 +80,22 @@ class SVGP(LatentFunctionInference):
|
|||
dF_dthetaL = dF_dthetaL.sum(1).sum(1)*batch_scale
|
||||
|
||||
#derivatives of expected likelihood, assuming zero mean function
|
||||
Adv = A.T[:,:,None]*dF_dv[None,:,:] # As if dF_Dv is diagonal
|
||||
Admu = A.T.dot(dF_dmu)
|
||||
AdvA = np.dstack([np.dot(A.T, Adv[:,:,i].T) for i in range(num_outputs)])
|
||||
#tmp = np.einsum('ijk,jlk->il', AdvA, S).dot(Kmmi)
|
||||
tmp = linalg.ijk_jlk_to_il(AdvA, S).dot(Kmmi)
|
||||
dF_dKmm = -Admu.dot(Kmmim.T) + AdvA.sum(-1) - tmp - tmp.T
|
||||
Adv = A[None,:,:]*dF_dv.T[:,None,:] # As if dF_Dv is diagonal, D, M, N
|
||||
Admu = A.dot(dF_dmu)
|
||||
Adv = np.ascontiguousarray(Adv) # makes for faster operations later...(inc dsymm)
|
||||
AdvA = np.dot(Adv.reshape(-1, num_data),A.T).reshape(num_outputs, num_inducing, num_inducing )
|
||||
tmp = np.sum([np.dot(a,s) for a, s in zip(AdvA, S)],0).dot(Kmmi)
|
||||
dF_dKmm = -Admu.dot(Kmmim.T) + AdvA.sum(0) - tmp - tmp.T
|
||||
dF_dKmm = 0.5*(dF_dKmm + dF_dKmm.T) # necessary? GPy bug?
|
||||
#tmp = 2.*(np.einsum('ij,jlk->ilk', Kmmi,S) - np.eye(num_inducing)[:,:,None])
|
||||
tmp = 2.*(linalg.ij_jlk_to_ilk(Kmmi, S) - np.eye(num_inducing)[:,:,None])
|
||||
#dF_dKmn = np.einsum('ijk,jlk->il', tmp, Adv) + Kmmim.dot(dF_dmu.T)
|
||||
dF_dKmn = linalg.ijk_jlk_to_il(tmp, Adv) + Kmmim.dot(dF_dmu.T)
|
||||
tmp = S.reshape(-1, num_inducing).dot(Kmmi).reshape(num_outputs, num_inducing , num_inducing )
|
||||
tmp = 2.*(tmp - np.eye(num_inducing)[None, :,:])
|
||||
|
||||
dF_dKnm = Kmmim.dot(dF_dmu.T).T
|
||||
assert dF_dKnm.flags['F_CONTIGUOUS'] # needed for dsymm in place call:
|
||||
for a,b in zip(tmp, Adv):
|
||||
dsymm(1.0, a.T, b.T, beta=1., side=1, c=dF_dKnm, overwrite_c=True)
|
||||
dF_dKmn = dF_dKnm.T
|
||||
|
||||
dF_dm = Admu
|
||||
dF_dS = AdvA
|
||||
|
||||
|
|
@ -106,11 +111,11 @@ class SVGP(LatentFunctionInference):
|
|||
log_marginal = F.sum() - KL
|
||||
dL_dm, dL_dS, dL_dKmm, dL_dKmn = dF_dm - dKL_dm, dF_dS- dKL_dS, dF_dKmm- dKL_dKmm, dF_dKmn
|
||||
|
||||
dL_dchol = np.dstack([2.*np.dot(dL_dS[:,:,i], L[:,:,i]) for i in range(num_outputs)])
|
||||
dL_dchol = 2.*np.array([np.dot(a,b) for a, b in zip(dL_dS, L) ])
|
||||
dL_dchol = choleskies.triang_to_flat(dL_dchol)
|
||||
|
||||
grad_dict = {'dL_dKmm':dL_dKmm, 'dL_dKmn':dL_dKmn, 'dL_dKdiag': dF_dv.sum(1), 'dL_dm':dL_dm, 'dL_dchol':dL_dchol, 'dL_dthetaL':dF_dthetaL}
|
||||
if mean_function is not None:
|
||||
grad_dict['dL_dmfZ'] = dF_dmfZ - dKL_dmfZ
|
||||
grad_dict['dL_dmfX'] = dF_dmfX
|
||||
return Posterior(mean=q_u_mean, cov=S, K=Kmm, prior_mean=prior_mean_u), log_marginal, grad_dict
|
||||
return Posterior(mean=q_u_mean, cov=S.T, K=Kmm, prior_mean=prior_mean_u), log_marginal, grad_dict
|
||||
|
|
|
|||
|
|
@ -14,6 +14,22 @@ for(nd=0;nd<(D*N);nd++){
|
|||
} //grad_X
|
||||
|
||||
|
||||
void _lengthscale_grads_unsafe(int N, int M, int Q, double* tmp, double* X, double* X2, double* grad){
|
||||
int n,m,nm,q,nQ,mQ;
|
||||
double dist;
|
||||
#pragma omp parallel for private(n,m,nm,q,nQ,mQ,dist)
|
||||
for(nm=0; nm<(N*M); nm++){
|
||||
n = nm/M;
|
||||
m = nm%M;
|
||||
nQ = n*Q;
|
||||
mQ = m*Q;
|
||||
for(q=0; q<Q; q++){
|
||||
dist = X[nQ+q]-X2[mQ+q];
|
||||
grad[q] += tmp[nm]*dist*dist;
|
||||
}
|
||||
}
|
||||
} //lengthscale_grads
|
||||
|
||||
|
||||
void _lengthscale_grads(int N, int M, int Q, double* tmp, double* X, double* X2, double* grad){
|
||||
int n,m,q;
|
||||
|
|
@ -34,3 +50,5 @@ for(q=0; q<Q; q++){
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ These tests make sure that the opure python and cython codes work the same
|
|||
|
||||
class CythonTestChols(np.testing.TestCase):
|
||||
def setUp(self):
|
||||
self.flat = np.random.randn(45, 5)
|
||||
self.triang = np.dstack([np.eye(20)[:,:,None] for i in range(3)])
|
||||
self.flat = np.random.randn(5,45)
|
||||
self.triang = np.array([np.eye(20) for i in range(3)])
|
||||
def test_flat_to_triang(self):
|
||||
L1 = choleskies._flat_to_triang_pure(self.flat)
|
||||
L2 = choleskies._flat_to_triang_cython(self.flat)
|
||||
|
|
|
|||
|
|
@ -17,12 +17,12 @@ def safe_root(N):
|
|||
def _flat_to_triang_pure(flat_mat):
|
||||
N, D = flat_mat.shape
|
||||
M = (-1 + safe_root(8*N+1))//2
|
||||
ret = np.zeros((M, M, D))
|
||||
count = 0
|
||||
for m in range(M):
|
||||
for mm in range(m+1):
|
||||
for d in range(D):
|
||||
ret.flat[d + m*D*M + mm*D] = flat_mat.flat[count];
|
||||
ret = np.zeros((D, M, M))
|
||||
for d in range(D):
|
||||
count = 0
|
||||
for m in range(M):
|
||||
for mm in range(m+1):
|
||||
ret[d,m, mm] = flat_mat[count, d];
|
||||
count = count+1
|
||||
return ret
|
||||
|
||||
|
|
@ -33,15 +33,15 @@ def _flat_to_triang_cython(flat_mat):
|
|||
|
||||
|
||||
def _triang_to_flat_pure(L):
|
||||
M, _, D = L.shape
|
||||
D, _, M = L.shape
|
||||
|
||||
N = M*(M+1)//2
|
||||
flat = np.empty((N, D))
|
||||
count = 0;
|
||||
for m in range(M):
|
||||
for mm in range(m+1):
|
||||
for d in range(D):
|
||||
flat.flat[count] = L.flat[d + m*D*M + mm*D];
|
||||
for d in range(D):
|
||||
count = 0;
|
||||
for m in range(M):
|
||||
for mm in range(m+1):
|
||||
flat[count,d] = L[d, m, mm]
|
||||
count = count +1
|
||||
return flat
|
||||
|
||||
|
|
@ -74,7 +74,7 @@ def triang_to_cov(L):
|
|||
return np.dstack([np.dot(L[:,:,i], L[:,:,i].T) for i in range(L.shape[-1])])
|
||||
|
||||
def multiple_dpotri(Ls):
|
||||
return np.dstack([linalg.dpotri(np.asfortranarray(Ls[:,:,i]), lower=1)[0] for i in range(Ls.shape[-1])])
|
||||
return np.array([linalg.dpotri(np.asfortranarray(Ls[i]), lower=1)[0] for i in range(Ls.shape[0])])
|
||||
|
||||
def indexes_to_fix_for_low_rank(rank, size):
|
||||
"""
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -8,28 +8,28 @@ import numpy as np
|
|||
cimport numpy as np
|
||||
|
||||
def flat_to_triang(np.ndarray[double, ndim=2] flat, int M):
|
||||
"""take a matrix N x D and return a M X M x D array where
|
||||
"""take a matrix N x D and return a D X M x M array where
|
||||
|
||||
N = M(M+1)/2
|
||||
|
||||
the lower triangluar portion of the d'th slice of the result is filled by the d'th column of flat.
|
||||
"""
|
||||
cdef int N = flat.shape[0]
|
||||
cdef int D = flat.shape[1]
|
||||
cdef int N = flat.shape[0]
|
||||
cdef int count = 0
|
||||
cdef np.ndarray[double, ndim=3] ret = np.zeros((M, M, D))
|
||||
cdef np.ndarray[double, ndim=3] ret = np.zeros((D, M, M))
|
||||
cdef int d, m, mm
|
||||
for d in range(D):
|
||||
count = 0
|
||||
for m in range(M):
|
||||
for mm in range(m+1):
|
||||
ret[m, mm, d] = flat[count,d]
|
||||
ret[d, m, mm] = flat[count,d]
|
||||
count += 1
|
||||
return ret
|
||||
|
||||
def triang_to_flat(np.ndarray[double, ndim=3] L):
|
||||
cdef int M = L.shape[0]
|
||||
cdef int D = L.shape[2]
|
||||
cdef int D = L.shape[0]
|
||||
cdef int M = L.shape[1]
|
||||
cdef int N = M*(M+1)/2
|
||||
cdef int count = 0
|
||||
cdef np.ndarray[double, ndim=2] flat = np.empty((N, D))
|
||||
|
|
@ -38,7 +38,7 @@ def triang_to_flat(np.ndarray[double, ndim=3] L):
|
|||
count = 0
|
||||
for m in range(M):
|
||||
for mm in range(m+1):
|
||||
flat[count,d] = L[m, mm, d]
|
||||
flat[count,d] = L[d, m, mm]
|
||||
count += 1
|
||||
return flat
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue