preliminary reconfiguring or choleskies ordering

This commit is contained in:
James Hensman 2015-05-05 14:13:38 +01:00
parent dde8e4136e
commit 5d1875ec44
6 changed files with 432 additions and 468 deletions

View file

@ -46,7 +46,7 @@ class SVGP(SparseGP):
num_latent_functions = Y.shape[1]
self.m = Param('q_u_mean', np.zeros((self.num_inducing, num_latent_functions)))
chol = choleskies.triang_to_flat(np.tile(np.eye(self.num_inducing)[:,:,None], (1,1,num_latent_functions)))
chol = choleskies.triang_to_flat(np.tile(np.eye(self.num_inducing)[None,:,:], (num_latent_functions, 1,1)))
self.chol = Param('q_u_chol', chol)
self.link_parameter(self.chol)
self.link_parameter(self.m)

View file

@ -16,11 +16,10 @@ class SVGP(LatentFunctionInference):
S = np.empty((num_outputs, num_inducing, num_inducing))
[np.dot(L[:,:,i], L[:,:,i].T, S[i,:,:]) for i in range(num_outputs)]
S = S.swapaxes(0,2)
[np.dot(L[i,:,:], L[i,:,:].T, S[i,:,:]) for i in range(num_outputs)]
#Si,_ = linalg.dpotri(np.asfortranarray(L), lower=1)
Si = choleskies.multiple_dpotri(L)
logdetS = np.array([2.*np.sum(np.log(np.abs(np.diag(L[:,:,i])))) for i in range(L.shape[-1])])
logdetS = np.array([2.*np.sum(np.log(np.abs(np.diag(L[i,:,:])))) for i in range(L.shape[0])])
if np.any(np.isinf(Si)):
raise ValueError("Cholesky representation unstable")
@ -46,16 +45,17 @@ class SVGP(LatentFunctionInference):
A = np.dot(Knm, Kmmi)
mu = prior_mean_f + np.dot(A, q_u_mean - prior_mean_u)
#v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * np.einsum('ij,jlk->ilk', A, S),1)
v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * linalg.ij_jlk_to_ilk(A, S),1)
#v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * linalg.ij_jlk_to_ilk(A, S),1)
v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + (S.dot(A.T)*A.T[None,:,:]).sum(1).T
#compute the KL term
Kmmim = np.dot(Kmmi, q_u_mean)
KLs = -0.5*logdetS -0.5*num_inducing + 0.5*logdetKmm + 0.5*np.sum(Kmmi[:,:,None]*S,0).sum(0) + 0.5*np.sum(q_u_mean*Kmmim,0)
KLs = -0.5*logdetS -0.5*num_inducing + 0.5*logdetKmm + 0.5*np.sum(Kmmi[None,:,:]*S,1).sum(1) + 0.5*np.sum(q_u_mean*Kmmim,0)
KL = KLs.sum()
#gradient of the KL term (assuming zero mean function)
dKL_dm = Kmmim.copy()
dKL_dS = 0.5*(Kmmi[:,:,None] - Si)
dKL_dKmm = 0.5*num_outputs*Kmmi - 0.5*Kmmi.dot(S.sum(-1)).dot(Kmmi) - 0.5*Kmmim.dot(Kmmim.T)
dKL_dS = 0.5*(Kmmi[None,:,:] - Si)
dKL_dKmm = 0.5*num_outputs*Kmmi - 0.5*Kmmi.dot(S.sum(0)).dot(Kmmi) - 0.5*Kmmim.dot(Kmmim.T)
if mean_function is not None:
#adjust KL term for mean function
@ -80,17 +80,14 @@ class SVGP(LatentFunctionInference):
dF_dthetaL = dF_dthetaL.sum(1).sum(1)*batch_scale
#derivatives of expected likelihood, assuming zero mean function
Adv = A.T[:,:,None]*dF_dv[None,:,:] # As if dF_Dv is diagonal
Adv = A.T[None,:,:]*dF_dv.T[:,None,:] # As if dF_Dv is diagonal, D, M, N
Admu = A.T.dot(dF_dmu)
AdvA = np.dstack([np.dot(A.T, Adv[:,:,i].T) for i in range(num_outputs)])
#tmp = np.einsum('ijk,jlk->il', AdvA, S).dot(Kmmi)
tmp = linalg.ijk_jlk_to_il(AdvA, S).dot(Kmmi)
dF_dKmm = -Admu.dot(Kmmim.T) + AdvA.sum(-1) - tmp - tmp.T
AdvA = np.dot(Adv, A) # D, M, M
tmp = np.sum([np.dot(a,s) for a, s in zip(AdvA, S)],0).dot(Kmmi)
dF_dKmm = -Admu.dot(Kmmim.T) + AdvA.sum(0) - tmp - tmp.T
dF_dKmm = 0.5*(dF_dKmm + dF_dKmm.T) # necessary? GPy bug?
#tmp = 2.*(np.einsum('ij,jlk->ilk', Kmmi,S) - np.eye(num_inducing)[:,:,None])
tmp = 2.*(linalg.ij_jlk_to_ilk(Kmmi, S) - np.eye(num_inducing)[:,:,None])
#dF_dKmn = np.einsum('ijk,jlk->il', tmp, Adv) + Kmmim.dot(dF_dmu.T)
dF_dKmn = linalg.ijk_jlk_to_il(tmp, Adv) + Kmmim.dot(dF_dmu.T)
tmp = 2.*(S.dot(Kmmi).swapaxes(1,2) - np.eye(num_inducing)[None, :,:]) # TODO: transpose?
dF_dKmn = np.sum([np.dot(a,b) for a,b in zip(tmp, Adv)],0) + Kmmim.dot(dF_dmu.T)
dF_dm = Admu
dF_dS = AdvA
@ -103,14 +100,16 @@ class SVGP(LatentFunctionInference):
#sum (gradients of) expected likelihood and KL part
log_marginal = F.sum() - KL
dL_dm, dL_dS, dL_dKmm, dL_dKmn = dF_dm - dKL_dm, dF_dS- dKL_dS, dF_dKmm- dKL_dKmm, dF_dKmn
log_marginal = F.sum()
dL_dm, dL_dS, dL_dKmm, dL_dKmn = dF_dm - dKL_dm*0, dF_dS- dKL_dS*0, dF_dKmm- dKL_dKmm*0, dF_dKmn
#log_marginal = F.sum() - KL
#dL_dm, dL_dS, dL_dKmm, dL_dKmn = dF_dm - dKL_dm, dF_dS- dKL_dS, dF_dKmm- dKL_dKmm, dF_dKmn
dL_dchol = np.dstack([2.*np.dot(dL_dS[:,:,i], L[:,:,i]) for i in range(num_outputs)])
dL_dchol = 2.*np.array([np.dot(a,b) for a, b in zip(dL_dS, L) ])
dL_dchol = choleskies.triang_to_flat(dL_dchol)
grad_dict = {'dL_dKmm':dL_dKmm, 'dL_dKmn':dL_dKmn, 'dL_dKdiag': dF_dv.sum(1), 'dL_dm':dL_dm, 'dL_dchol':dL_dchol, 'dL_dthetaL':dF_dthetaL}
if mean_function is not None:
grad_dict['dL_dmfZ'] = dF_dmfZ - dKL_dmfZ
grad_dict['dL_dmfX'] = dF_dmfX
return Posterior(mean=q_u_mean, cov=S, K=Kmm, prior_mean=prior_mean_u), log_marginal, grad_dict
return Posterior(mean=q_u_mean, cov=S.T, K=Kmm, prior_mean=prior_mean_u), log_marginal, grad_dict

