preliminary reconfiguring or choleskies ordering

This commit is contained in:
James Hensman 2015-05-05 14:13:38 +01:00
parent dde8e4136e
commit 5d1875ec44
6 changed files with 432 additions and 468 deletions

View file

@ -46,7 +46,7 @@ class SVGP(SparseGP):
num_latent_functions = Y.shape[1] num_latent_functions = Y.shape[1]
self.m = Param('q_u_mean', np.zeros((self.num_inducing, num_latent_functions))) self.m = Param('q_u_mean', np.zeros((self.num_inducing, num_latent_functions)))
chol = choleskies.triang_to_flat(np.tile(np.eye(self.num_inducing)[:,:,None], (1,1,num_latent_functions))) chol = choleskies.triang_to_flat(np.tile(np.eye(self.num_inducing)[None,:,:], (num_latent_functions, 1,1)))
self.chol = Param('q_u_chol', chol) self.chol = Param('q_u_chol', chol)
self.link_parameter(self.chol) self.link_parameter(self.chol)
self.link_parameter(self.m) self.link_parameter(self.m)

View file

@ -16,11 +16,10 @@ class SVGP(LatentFunctionInference):
S = np.empty((num_outputs, num_inducing, num_inducing)) S = np.empty((num_outputs, num_inducing, num_inducing))
[np.dot(L[:,:,i], L[:,:,i].T, S[i,:,:]) for i in range(num_outputs)] [np.dot(L[i,:,:], L[i,:,:].T, S[i,:,:]) for i in range(num_outputs)]
S = S.swapaxes(0,2)
#Si,_ = linalg.dpotri(np.asfortranarray(L), lower=1) #Si,_ = linalg.dpotri(np.asfortranarray(L), lower=1)
Si = choleskies.multiple_dpotri(L) Si = choleskies.multiple_dpotri(L)
logdetS = np.array([2.*np.sum(np.log(np.abs(np.diag(L[:,:,i])))) for i in range(L.shape[-1])]) logdetS = np.array([2.*np.sum(np.log(np.abs(np.diag(L[i,:,:])))) for i in range(L.shape[0])])
if np.any(np.isinf(Si)): if np.any(np.isinf(Si)):
raise ValueError("Cholesky representation unstable") raise ValueError("Cholesky representation unstable")
@ -46,16 +45,17 @@ class SVGP(LatentFunctionInference):
A = np.dot(Knm, Kmmi) A = np.dot(Knm, Kmmi)
mu = prior_mean_f + np.dot(A, q_u_mean - prior_mean_u) mu = prior_mean_f + np.dot(A, q_u_mean - prior_mean_u)
#v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * np.einsum('ij,jlk->ilk', A, S),1) #v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * np.einsum('ij,jlk->ilk', A, S),1)
v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * linalg.ij_jlk_to_ilk(A, S),1) #v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * linalg.ij_jlk_to_ilk(A, S),1)
v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + (S.dot(A.T)*A.T[None,:,:]).sum(1).T
#compute the KL term #compute the KL term
Kmmim = np.dot(Kmmi, q_u_mean) Kmmim = np.dot(Kmmi, q_u_mean)
KLs = -0.5*logdetS -0.5*num_inducing + 0.5*logdetKmm + 0.5*np.sum(Kmmi[:,:,None]*S,0).sum(0) + 0.5*np.sum(q_u_mean*Kmmim,0) KLs = -0.5*logdetS -0.5*num_inducing + 0.5*logdetKmm + 0.5*np.sum(Kmmi[None,:,:]*S,1).sum(1) + 0.5*np.sum(q_u_mean*Kmmim,0)
KL = KLs.sum() KL = KLs.sum()
#gradient of the KL term (assuming zero mean function) #gradient of the KL term (assuming zero mean function)
dKL_dm = Kmmim.copy() dKL_dm = Kmmim.copy()
dKL_dS = 0.5*(Kmmi[:,:,None] - Si) dKL_dS = 0.5*(Kmmi[None,:,:] - Si)
dKL_dKmm = 0.5*num_outputs*Kmmi - 0.5*Kmmi.dot(S.sum(-1)).dot(Kmmi) - 0.5*Kmmim.dot(Kmmim.T) dKL_dKmm = 0.5*num_outputs*Kmmi - 0.5*Kmmi.dot(S.sum(0)).dot(Kmmi) - 0.5*Kmmim.dot(Kmmim.T)
if mean_function is not None: if mean_function is not None:
#adjust KL term for mean function #adjust KL term for mean function
@ -80,17 +80,14 @@ class SVGP(LatentFunctionInference):
dF_dthetaL = dF_dthetaL.sum(1).sum(1)*batch_scale dF_dthetaL = dF_dthetaL.sum(1).sum(1)*batch_scale
#derivatives of expected likelihood, assuming zero mean function #derivatives of expected likelihood, assuming zero mean function
Adv = A.T[:,:,None]*dF_dv[None,:,:] # As if dF_Dv is diagonal Adv = A.T[None,:,:]*dF_dv.T[:,None,:] # As if dF_Dv is diagonal, D, M, N
Admu = A.T.dot(dF_dmu) Admu = A.T.dot(dF_dmu)
AdvA = np.dstack([np.dot(A.T, Adv[:,:,i].T) for i in range(num_outputs)]) AdvA = np.dot(Adv, A) # D, M, M
#tmp = np.einsum('ijk,jlk->il', AdvA, S).dot(Kmmi) tmp = np.sum([np.dot(a,s) for a, s in zip(AdvA, S)],0).dot(Kmmi)
tmp = linalg.ijk_jlk_to_il(AdvA, S).dot(Kmmi) dF_dKmm = -Admu.dot(Kmmim.T) + AdvA.sum(0) - tmp - tmp.T
dF_dKmm = -Admu.dot(Kmmim.T) + AdvA.sum(-1) - tmp - tmp.T
dF_dKmm = 0.5*(dF_dKmm + dF_dKmm.T) # necessary? GPy bug? dF_dKmm = 0.5*(dF_dKmm + dF_dKmm.T) # necessary? GPy bug?
#tmp = 2.*(np.einsum('ij,jlk->ilk', Kmmi,S) - np.eye(num_inducing)[:,:,None]) tmp = 2.*(S.dot(Kmmi).swapaxes(1,2) - np.eye(num_inducing)[None, :,:]) # TODO: transpose?
tmp = 2.*(linalg.ij_jlk_to_ilk(Kmmi, S) - np.eye(num_inducing)[:,:,None]) dF_dKmn = np.sum([np.dot(a,b) for a,b in zip(tmp, Adv)],0) + Kmmim.dot(dF_dmu.T)
#dF_dKmn = np.einsum('ijk,jlk->il', tmp, Adv) + Kmmim.dot(dF_dmu.T)
dF_dKmn = linalg.ijk_jlk_to_il(tmp, Adv) + Kmmim.dot(dF_dmu.T)
dF_dm = Admu dF_dm = Admu
dF_dS = AdvA dF_dS = AdvA
@ -103,14 +100,16 @@ class SVGP(LatentFunctionInference):
#sum (gradients of) expected likelihood and KL part #sum (gradients of) expected likelihood and KL part
log_marginal = F.sum() - KL log_marginal = F.sum()
dL_dm, dL_dS, dL_dKmm, dL_dKmn = dF_dm - dKL_dm, dF_dS- dKL_dS, dF_dKmm- dKL_dKmm, dF_dKmn dL_dm, dL_dS, dL_dKmm, dL_dKmn = dF_dm - dKL_dm*0, dF_dS- dKL_dS*0, dF_dKmm- dKL_dKmm*0, dF_dKmn
#log_marginal = F.sum() - KL
#dL_dm, dL_dS, dL_dKmm, dL_dKmn = dF_dm - dKL_dm, dF_dS- dKL_dS, dF_dKmm- dKL_dKmm, dF_dKmn
dL_dchol = np.dstack([2.*np.dot(dL_dS[:,:,i], L[:,:,i]) for i in range(num_outputs)]) dL_dchol = 2.*np.array([np.dot(a,b) for a, b in zip(dL_dS, L) ])
dL_dchol = choleskies.triang_to_flat(dL_dchol) dL_dchol = choleskies.triang_to_flat(dL_dchol)
grad_dict = {'dL_dKmm':dL_dKmm, 'dL_dKmn':dL_dKmn, 'dL_dKdiag': dF_dv.sum(1), 'dL_dm':dL_dm, 'dL_dchol':dL_dchol, 'dL_dthetaL':dF_dthetaL} grad_dict = {'dL_dKmm':dL_dKmm, 'dL_dKmn':dL_dKmn, 'dL_dKdiag': dF_dv.sum(1), 'dL_dm':dL_dm, 'dL_dchol':dL_dchol, 'dL_dthetaL':dF_dthetaL}
if mean_function is not None: if mean_function is not None:
grad_dict['dL_dmfZ'] = dF_dmfZ - dKL_dmfZ grad_dict['dL_dmfZ'] = dF_dmfZ - dKL_dmfZ
grad_dict['dL_dmfX'] = dF_dmfX grad_dict['dL_dmfX'] = dF_dmfX
return Posterior(mean=q_u_mean, cov=S, K=Kmm, prior_mean=prior_mean_u), log_marginal, grad_dict return Posterior(mean=q_u_mean, cov=S.T, K=Kmm, prior_mean=prior_mean_u), log_marginal, grad_dict

