Merge branch 'reorder_choleskies' into devel

This commit is contained in:
James Hensman 2015-05-15 09:00:18 +01:00
commit 0651c933db
7 changed files with 451 additions and 463 deletions

View file

@ -46,7 +46,7 @@ class SVGP(SparseGP):
num_latent_functions = Y.shape[1] num_latent_functions = Y.shape[1]
self.m = Param('q_u_mean', np.zeros((self.num_inducing, num_latent_functions))) self.m = Param('q_u_mean', np.zeros((self.num_inducing, num_latent_functions)))
chol = choleskies.triang_to_flat(np.tile(np.eye(self.num_inducing)[:,:,None], (1,1,num_latent_functions))) chol = choleskies.triang_to_flat(np.tile(np.eye(self.num_inducing)[None,:,:], (num_latent_functions, 1,1)))
self.chol = Param('q_u_chol', chol) self.chol = Param('q_u_chol', chol)
self.link_parameter(self.chol) self.link_parameter(self.chol)
self.link_parameter(self.m) self.link_parameter(self.m)

View file

@ -3,6 +3,7 @@ from ...util import linalg
from ...util import choleskies from ...util import choleskies
import numpy as np import numpy as np
from .posterior import Posterior from .posterior import Posterior
from scipy.linalg.blas import dgemm, dsymm, dtrmm
class SVGP(LatentFunctionInference): class SVGP(LatentFunctionInference):
@ -16,16 +17,13 @@ class SVGP(LatentFunctionInference):
S = np.empty((num_outputs, num_inducing, num_inducing)) S = np.empty((num_outputs, num_inducing, num_inducing))
[np.dot(L[:,:,i], L[:,:,i].T, S[i,:,:]) for i in range(num_outputs)] [np.dot(L[i,:,:], L[i,:,:].T, S[i,:,:]) for i in range(num_outputs)]
S = S.swapaxes(0,2)
#Si,_ = linalg.dpotri(np.asfortranarray(L), lower=1) #Si,_ = linalg.dpotri(np.asfortranarray(L), lower=1)
Si = choleskies.multiple_dpotri(L) Si = choleskies.multiple_dpotri(L)
logdetS = np.array([2.*np.sum(np.log(np.abs(np.diag(L[:,:,i])))) for i in range(L.shape[-1])]) logdetS = np.array([2.*np.sum(np.log(np.abs(np.diag(L[i,:,:])))) for i in range(L.shape[0])])
if np.any(np.isinf(Si)): if np.any(np.isinf(Si)):
raise ValueError("Cholesky representation unstable") raise ValueError("Cholesky representation unstable")
#S = S + np.eye(S.shape[0])*1e-5*np.max(np.max(S))
#Si, Lnew, _,_ = linalg.pdinv(S)
#compute mean function stuff #compute mean function stuff
if mean_function is not None: if mean_function is not None:
@ -35,27 +33,29 @@ class SVGP(LatentFunctionInference):
prior_mean_u = np.zeros((num_inducing, num_outputs)) prior_mean_u = np.zeros((num_inducing, num_outputs))
prior_mean_f = np.zeros((num_data, num_outputs)) prior_mean_f = np.zeros((num_data, num_outputs))
#compute kernel related stuff #compute kernel related stuff
Kmm = kern.K(Z) Kmm = kern.K(Z)
Knm = kern.K(X, Z) Kmn = kern.K(Z, X)
Knn_diag = kern.Kdiag(X) Knn_diag = kern.Kdiag(X)
Kmmi, Lm, Lmi, logdetKmm = linalg.pdinv(Kmm) Lm = linalg.jitchol(Kmm)
logdetKmm = 2.*np.sum(np.log(np.diag(Lm)))
Kmmi, _ = linalg.dpotri(Lm)
#compute the marginal means and variances of q(f) #compute the marginal means and variances of q(f)
A = np.dot(Knm, Kmmi) A, _ = linalg.dpotrs(Lm, Kmn)
mu = prior_mean_f + np.dot(A, q_u_mean - prior_mean_u) mu = prior_mean_f + np.dot(A.T, q_u_mean - prior_mean_u)
#v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * np.einsum('ij,jlk->ilk', A, S),1) LA = L.reshape(-1, num_inducing).dot(A).reshape(num_outputs, num_inducing, num_data)
v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * linalg.ij_jlk_to_ilk(A, S),1) #TODO? possibly use dtrmm for the above line?
v = (Knn_diag - np.sum(A*Kmn,0))[:,None] + np.sum(np.square(LA),1).T
#compute the KL term #compute the KL term
Kmmim = np.dot(Kmmi, q_u_mean) Kmmim = np.dot(Kmmi, q_u_mean)
KLs = -0.5*logdetS -0.5*num_inducing + 0.5*logdetKmm + 0.5*np.sum(Kmmi[:,:,None]*S,0).sum(0) + 0.5*np.sum(q_u_mean*Kmmim,0) KLs = -0.5*logdetS -0.5*num_inducing + 0.5*logdetKmm + 0.5*np.sum(Kmmi[None,:,:]*S,1).sum(1) + 0.5*np.sum(q_u_mean*Kmmim,0)
KL = KLs.sum() KL = KLs.sum()
#gradient of the KL term (assuming zero mean function) #gradient of the KL term (assuming zero mean function)
dKL_dm = Kmmim.copy() dKL_dm = Kmmim.copy()
dKL_dS = 0.5*(Kmmi[:,:,None] - Si) dKL_dS = 0.5*(Kmmi[None,:,:] - Si)
dKL_dKmm = 0.5*num_outputs*Kmmi - 0.5*Kmmi.dot(S.sum(-1)).dot(Kmmi) - 0.5*Kmmim.dot(Kmmim.T) dKL_dKmm = 0.5*num_outputs*Kmmi - 0.5*Kmmi.dot(S.sum(0)).dot(Kmmi) - 0.5*Kmmim.dot(Kmmim.T)
if mean_function is not None: if mean_function is not None:
#adjust KL term for mean function #adjust KL term for mean function
@ -80,17 +80,22 @@ class SVGP(LatentFunctionInference):
dF_dthetaL = dF_dthetaL.sum(1).sum(1)*batch_scale dF_dthetaL = dF_dthetaL.sum(1).sum(1)*batch_scale
#derivatives of expected likelihood, assuming zero mean function #derivatives of expected likelihood, assuming zero mean function
Adv = A.T[:,:,None]*dF_dv[None,:,:] # As if dF_Dv is diagonal Adv = A[None,:,:]*dF_dv.T[:,None,:] # As if dF_Dv is diagonal, D, M, N
Admu = A.T.dot(dF_dmu) Admu = A.dot(dF_dmu)
AdvA = np.dstack([np.dot(A.T, Adv[:,:,i].T) for i in range(num_outputs)]) Adv = np.ascontiguousarray(Adv) # makes for faster operations later...(inc dsymm)
#tmp = np.einsum('ijk,jlk->il', AdvA, S).dot(Kmmi) AdvA = np.dot(Adv.reshape(-1, num_data),A.T).reshape(num_outputs, num_inducing, num_inducing )
tmp = linalg.ijk_jlk_to_il(AdvA, S).dot(Kmmi) tmp = np.sum([np.dot(a,s) for a, s in zip(AdvA, S)],0).dot(Kmmi)
dF_dKmm = -Admu.dot(Kmmim.T) + AdvA.sum(-1) - tmp - tmp.T dF_dKmm = -Admu.dot(Kmmim.T) + AdvA.sum(0) - tmp - tmp.T
dF_dKmm = 0.5*(dF_dKmm + dF_dKmm.T) # necessary? GPy bug? dF_dKmm = 0.5*(dF_dKmm + dF_dKmm.T) # necessary? GPy bug?
#tmp = 2.*(np.einsum('ij,jlk->ilk', Kmmi,S) - np.eye(num_inducing)[:,:,None]) tmp = S.reshape(-1, num_inducing).dot(Kmmi).reshape(num_outputs, num_inducing , num_inducing )
tmp = 2.*(linalg.ij_jlk_to_ilk(Kmmi, S) - np.eye(num_inducing)[:,:,None]) tmp = 2.*(tmp - np.eye(num_inducing)[None, :,:])
#dF_dKmn = np.einsum('ijk,jlk->il', tmp, Adv) + Kmmim.dot(dF_dmu.T)
dF_dKmn = linalg.ijk_jlk_to_il(tmp, Adv) + Kmmim.dot(dF_dmu.T) dF_dKnm = Kmmim.dot(dF_dmu.T).T
assert dF_dKnm.flags['F_CONTIGUOUS'] # needed for dsymm in place call:
for a,b in zip(tmp, Adv):
dsymm(1.0, a.T, b.T, beta=1., side=1, c=dF_dKnm, overwrite_c=True)
dF_dKmn = dF_dKnm.T
dF_dm = Admu dF_dm = Admu
dF_dS = AdvA dF_dS = AdvA
@ -106,11 +111,11 @@ class SVGP(LatentFunctionInference):
log_marginal = F.sum() - KL log_marginal = F.sum() - KL
dL_dm, dL_dS, dL_dKmm, dL_dKmn = dF_dm - dKL_dm, dF_dS- dKL_dS, dF_dKmm- dKL_dKmm, dF_dKmn dL_dm, dL_dS, dL_dKmm, dL_dKmn = dF_dm - dKL_dm, dF_dS- dKL_dS, dF_dKmm- dKL_dKmm, dF_dKmn
dL_dchol = np.dstack([2.*np.dot(dL_dS[:,:,i], L[:,:,i]) for i in range(num_outputs)]) dL_dchol = 2.*np.array([np.dot(a,b) for a, b in zip(dL_dS, L) ])
dL_dchol = choleskies.triang_to_flat(dL_dchol) dL_dchol = choleskies.triang_to_flat(dL_dchol)
grad_dict = {'dL_dKmm':dL_dKmm, 'dL_dKmn':dL_dKmn, 'dL_dKdiag': dF_dv.sum(1), 'dL_dm':dL_dm, 'dL_dchol':dL_dchol, 'dL_dthetaL':dF_dthetaL} grad_dict = {'dL_dKmm':dL_dKmm, 'dL_dKmn':dL_dKmn, 'dL_dKdiag': dF_dv.sum(1), 'dL_dm':dL_dm, 'dL_dchol':dL_dchol, 'dL_dthetaL':dF_dthetaL}
if mean_function is not None: if mean_function is not None:
grad_dict['dL_dmfZ'] = dF_dmfZ - dKL_dmfZ grad_dict['dL_dmfZ'] = dF_dmfZ - dKL_dmfZ
grad_dict['dL_dmfX'] = dF_dmfX grad_dict['dL_dmfX'] = dF_dmfX
return Posterior(mean=q_u_mean, cov=S, K=Kmm, prior_mean=prior_mean_u), log_marginal, grad_dict return Posterior(mean=q_u_mean, cov=S.T, K=Kmm, prior_mean=prior_mean_u), log_marginal, grad_dict