View file

@ -9,8 +9,8 @@ These tests make sure that the opure python and cython codes work the same
class CythonTestChols(np.testing.TestCase):
def setUp(self):
self.flat = np.random.randn(45, 5)
self.triang = np.dstack([np.eye(20)[:,:,None] for i in range(3)])
self.flat = np.random.randn(5,45)
self.triang = np.array([np.eye(20) for i in range(3)])
def test_flat_to_triang(self):
L1 = choleskies._flat_to_triang_pure(self.flat)
L2 = choleskies._flat_to_triang_cython(self.flat)

View file

@ -15,33 +15,33 @@ def safe_root(N):
return j
def _flat_to_triang_pure(flat_mat):
N, D = flat_mat.shape
D, N = flat_mat.shape
M = (-1 + safe_root(8*N+1))//2
ret = np.zeros((M, M, D))
count = 0
for m in range(M):
for mm in range(m+1):
for d in range(D):
ret.flat[d + m*D*M + mm*D] = flat_mat.flat[count];
ret = np.zeros((D, M, M))
for d in range(D):
count = 0
for m in range(M):
for mm in range(m+1):
ret[d,m, mm] = flat_mat[d, count];
count = count+1
return ret
def _flat_to_triang_cython(flat_mat):
N, D = flat_mat.shape
D, N = flat_mat.shape
M = (-1 + safe_root(8*N+1))//2
return choleskies_cython.flat_to_triang(flat_mat, M)
def _triang_to_flat_pure(L):
M, _, D = L.shape
D, _, M = L.shape
N = M*(M+1)//2
flat = np.empty((N, D))
count = 0;
for m in range(M):
for mm in range(m+1):
for d in range(D):
flat.flat[count] = L.flat[d + m*D*M + mm*D];
flat = np.empty((D, N))
for d in range(D):
count = 0;
for m in range(M):
for mm in range(m+1):
flat[d,count] = L[d, m, mm]
count = count +1
return flat
@ -74,7 +74,7 @@ def triang_to_cov(L):
return np.dstack([np.dot(L[:,:,i], L[:,:,i].T) for i in range(L.shape[-1])])
def multiple_dpotri(Ls):
return np.dstack([linalg.dpotri(np.asfortranarray(Ls[:,:,i]), lower=1)[0] for i in range(Ls.shape[-1])])
return np.array([linalg.dpotri(np.asfortranarray(Ls[i]), lower=1)[0] for i in range(Ls.shape[0])])
def indexes_to_fix_for_low_rank(rank, size):
"""

File diff suppressed because it is too large Load diff

View file

@ -8,37 +8,37 @@ import numpy as np
cimport numpy as np
def flat_to_triang(np.ndarray[double, ndim=2] flat, int M):
"""take a matrix N x D and return a M X M x D array where
"""take a matrix D x N and return a D X M x M array where
N = M(M+1)/2
the lower triangluar portion of the d'th slice of the result is filled by the d'th column of flat.
"""
cdef int N = flat.shape[0]
cdef int D = flat.shape[1]
cdef int D = flat.shape[0]
cdef int N = flat.shape[1]
cdef int count = 0
cdef np.ndarray[double, ndim=3] ret = np.zeros((M, M, D))
cdef np.ndarray[double, ndim=3] ret = np.zeros((D, M, M))
cdef int d, m, mm
for d in range(D):
count = 0
for m in range(M):
for mm in range(m+1):
ret[m, mm, d] = flat[count,d]
ret[d, m, mm] = flat[d,count]
count += 1
return ret
def triang_to_flat(np.ndarray[double, ndim=3] L):
cdef int M = L.shape[0]
cdef int D = L.shape[2]
cdef int D = L.shape[0]
cdef int M = L.shape[1]
cdef int N = M*(M+1)/2
cdef int count = 0
cdef np.ndarray[double, ndim=2] flat = np.empty((N, D))
cdef np.ndarray[double, ndim=2] flat = np.empty((D, N))
cdef int d, m, mm
for d in range(D):
count = 0
for m in range(M):
for mm in range(m+1):
flat[count,d] = L[m, mm, d]
flat[d,count] = L[d, m, mm]
count += 1
return flat