View file

@ -9,8 +9,8 @@ These tests make sure that the opure python and cython codes work the same
class CythonTestChols(np.testing.TestCase): class CythonTestChols(np.testing.TestCase):
def setUp(self): def setUp(self):
self.flat = np.random.randn(45, 5) self.flat = np.random.randn(5,45)
self.triang = np.dstack([np.eye(20)[:,:,None] for i in range(3)]) self.triang = np.array([np.eye(20) for i in range(3)])
def test_flat_to_triang(self): def test_flat_to_triang(self):
L1 = choleskies._flat_to_triang_pure(self.flat) L1 = choleskies._flat_to_triang_pure(self.flat)
L2 = choleskies._flat_to_triang_cython(self.flat) L2 = choleskies._flat_to_triang_cython(self.flat)

View file

@ -15,33 +15,33 @@ def safe_root(N):
return j return j
def _flat_to_triang_pure(flat_mat): def _flat_to_triang_pure(flat_mat):
N, D = flat_mat.shape D, N = flat_mat.shape
M = (-1 + safe_root(8*N+1))//2 M = (-1 + safe_root(8*N+1))//2
ret = np.zeros((M, M, D)) ret = np.zeros((D, M, M))
count = 0 for d in range(D):
for m in range(M): count = 0
for mm in range(m+1): for m in range(M):
for d in range(D): for mm in range(m+1):
ret.flat[d + m*D*M + mm*D] = flat_mat.flat[count]; ret[d,m, mm] = flat_mat[d, count];
count = count+1 count = count+1
return ret return ret
def _flat_to_triang_cython(flat_mat): def _flat_to_triang_cython(flat_mat):
N, D = flat_mat.shape D, N = flat_mat.shape
M = (-1 + safe_root(8*N+1))//2 M = (-1 + safe_root(8*N+1))//2
return choleskies_cython.flat_to_triang(flat_mat, M) return choleskies_cython.flat_to_triang(flat_mat, M)
def _triang_to_flat_pure(L): def _triang_to_flat_pure(L):
M, _, D = L.shape D, _, M = L.shape
N = M*(M+1)//2 N = M*(M+1)//2
flat = np.empty((N, D)) flat = np.empty((D, N))
count = 0; for d in range(D):
for m in range(M): count = 0;
for mm in range(m+1): for m in range(M):
for d in range(D): for mm in range(m+1):
flat.flat[count] = L.flat[d + m*D*M + mm*D]; flat[d,count] = L[d, m, mm]
count = count +1 count = count +1
return flat return flat
@ -74,7 +74,7 @@ def triang_to_cov(L):
return np.dstack([np.dot(L[:,:,i], L[:,:,i].T) for i in range(L.shape[-1])]) return np.dstack([np.dot(L[:,:,i], L[:,:,i].T) for i in range(L.shape[-1])])
def multiple_dpotri(Ls): def multiple_dpotri(Ls):
return np.dstack([linalg.dpotri(np.asfortranarray(Ls[:,:,i]), lower=1)[0] for i in range(Ls.shape[-1])]) return np.array([linalg.dpotri(np.asfortranarray(Ls[i]), lower=1)[0] for i in range(Ls.shape[0])])
def indexes_to_fix_for_low_rank(rank, size): def indexes_to_fix_for_low_rank(rank, size):
""" """