View file

@ -14,6 +14,22 @@ for(nd=0;nd<(D*N);nd++){
} //grad_X } //grad_X
void _lengthscale_grads_unsafe(int N, int M, int Q, double* tmp, double* X, double* X2, double* grad){
int n,m,nm,q,nQ,mQ;
double dist;
#pragma omp parallel for private(n,m,nm,q,nQ,mQ,dist)
for(nm=0; nm<(N*M); nm++){
n = nm/M;
m = nm%M;
nQ = n*Q;
mQ = m*Q;
for(q=0; q<Q; q++){
dist = X[nQ+q]-X2[mQ+q];
grad[q] += tmp[nm]*dist*dist;
}
}
} //lengthscale_grads
void _lengthscale_grads(int N, int M, int Q, double* tmp, double* X, double* X2, double* grad){ void _lengthscale_grads(int N, int M, int Q, double* tmp, double* X, double* X2, double* grad){
int n,m,q; int n,m,q;
@ -34,3 +50,5 @@ for(q=0; q<Q; q++){

View file

@ -9,8 +9,8 @@ These tests make sure that the opure python and cython codes work the same
class CythonTestChols(np.testing.TestCase): class CythonTestChols(np.testing.TestCase):
def setUp(self): def setUp(self):
self.flat = np.random.randn(45, 5) self.flat = np.random.randn(5,45)
self.triang = np.dstack([np.eye(20)[:,:,None] for i in range(3)]) self.triang = np.array([np.eye(20) for i in range(3)])
def test_flat_to_triang(self): def test_flat_to_triang(self):
L1 = choleskies._flat_to_triang_pure(self.flat) L1 = choleskies._flat_to_triang_pure(self.flat)
L2 = choleskies._flat_to_triang_cython(self.flat) L2 = choleskies._flat_to_triang_cython(self.flat)

View file

@ -17,12 +17,12 @@ def safe_root(N):
def _flat_to_triang_pure(flat_mat): def _flat_to_triang_pure(flat_mat):
N, D = flat_mat.shape N, D = flat_mat.shape
M = (-1 + safe_root(8*N+1))//2 M = (-1 + safe_root(8*N+1))//2
ret = np.zeros((M, M, D)) ret = np.zeros((D, M, M))
count = 0 for d in range(D):
for m in range(M): count = 0
for mm in range(m+1): for m in range(M):
for d in range(D): for mm in range(m+1):
ret.flat[d + m*D*M + mm*D] = flat_mat.flat[count]; ret[d,m, mm] = flat_mat[count, d];
count = count+1 count = count+1
return ret return ret
@ -33,15 +33,15 @@ def _flat_to_triang_cython(flat_mat):
def _triang_to_flat_pure(L): def _triang_to_flat_pure(L):
M, _, D = L.shape D, _, M = L.shape
N = M*(M+1)//2 N = M*(M+1)//2
flat = np.empty((N, D)) flat = np.empty((N, D))
count = 0; for d in range(D):
for m in range(M): count = 0;
for mm in range(m+1): for m in range(M):
for d in range(D): for mm in range(m+1):
flat.flat[count] = L.flat[d + m*D*M + mm*D]; flat[count,d] = L[d, m, mm]
count = count +1 count = count +1
return flat return flat
@ -74,7 +74,7 @@ def triang_to_cov(L):
return np.dstack([np.dot(L[:,:,i], L[:,:,i].T) for i in range(L.shape[-1])]) return np.dstack([np.dot(L[:,:,i], L[:,:,i].T) for i in range(L.shape[-1])])
def multiple_dpotri(Ls): def multiple_dpotri(Ls):
return np.dstack([linalg.dpotri(np.asfortranarray(Ls[:,:,i]), lower=1)[0] for i in range(Ls.shape[-1])]) return np.array([linalg.dpotri(np.asfortranarray(Ls[i]), lower=1)[0] for i in range(Ls.shape[0])])
def indexes_to_fix_for_low_rank(rank, size): def indexes_to_fix_for_low_rank(rank, size):
""" """

File diff suppressed because it is too large Load diff

View file

@ -8,28 +8,28 @@ import numpy as np
cimport numpy as np cimport numpy as np
def flat_to_triang(np.ndarray[double, ndim=2] flat, int M): def flat_to_triang(np.ndarray[double, ndim=2] flat, int M):
"""take a matrix N x D and return a M X M x D array where """take a matrix N x D and return a D X M x M array where
N = M(M+1)/2 N = M(M+1)/2
the lower triangluar portion of the d'th slice of the result is filled by the d'th column of flat. the lower triangluar portion of the d'th slice of the result is filled by the d'th column of flat.
""" """
cdef int N = flat.shape[0]
cdef int D = flat.shape[1] cdef int D = flat.shape[1]
cdef int N = flat.shape[0]
cdef int count = 0 cdef int count = 0
cdef np.ndarray[double, ndim=3] ret = np.zeros((M, M, D)) cdef np.ndarray[double, ndim=3] ret = np.zeros((D, M, M))
cdef int d, m, mm cdef int d, m, mm
for d in range(D): for d in range(D):
count = 0 count = 0
for m in range(M): for m in range(M):
for mm in range(m+1): for mm in range(m+1):
ret[m, mm, d] = flat[count,d] ret[d, m, mm] = flat[count,d]
count += 1 count += 1
return ret return ret
def triang_to_flat(np.ndarray[double, ndim=3] L): def triang_to_flat(np.ndarray[double, ndim=3] L):
cdef int M = L.shape[0] cdef int D = L.shape[0]
cdef int D = L.shape[2] cdef int M = L.shape[1]
cdef int N = M*(M+1)/2 cdef int N = M*(M+1)/2
cdef int count = 0 cdef int count = 0
cdef np.ndarray[double, ndim=2] flat = np.empty((N, D)) cdef np.ndarray[double, ndim=2] flat = np.empty((N, D))
@ -38,7 +38,7 @@ def triang_to_flat(np.ndarray[double, ndim=3] L):
count = 0 count = 0
for m in range(M): for m in range(M):
for mm in range(m+1): for mm in range(m+1):
flat[count,d] = L[m, mm, d] flat[count,d] = L[d, m, mm]
count += 1 count += 1
return flat return flat