mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-12 05:22:38 +02:00
preliminary reconfiguring or choleskies ordering
This commit is contained in:
parent
dde8e4136e
commit
5d1875ec44
6 changed files with 432 additions and 468 deletions
|
|
@ -46,7 +46,7 @@ class SVGP(SparseGP):
|
||||||
num_latent_functions = Y.shape[1]
|
num_latent_functions = Y.shape[1]
|
||||||
|
|
||||||
self.m = Param('q_u_mean', np.zeros((self.num_inducing, num_latent_functions)))
|
self.m = Param('q_u_mean', np.zeros((self.num_inducing, num_latent_functions)))
|
||||||
chol = choleskies.triang_to_flat(np.tile(np.eye(self.num_inducing)[:,:,None], (1,1,num_latent_functions)))
|
chol = choleskies.triang_to_flat(np.tile(np.eye(self.num_inducing)[None,:,:], (num_latent_functions, 1,1)))
|
||||||
self.chol = Param('q_u_chol', chol)
|
self.chol = Param('q_u_chol', chol)
|
||||||
self.link_parameter(self.chol)
|
self.link_parameter(self.chol)
|
||||||
self.link_parameter(self.m)
|
self.link_parameter(self.m)
|
||||||
|
|
|
||||||
|
|
@ -16,11 +16,10 @@ class SVGP(LatentFunctionInference):
|
||||||
|
|
||||||
|
|
||||||
S = np.empty((num_outputs, num_inducing, num_inducing))
|
S = np.empty((num_outputs, num_inducing, num_inducing))
|
||||||
[np.dot(L[:,:,i], L[:,:,i].T, S[i,:,:]) for i in range(num_outputs)]
|
[np.dot(L[i,:,:], L[i,:,:].T, S[i,:,:]) for i in range(num_outputs)]
|
||||||
S = S.swapaxes(0,2)
|
|
||||||
#Si,_ = linalg.dpotri(np.asfortranarray(L), lower=1)
|
#Si,_ = linalg.dpotri(np.asfortranarray(L), lower=1)
|
||||||
Si = choleskies.multiple_dpotri(L)
|
Si = choleskies.multiple_dpotri(L)
|
||||||
logdetS = np.array([2.*np.sum(np.log(np.abs(np.diag(L[:,:,i])))) for i in range(L.shape[-1])])
|
logdetS = np.array([2.*np.sum(np.log(np.abs(np.diag(L[i,:,:])))) for i in range(L.shape[0])])
|
||||||
|
|
||||||
if np.any(np.isinf(Si)):
|
if np.any(np.isinf(Si)):
|
||||||
raise ValueError("Cholesky representation unstable")
|
raise ValueError("Cholesky representation unstable")
|
||||||
|
|
@ -46,16 +45,17 @@ class SVGP(LatentFunctionInference):
|
||||||
A = np.dot(Knm, Kmmi)
|
A = np.dot(Knm, Kmmi)
|
||||||
mu = prior_mean_f + np.dot(A, q_u_mean - prior_mean_u)
|
mu = prior_mean_f + np.dot(A, q_u_mean - prior_mean_u)
|
||||||
#v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * np.einsum('ij,jlk->ilk', A, S),1)
|
#v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * np.einsum('ij,jlk->ilk', A, S),1)
|
||||||
v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * linalg.ij_jlk_to_ilk(A, S),1)
|
#v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + np.sum(A[:,:,None] * linalg.ij_jlk_to_ilk(A, S),1)
|
||||||
|
v = Knn_diag[:,None] - np.sum(A*Knm,1)[:,None] + (S.dot(A.T)*A.T[None,:,:]).sum(1).T
|
||||||
|
|
||||||
#compute the KL term
|
#compute the KL term
|
||||||
Kmmim = np.dot(Kmmi, q_u_mean)
|
Kmmim = np.dot(Kmmi, q_u_mean)
|
||||||
KLs = -0.5*logdetS -0.5*num_inducing + 0.5*logdetKmm + 0.5*np.sum(Kmmi[:,:,None]*S,0).sum(0) + 0.5*np.sum(q_u_mean*Kmmim,0)
|
KLs = -0.5*logdetS -0.5*num_inducing + 0.5*logdetKmm + 0.5*np.sum(Kmmi[None,:,:]*S,1).sum(1) + 0.5*np.sum(q_u_mean*Kmmim,0)
|
||||||
KL = KLs.sum()
|
KL = KLs.sum()
|
||||||
#gradient of the KL term (assuming zero mean function)
|
#gradient of the KL term (assuming zero mean function)
|
||||||
dKL_dm = Kmmim.copy()
|
dKL_dm = Kmmim.copy()
|
||||||
dKL_dS = 0.5*(Kmmi[:,:,None] - Si)
|
dKL_dS = 0.5*(Kmmi[None,:,:] - Si)
|
||||||
dKL_dKmm = 0.5*num_outputs*Kmmi - 0.5*Kmmi.dot(S.sum(-1)).dot(Kmmi) - 0.5*Kmmim.dot(Kmmim.T)
|
dKL_dKmm = 0.5*num_outputs*Kmmi - 0.5*Kmmi.dot(S.sum(0)).dot(Kmmi) - 0.5*Kmmim.dot(Kmmim.T)
|
||||||
|
|
||||||
if mean_function is not None:
|
if mean_function is not None:
|
||||||
#adjust KL term for mean function
|
#adjust KL term for mean function
|
||||||
|
|
@ -80,17 +80,14 @@ class SVGP(LatentFunctionInference):
|
||||||
dF_dthetaL = dF_dthetaL.sum(1).sum(1)*batch_scale
|
dF_dthetaL = dF_dthetaL.sum(1).sum(1)*batch_scale
|
||||||
|
|
||||||
#derivatives of expected likelihood, assuming zero mean function
|
#derivatives of expected likelihood, assuming zero mean function
|
||||||
Adv = A.T[:,:,None]*dF_dv[None,:,:] # As if dF_Dv is diagonal
|
Adv = A.T[None,:,:]*dF_dv.T[:,None,:] # As if dF_Dv is diagonal, D, M, N
|
||||||
Admu = A.T.dot(dF_dmu)
|
Admu = A.T.dot(dF_dmu)
|
||||||
AdvA = np.dstack([np.dot(A.T, Adv[:,:,i].T) for i in range(num_outputs)])
|
AdvA = np.dot(Adv, A) # D, M, M
|
||||||
#tmp = np.einsum('ijk,jlk->il', AdvA, S).dot(Kmmi)
|
tmp = np.sum([np.dot(a,s) for a, s in zip(AdvA, S)],0).dot(Kmmi)
|
||||||
tmp = linalg.ijk_jlk_to_il(AdvA, S).dot(Kmmi)
|
dF_dKmm = -Admu.dot(Kmmim.T) + AdvA.sum(0) - tmp - tmp.T
|
||||||
dF_dKmm = -Admu.dot(Kmmim.T) + AdvA.sum(-1) - tmp - tmp.T
|
|
||||||
dF_dKmm = 0.5*(dF_dKmm + dF_dKmm.T) # necessary? GPy bug?
|
dF_dKmm = 0.5*(dF_dKmm + dF_dKmm.T) # necessary? GPy bug?
|
||||||
#tmp = 2.*(np.einsum('ij,jlk->ilk', Kmmi,S) - np.eye(num_inducing)[:,:,None])
|
tmp = 2.*(S.dot(Kmmi).swapaxes(1,2) - np.eye(num_inducing)[None, :,:]) # TODO: transpose?
|
||||||
tmp = 2.*(linalg.ij_jlk_to_ilk(Kmmi, S) - np.eye(num_inducing)[:,:,None])
|
dF_dKmn = np.sum([np.dot(a,b) for a,b in zip(tmp, Adv)],0) + Kmmim.dot(dF_dmu.T)
|
||||||
#dF_dKmn = np.einsum('ijk,jlk->il', tmp, Adv) + Kmmim.dot(dF_dmu.T)
|
|
||||||
dF_dKmn = linalg.ijk_jlk_to_il(tmp, Adv) + Kmmim.dot(dF_dmu.T)
|
|
||||||
dF_dm = Admu
|
dF_dm = Admu
|
||||||
dF_dS = AdvA
|
dF_dS = AdvA
|
||||||
|
|
||||||
|
|
@ -103,14 +100,16 @@ class SVGP(LatentFunctionInference):
|
||||||
|
|
||||||
|
|
||||||
#sum (gradients of) expected likelihood and KL part
|
#sum (gradients of) expected likelihood and KL part
|
||||||
log_marginal = F.sum() - KL
|
log_marginal = F.sum()
|
||||||
dL_dm, dL_dS, dL_dKmm, dL_dKmn = dF_dm - dKL_dm, dF_dS- dKL_dS, dF_dKmm- dKL_dKmm, dF_dKmn
|
dL_dm, dL_dS, dL_dKmm, dL_dKmn = dF_dm - dKL_dm*0, dF_dS- dKL_dS*0, dF_dKmm- dKL_dKmm*0, dF_dKmn
|
||||||
|
#log_marginal = F.sum() - KL
|
||||||
|
#dL_dm, dL_dS, dL_dKmm, dL_dKmn = dF_dm - dKL_dm, dF_dS- dKL_dS, dF_dKmm- dKL_dKmm, dF_dKmn
|
||||||
|
|
||||||
dL_dchol = np.dstack([2.*np.dot(dL_dS[:,:,i], L[:,:,i]) for i in range(num_outputs)])
|
dL_dchol = 2.*np.array([np.dot(a,b) for a, b in zip(dL_dS, L) ])
|
||||||
dL_dchol = choleskies.triang_to_flat(dL_dchol)
|
dL_dchol = choleskies.triang_to_flat(dL_dchol)
|
||||||
|
|
||||||
grad_dict = {'dL_dKmm':dL_dKmm, 'dL_dKmn':dL_dKmn, 'dL_dKdiag': dF_dv.sum(1), 'dL_dm':dL_dm, 'dL_dchol':dL_dchol, 'dL_dthetaL':dF_dthetaL}
|
grad_dict = {'dL_dKmm':dL_dKmm, 'dL_dKmn':dL_dKmn, 'dL_dKdiag': dF_dv.sum(1), 'dL_dm':dL_dm, 'dL_dchol':dL_dchol, 'dL_dthetaL':dF_dthetaL}
|
||||||
if mean_function is not None:
|
if mean_function is not None:
|
||||||
grad_dict['dL_dmfZ'] = dF_dmfZ - dKL_dmfZ
|
grad_dict['dL_dmfZ'] = dF_dmfZ - dKL_dmfZ
|
||||||
grad_dict['dL_dmfX'] = dF_dmfX
|
grad_dict['dL_dmfX'] = dF_dmfX
|
||||||
return Posterior(mean=q_u_mean, cov=S, K=Kmm, prior_mean=prior_mean_u), log_marginal, grad_dict
|
return Posterior(mean=q_u_mean, cov=S.T, K=Kmm, prior_mean=prior_mean_u), log_marginal, grad_dict
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,8 @@ These tests make sure that the opure python and cython codes work the same
|
||||||
|
|
||||||
class CythonTestChols(np.testing.TestCase):
|
class CythonTestChols(np.testing.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.flat = np.random.randn(45, 5)
|
self.flat = np.random.randn(5,45)
|
||||||
self.triang = np.dstack([np.eye(20)[:,:,None] for i in range(3)])
|
self.triang = np.array([np.eye(20) for i in range(3)])
|
||||||
def test_flat_to_triang(self):
|
def test_flat_to_triang(self):
|
||||||
L1 = choleskies._flat_to_triang_pure(self.flat)
|
L1 = choleskies._flat_to_triang_pure(self.flat)
|
||||||
L2 = choleskies._flat_to_triang_cython(self.flat)
|
L2 = choleskies._flat_to_triang_cython(self.flat)
|
||||||
|
|
|
||||||
|
|
@ -15,33 +15,33 @@ def safe_root(N):
|
||||||
return j
|
return j
|
||||||
|
|
||||||
def _flat_to_triang_pure(flat_mat):
|
def _flat_to_triang_pure(flat_mat):
|
||||||
N, D = flat_mat.shape
|
D, N = flat_mat.shape
|
||||||
M = (-1 + safe_root(8*N+1))//2
|
M = (-1 + safe_root(8*N+1))//2
|
||||||
ret = np.zeros((M, M, D))
|
ret = np.zeros((D, M, M))
|
||||||
|
for d in range(D):
|
||||||
count = 0
|
count = 0
|
||||||
for m in range(M):
|
for m in range(M):
|
||||||
for mm in range(m+1):
|
for mm in range(m+1):
|
||||||
for d in range(D):
|
ret[d,m, mm] = flat_mat[d, count];
|
||||||
ret.flat[d + m*D*M + mm*D] = flat_mat.flat[count];
|
|
||||||
count = count+1
|
count = count+1
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def _flat_to_triang_cython(flat_mat):
|
def _flat_to_triang_cython(flat_mat):
|
||||||
N, D = flat_mat.shape
|
D, N = flat_mat.shape
|
||||||
M = (-1 + safe_root(8*N+1))//2
|
M = (-1 + safe_root(8*N+1))//2
|
||||||
return choleskies_cython.flat_to_triang(flat_mat, M)
|
return choleskies_cython.flat_to_triang(flat_mat, M)
|
||||||
|
|
||||||
|
|
||||||
def _triang_to_flat_pure(L):
|
def _triang_to_flat_pure(L):
|
||||||
M, _, D = L.shape
|
D, _, M = L.shape
|
||||||
|
|
||||||
N = M*(M+1)//2
|
N = M*(M+1)//2
|
||||||
flat = np.empty((N, D))
|
flat = np.empty((D, N))
|
||||||
|
for d in range(D):
|
||||||
count = 0;
|
count = 0;
|
||||||
for m in range(M):
|
for m in range(M):
|
||||||
for mm in range(m+1):
|
for mm in range(m+1):
|
||||||
for d in range(D):
|
flat[d,count] = L[d, m, mm]
|
||||||
flat.flat[count] = L.flat[d + m*D*M + mm*D];
|
|
||||||
count = count +1
|
count = count +1
|
||||||
return flat
|
return flat
|
||||||
|
|
||||||
|
|
@ -74,7 +74,7 @@ def triang_to_cov(L):
|
||||||
return np.dstack([np.dot(L[:,:,i], L[:,:,i].T) for i in range(L.shape[-1])])
|
return np.dstack([np.dot(L[:,:,i], L[:,:,i].T) for i in range(L.shape[-1])])
|
||||||
|
|
||||||
def multiple_dpotri(Ls):
|
def multiple_dpotri(Ls):
|
||||||
return np.dstack([linalg.dpotri(np.asfortranarray(Ls[:,:,i]), lower=1)[0] for i in range(Ls.shape[-1])])
|
return np.array([linalg.dpotri(np.asfortranarray(Ls[i]), lower=1)[0] for i in range(Ls.shape[0])])
|
||||||
|
|
||||||
def indexes_to_fix_for_low_rank(rank, size):
|
def indexes_to_fix_for_low_rank(rank, size):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -8,37 +8,37 @@ import numpy as np
|
||||||
cimport numpy as np
|
cimport numpy as np
|
||||||
|
|
||||||
def flat_to_triang(np.ndarray[double, ndim=2] flat, int M):
|
def flat_to_triang(np.ndarray[double, ndim=2] flat, int M):
|
||||||
"""take a matrix N x D and return a M X M x D array where
|
"""take a matrix D x N and return a D X M x M array where
|
||||||
|
|
||||||
N = M(M+1)/2
|
N = M(M+1)/2
|
||||||
|
|
||||||
the lower triangluar portion of the d'th slice of the result is filled by the d'th column of flat.
|
the lower triangluar portion of the d'th slice of the result is filled by the d'th column of flat.
|
||||||
"""
|
"""
|
||||||
cdef int N = flat.shape[0]
|
cdef int D = flat.shape[0]
|
||||||
cdef int D = flat.shape[1]
|
cdef int N = flat.shape[1]
|
||||||
cdef int count = 0
|
cdef int count = 0
|
||||||
cdef np.ndarray[double, ndim=3] ret = np.zeros((M, M, D))
|
cdef np.ndarray[double, ndim=3] ret = np.zeros((D, M, M))
|
||||||
cdef int d, m, mm
|
cdef int d, m, mm
|
||||||
for d in range(D):
|
for d in range(D):
|
||||||
count = 0
|
count = 0
|
||||||
for m in range(M):
|
for m in range(M):
|
||||||
for mm in range(m+1):
|
for mm in range(m+1):
|
||||||
ret[m, mm, d] = flat[count,d]
|
ret[d, m, mm] = flat[d,count]
|
||||||
count += 1
|
count += 1
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def triang_to_flat(np.ndarray[double, ndim=3] L):
|
def triang_to_flat(np.ndarray[double, ndim=3] L):
|
||||||
cdef int M = L.shape[0]
|
cdef int D = L.shape[0]
|
||||||
cdef int D = L.shape[2]
|
cdef int M = L.shape[1]
|
||||||
cdef int N = M*(M+1)/2
|
cdef int N = M*(M+1)/2
|
||||||
cdef int count = 0
|
cdef int count = 0
|
||||||
cdef np.ndarray[double, ndim=2] flat = np.empty((N, D))
|
cdef np.ndarray[double, ndim=2] flat = np.empty((D, N))
|
||||||
cdef int d, m, mm
|
cdef int d, m, mm
|
||||||
for d in range(D):
|
for d in range(D):
|
||||||
count = 0
|
count = 0
|
||||||
for m in range(M):
|
for m in range(M):
|
||||||
for mm in range(m+1):
|
for mm in range(m+1):
|
||||||
flat[count,d] = L[m, mm, d]
|
flat[d,count] = L[d, m, mm]
|
||||||
count += 1
|
count += 1
|
||||||
return flat
|
return flat
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue