Merge branch 'reorder_choleskies' into devel

This commit is contained in:
James Hensman 2015-05-15 09:00:18 +01:00
commit 0651c933db
7 changed files with 451 additions and 463 deletions

View file

@ -46,7 +46,7 @@ class SVGP(SparseGP):
num_latent_functions = Y.shape[1]
self.m = Param('q_u_mean', np.zeros((self.num_inducing, num_latent_functions)))
chol = choleskies.triang_to_flat(np.tile(np.eye(self.num_inducing)[:,:,None], (1,1,num_latent_functions)))
chol = choleskies.triang_to_flat(np.tile(np.eye(self.num_inducing)[None,:,:], (num_latent_functions, 1,1)))
self.chol = Param('q_u_chol', chol)
self.link_parameter(self.chol)
self.link_parameter(self.m)

View file

@ -3,6 +3,7 @@ from ...util import linalg
from ...util import choleskies
import numpy as np
from .posterior import Posterior
from scipy.linalg.blas import dgemm, dsymm, dtrmm
class SVGP(LatentFunctionInference):
@ -16,16 +17,13 @@ class SVGP(LatentFunctionInference):
S = np.empty((num_outputs, num_inducing, num_inducing))
[np.dot(L[:,:,i], L[:,:,i].T, S[i,:,:]) for i in range(num_outputs)]
S = S.swapaxes(0,2)
[np.dot(L[i,:,:], L[i,:,:].T, S[i,:,:]) for i in range(num_outputs)]
#Si,_ = linalg.dpotri(np.asfortranarray(L), lower=1)
Si = choleskies.multiple_dpotri(L)
logdetS = np.array([2.*np.sum(np.log(np.abs(np.diag(L[:,:,i])))) for i in range(L.shape[-1])])
logdetS = np.array([2.*np.sum(np.log(np.abs(np.diag(L[i,:,:])))) for i in range(L.shape[0])])
if np.any(np.isinf(Si)):
raise ValueError("Cholesky representation unstable")
#S = S + np.eye(S.shape[0])*1e-5*np.max(np.max(S))
#Si, Lnew, _,_ = linalg.pdinv(S)
#compute mean function stuff
if mean_function is not None:
@ -35,27 +33,29 @@ class SVGP(LatentFunctionInference):
prior_mean_u = np.zeros((num_inducing, num_outputs))
prior_mean_f = np.zeros((num_data, num_outputs))
#compute kernel related stuff
Kmm = kern.K(Z)
Knm = kern.K(X, Z)
Kmn = kern.K(Z, X)
Knn_diag = kern.Kdiag(X)
Kmmi, Lm, Lmi, logdetKmm = linalg.pdinv(Kmm)
Lm = linalg.jitchol(Kmm)
logdetKmm = 2.*np.sum(np.log(np.diag(Lm)))
Kmmi, _ = linalg.dpotri(Lm)
#compute the marginal means and variances of q(f)
A = np.dot(Knm, Kmmi)
mu = prior_mean_f + np.dot(A, q_u_mean - prior_mean_u)
#v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * np.einsum('ij,jlk->ilk', A, S),1)
v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * linalg.ij_jlk_to_ilk(A, S),1)
A, _ = linalg.dpotrs(Lm, Kmn)
mu = prior_mean_f + np.dot(A.T, q_u_mean - prior_mean_u)
LA = L.reshape(-1, num_inducing).dot(A).reshape(num_outputs, num_inducing, num_data)
#TODO? possibly use dtrmm for the above line?
v = (Knn_diag - np.sum(A*Kmn,0))[:,None] + np.sum(np.square(LA),1).T
#compute the KL term
Kmmim = np.dot(Kmmi, q_u_mean)
KLs = -0.5*logdetS -0.5*num_inducing + 0.5*logdetKmm + 0.5*np.sum(Kmmi[:,:,None]*S,0).sum(0) + 0.5*np.sum(q_u_mean*Kmmim,0)
KLs = -0.5*logdetS -0.5*num_inducing + 0.5*logdetKmm + 0.5*np.sum(Kmmi[None,:,:]*S,1).sum(1) + 0.5*np.sum(q_u_mean*Kmmim,0)
KL = KLs.sum()
#gradient of the KL term (assuming zero mean function)
dKL_dm = Kmmim.copy()
dKL_dS = 0.5*(Kmmi[:,:,None] - Si)
dKL_dKmm = 0.5*num_outputs*Kmmi - 0.5*Kmmi.dot(S.sum(-1)).dot(Kmmi) - 0.5*Kmmim.dot(Kmmim.T)
dKL_dS = 0.5*(Kmmi[None,:,:] - Si)
dKL_dKmm = 0.5*num_outputs*Kmmi - 0.5*Kmmi.dot(S.sum(0)).dot(Kmmi) - 0.5*Kmmim.dot(Kmmim.T)
if mean_function is not None:
#adjust KL term for mean function
@ -80,17 +80,22 @@ class SVGP(LatentFunctionInference):
dF_dthetaL = dF_dthetaL.sum(1).sum(1)*batch_scale
#derivatives of expected likelihood, assuming zero mean function
Adv = A.T[:,:,None]*dF_dv[None,:,:] # As if dF_Dv is diagonal
Admu = A.T.dot(dF_dmu)
AdvA = np.dstack([np.dot(A.T, Adv[:,:,i].T) for i in range(num_outputs)])
#tmp = np.einsum('ijk,jlk->il', AdvA, S).dot(Kmmi)
tmp = linalg.ijk_jlk_to_il(AdvA, S).dot(Kmmi)
dF_dKmm = -Admu.dot(Kmmim.T) + AdvA.sum(-1) - tmp - tmp.T
Adv = A[None,:,:]*dF_dv.T[:,None,:] # As if dF_Dv is diagonal, D, M, N
Admu = A.dot(dF_dmu)
Adv = np.ascontiguousarray(Adv) # makes for faster operations later...(inc dsymm)
AdvA = np.dot(Adv.reshape(-1, num_data),A.T).reshape(num_outputs, num_inducing, num_inducing )
tmp = np.sum([np.dot(a,s) for a, s in zip(AdvA, S)],0).dot(Kmmi)
dF_dKmm = -Admu.dot(Kmmim.T) + AdvA.sum(0) - tmp - tmp.T
dF_dKmm = 0.5*(dF_dKmm + dF_dKmm.T) # necessary? GPy bug?
#tmp = 2.*(np.einsum('ij,jlk->ilk', Kmmi,S) - np.eye(num_inducing)[:,:,None])
tmp = 2.*(linalg.ij_jlk_to_ilk(Kmmi, S) - np.eye(num_inducing)[:,:,None])
#dF_dKmn = np.einsum('ijk,jlk->il', tmp, Adv) + Kmmim.dot(dF_dmu.T)
dF_dKmn = linalg.ijk_jlk_to_il(tmp, Adv) + Kmmim.dot(dF_dmu.T)
tmp = S.reshape(-1, num_inducing).dot(Kmmi).reshape(num_outputs, num_inducing , num_inducing )
tmp = 2.*(tmp - np.eye(num_inducing)[None, :,:])
dF_dKnm = Kmmim.dot(dF_dmu.T).T
assert dF_dKnm.flags['F_CONTIGUOUS'] # needed for dsymm in place call:
for a,b in zip(tmp, Adv):
dsymm(1.0, a.T, b.T, beta=1., side=1, c=dF_dKnm, overwrite_c=True)
dF_dKmn = dF_dKnm.T
dF_dm = Admu
dF_dS = AdvA
@ -106,11 +111,11 @@ class SVGP(LatentFunctionInference):
log_marginal = F.sum() - KL
dL_dm, dL_dS, dL_dKmm, dL_dKmn = dF_dm - dKL_dm, dF_dS- dKL_dS, dF_dKmm- dKL_dKmm, dF_dKmn
dL_dchol = np.dstack([2.*np.dot(dL_dS[:,:,i], L[:,:,i]) for i in range(num_outputs)])
dL_dchol = 2.*np.array([np.dot(a,b) for a, b in zip(dL_dS, L) ])
dL_dchol = choleskies.triang_to_flat(dL_dchol)
grad_dict = {'dL_dKmm':dL_dKmm, 'dL_dKmn':dL_dKmn, 'dL_dKdiag': dF_dv.sum(1), 'dL_dm':dL_dm, 'dL_dchol':dL_dchol, 'dL_dthetaL':dF_dthetaL}
if mean_function is not None:
grad_dict['dL_dmfZ'] = dF_dmfZ - dKL_dmfZ
grad_dict['dL_dmfX'] = dF_dmfX
return Posterior(mean=q_u_mean, cov=S, K=Kmm, prior_mean=prior_mean_u), log_marginal, grad_dict
return Posterior(mean=q_u_mean, cov=S.T, K=Kmm, prior_mean=prior_mean_u), log_marginal, grad_dict