File diff suppressed because it is too large Load diff

View file

@ -8,37 +8,37 @@ import numpy as np
cimport numpy as np cimport numpy as np
def flat_to_triang(np.ndarray[double, ndim=2] flat, int M): def flat_to_triang(np.ndarray[double, ndim=2] flat, int M):
"""take a matrix N x D and return a M X M x D array where """take a matrix D x N and return a D X M x M array where
N = M(M+1)/2 N = M(M+1)/2
the lower triangluar portion of the d'th slice of the result is filled by the d'th column of flat. the lower triangluar portion of the d'th slice of the result is filled by the d'th column of flat.
""" """
cdef int N = flat.shape[0] cdef int D = flat.shape[0]
cdef int D = flat.shape[1] cdef int N = flat.shape[1]
cdef int count = 0 cdef int count = 0
cdef np.ndarray[double, ndim=3] ret = np.zeros((M, M, D)) cdef np.ndarray[double, ndim=3] ret = np.zeros((D, M, M))
cdef int d, m, mm cdef int d, m, mm
for d in range(D): for d in range(D):
count = 0 count = 0
for m in range(M): for m in range(M):
for mm in range(m+1): for mm in range(m+1):
ret[m, mm, d] = flat[count,d] ret[d, m, mm] = flat[d,count]
count += 1 count += 1
return ret return ret
def triang_to_flat(np.ndarray[double, ndim=3] L): def triang_to_flat(np.ndarray[double, ndim=3] L):
cdef int M = L.shape[0] cdef int D = L.shape[0]
cdef int D = L.shape[2] cdef int M = L.shape[1]
cdef int N = M*(M+1)/2 cdef int N = M*(M+1)/2
cdef int count = 0 cdef int count = 0
cdef np.ndarray[double, ndim=2] flat = np.empty((N, D)) cdef np.ndarray[double, ndim=2] flat = np.empty((D, N))
cdef int d, m, mm cdef int d, m, mm
for d in range(D): for d in range(D):
count = 0 count = 0
for m in range(M): for m in range(M):
for mm in range(m+1): for mm in range(m+1):
flat[count,d] = L[m, mm, d] flat[d,count] = L[d, m, mm]
count += 1 count += 1
return flat return flat