View file

@ -14,6 +14,22 @@ for(nd=0;nd<(D*N);nd++){
} //grad_X
void _lengthscale_grads_unsafe(int N, int M, int Q, double* tmp, double* X, double* X2, double* grad){
int n,m,nm,q,nQ,mQ;
double dist;
#pragma omp parallel for private(n,m,nm,q,nQ,mQ,dist)
for(nm=0; nm<(N*M); nm++){
n = nm/M;
m = nm%M;
nQ = n*Q;
mQ = m*Q;
for(q=0; q<Q; q++){
dist = X[nQ+q]-X2[mQ+q];
grad[q] += tmp[nm]*dist*dist;
}
}
} //lengthscale_grads
void _lengthscale_grads(int N, int M, int Q, double* tmp, double* X, double* X2, double* grad){
int n,m,q;
@ -34,3 +50,5 @@ for(q=0; q<Q; q++){

View file

@ -9,8 +9,8 @@ These tests make sure that the opure python and cython codes work the same
class CythonTestChols(np.testing.TestCase):
def setUp(self):
self.flat = np.random.randn(45, 5)
self.triang = np.dstack([np.eye(20)[:,:,None] for i in range(3)])
self.flat = np.random.randn(5,45)
self.triang = np.array([np.eye(20) for i in range(3)])
def test_flat_to_triang(self):
L1 = choleskies._flat_to_triang_pure(self.flat)
L2 = choleskies._flat_to_triang_cython(self.flat)

View file

@ -17,12 +17,12 @@ def safe_root(N):
def _flat_to_triang_pure(flat_mat):
N, D = flat_mat.shape
M = (-1 + safe_root(8*N+1))//2
ret = np.zeros((M, M, D))
count = 0
for m in range(M):
for mm in range(m+1):
for d in range(D):
ret.flat[d + m*D*M + mm*D] = flat_mat.flat[count];
ret = np.zeros((D, M, M))
for d in range(D):
count = 0
for m in range(M):
for mm in range(m+1):
ret[d,m, mm] = flat_mat[count, d];
count = count+1
return ret
@ -33,15 +33,15 @@ def _flat_to_triang_cython(flat_mat):
def _triang_to_flat_pure(L):
M, _, D = L.shape
D, _, M = L.shape
N = M*(M+1)//2
flat = np.empty((N, D))
count = 0;
for m in range(M):
for mm in range(m+1):
for d in range(D):
flat.flat[count] = L.flat[d + m*D*M + mm*D];
for d in range(D):
count = 0;
for m in range(M):
for mm in range(m+1):
flat[count,d] = L[d, m, mm]
count = count +1
return flat
@ -74,7 +74,7 @@ def triang_to_cov(L):
return np.dstack([np.dot(L[:,:,i], L[:,:,i].T) for i in range(L.shape[-1])])
def multiple_dpotri(Ls):
return np.dstack([linalg.dpotri(np.asfortranarray(Ls[:,:,i]), lower=1)[0] for i in range(Ls.shape[-1])])
return np.array([linalg.dpotri(np.asfortranarray(Ls[i]), lower=1)[0] for i in range(Ls.shape[0])])
def indexes_to_fix_for_low_rank(rank, size):
"""

File diff suppressed because it is too large Load diff

View file

@ -8,28 +8,28 @@ import numpy as np
cimport numpy as np
def flat_to_triang(np.ndarray[double, ndim=2] flat, int M):
"""take a matrix N x D and return a M X M x D array where
"""take a matrix N x D and return a D X M x M array where
N = M(M+1)/2
the lower triangluar portion of the d'th slice of the result is filled by the d'th column of flat.
"""
cdef int N = flat.shape[0]
cdef int D = flat.shape[1]
cdef int N = flat.shape[0]
cdef int count = 0
cdef np.ndarray[double, ndim=3] ret = np.zeros((M, M, D))
cdef np.ndarray[double, ndim=3] ret = np.zeros((D, M, M))
cdef int d, m, mm
for d in range(D):
count = 0
for m in range(M):
for mm in range(m+1):
ret[m, mm, d] = flat[count,d]
ret[d, m, mm] = flat[count,d]
count += 1
return ret
def triang_to_flat(np.ndarray[double, ndim=3] L):
cdef int M = L.shape[0]
cdef int D = L.shape[2]
cdef int D = L.shape[0]
cdef int M = L.shape[1]
cdef int N = M*(M+1)/2
cdef int count = 0
cdef np.ndarray[double, ndim=2] flat = np.empty((N, D))
@ -38,7 +38,7 @@ def triang_to_flat(np.ndarray[double, ndim=3] L):
count = 0
for m in range(M):
for mm in range(m+1):
flat[count,d] = L[m, mm, d]
flat[count,d] = L[d, m, mm]
count += 1
return flat