svgp working with reordered chols

This commit is contained in:
James Hensman 2015-05-15 08:59:19 +01:00
parent 2249ec06a5
commit 77093a304c
4 changed files with 51 additions and 59 deletions

View file

@ -24,8 +24,6 @@ class SVGP(LatentFunctionInference):
if np.any(np.isinf(Si)):
raise ValueError("Cholesky representation unstable")
#S = S + np.eye(S.shape[0])*1e-5*np.max(np.max(S))
#Si, Lnew, _,_ = linalg.pdinv(S)
#compute mean function stuff
if mean_function is not None:
@ -35,27 +33,21 @@ class SVGP(LatentFunctionInference):
prior_mean_u = np.zeros((num_inducing, num_outputs))
prior_mean_f = np.zeros((num_data, num_outputs))
#compute kernel related stuff
Kmm = kern.K(Z)
Kmn = kern.K(Z, X)
Knn_diag = kern.Kdiag(X)
Kmmi, Lm, Lmi, logdetKmm = linalg.pdinv(Kmm)
Lm = linalg.jitchol(Kmm)
logdetKmm = 2.*np.sum(np.log(np.diag(Lm)))
Kmmi, _ = linalg.dpotri(Lm)
#compute the marginal means and variances of q(f)
A = np.dot(Kmmi, Kmn)
A, _ = linalg.dpotrs(Lm, Kmn)
mu = prior_mean_f + np.dot(A.T, q_u_mean - prior_mean_u)
LA = L.reshape(-1, num_inducing).dot(A).reshape(num_outputs, num_inducing, num_data)
#LA = np.empty((num_outputs, num_inducing, num_data))
#Af = np.asfortranarray(A)
#for Li, LAi in zip(L, LA):
#LAi[:,:] = dtrmm(1., Li.T, Af, side=0, lower=0, trans_a=1, overwrite_b=0)
#stop
#assert np.allclose(LA, LA_)
#TODO? possibly use dtrmm for the above line?
v = (Knn_diag - np.sum(A*Kmn,0))[:,None] + np.sum(np.square(LA),1).T
#compute the KL term
Kmmim = np.dot(Kmmi, q_u_mean)
KLs = -0.5*logdetS -0.5*num_inducing + 0.5*logdetKmm + 0.5*np.sum(Kmmi[None,:,:]*S,1).sum(1) + 0.5*np.sum(q_u_mean*Kmmim,0)
@ -90,7 +82,7 @@ class SVGP(LatentFunctionInference):
#derivatives of expected likelihood, assuming zero mean function
Adv = A[None,:,:]*dF_dv.T[:,None,:] # As if dF_Dv is diagonal, D, M, N
Admu = A.dot(dF_dmu)
Adv = np.ascontiguousarray(Adv) # makes for faster operations later...
Adv = np.ascontiguousarray(Adv) # makes for faster operations later...(inc dsymm)
AdvA = np.dot(Adv.reshape(-1, num_data),A.T).reshape(num_outputs, num_inducing, num_inducing )
tmp = np.sum([np.dot(a,s) for a, s in zip(AdvA, S)],0).dot(Kmmi)
dF_dKmm = -Admu.dot(Kmmim.T) + AdvA.sum(0) - tmp - tmp.T

View file

@ -15,19 +15,19 @@ def safe_root(N):
return j
def _flat_to_triang_pure(flat_mat):
D, N = flat_mat.shape
N, D = flat_mat.shape
M = (-1 + safe_root(8*N+1))//2
ret = np.zeros((D, M, M))
for d in range(D):
count = 0
for m in range(M):
for mm in range(m+1):
ret[d,m, mm] = flat_mat[d, count];
ret[d,m, mm] = flat_mat[count, d];
count = count+1
return ret
def _flat_to_triang_cython(flat_mat):
D, N = flat_mat.shape
N, D = flat_mat.shape
M = (-1 + safe_root(8*N+1))//2
return choleskies_cython.flat_to_triang(flat_mat, M)
@ -36,12 +36,12 @@ def _triang_to_flat_pure(L):
D, _, M = L.shape
N = M*(M+1)//2
flat = np.empty((D, N))
flat = np.empty((N, D))
for d in range(D):
count = 0;
for m in range(M):
for mm in range(m+1):
flat[d,count] = L[d, m, mm]
flat[count,d] = L[d, m, mm]
count = count +1
return flat

View file

@ -1167,13 +1167,13 @@ static PyObject *__pyx_codeobj__12;
* cimport numpy as np
*
* def flat_to_triang(np.ndarray[double, ndim=2] flat, int M): # <<<<<<<<<<<<<<
* """take a matrix D x N and return a D X M x M array where
* """take a matrix N x D and return a D X M x M array where
*
*/
/* Python wrapper */
static PyObject *__pyx_pw_3GPy_4util_17choleskies_cython_1flat_to_triang(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/
static char __pyx_doc_3GPy_4util_17choleskies_cython_flat_to_triang[] = "take a matrix D x N and return a D X M x M array where\n\n N = M(M+1)/2\n\n the lower triangluar portion of the d'th slice of the result is filled by the d'th column of flat.\n ";
static char __pyx_doc_3GPy_4util_17choleskies_cython_flat_to_triang[] = "take a matrix N x D and return a D X M x M array where\n\n N = M(M+1)/2\n\n the lower triangluar portion of the d'th slice of the result is filled by the d'th column of flat.\n ";
static PyMethodDef __pyx_mdef_3GPy_4util_17choleskies_cython_1flat_to_triang = {"flat_to_triang", (PyCFunction)__pyx_pw_3GPy_4util_17choleskies_cython_1flat_to_triang, METH_VARARGS|METH_KEYWORDS, __pyx_doc_3GPy_4util_17choleskies_cython_flat_to_triang};
static PyObject *__pyx_pw_3GPy_4util_17choleskies_cython_1flat_to_triang(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) {
PyArrayObject *__pyx_v_flat = 0;
@ -1292,24 +1292,24 @@ static PyObject *__pyx_pf_3GPy_4util_17choleskies_cython_flat_to_triang(CYTHON_U
/* "GPy/util/choleskies_cython.pyx":17
* the lower triangluar portion of the d'th slice of the result is filled by the d'th column of flat.
* """
* cdef int D = flat.shape[0] # <<<<<<<<<<<<<<
* cdef int N = flat.shape[1]
* cdef int D = flat.shape[1] # <<<<<<<<<<<<<<
* cdef int N = flat.shape[0]
* cdef int count = 0
*/
__pyx_v_D = (__pyx_v_flat->dimensions[0]);
__pyx_v_D = (__pyx_v_flat->dimensions[1]);
/* "GPy/util/choleskies_cython.pyx":18
* """
* cdef int D = flat.shape[0]
* cdef int N = flat.shape[1] # <<<<<<<<<<<<<<
* cdef int D = flat.shape[1]
* cdef int N = flat.shape[0] # <<<<<<<<<<<<<<
* cdef int count = 0
* cdef np.ndarray[double, ndim=3] ret = np.zeros((D, M, M))
*/
__pyx_v_N = (__pyx_v_flat->dimensions[1]);
__pyx_v_N = (__pyx_v_flat->dimensions[0]);
/* "GPy/util/choleskies_cython.pyx":19
* cdef int D = flat.shape[0]
* cdef int N = flat.shape[1]
* cdef int D = flat.shape[1]
* cdef int N = flat.shape[0]
* cdef int count = 0 # <<<<<<<<<<<<<<
* cdef np.ndarray[double, ndim=3] ret = np.zeros((D, M, M))
* cdef int d, m, mm
@ -1317,7 +1317,7 @@ static PyObject *__pyx_pf_3GPy_4util_17choleskies_cython_flat_to_triang(CYTHON_U
__pyx_v_count = 0;
/* "GPy/util/choleskies_cython.pyx":20
* cdef int N = flat.shape[1]
* cdef int N = flat.shape[0]
* cdef int count = 0
* cdef np.ndarray[double, ndim=3] ret = np.zeros((D, M, M)) # <<<<<<<<<<<<<<
* cdef int d, m, mm
@ -1410,7 +1410,7 @@ static PyObject *__pyx_pf_3GPy_4util_17choleskies_cython_flat_to_triang(CYTHON_U
* count = 0
* for m in range(M): # <<<<<<<<<<<<<<
* for mm in range(m+1):
* ret[d, m, mm] = flat[d,count]
* ret[d, m, mm] = flat[count,d]
*/
__pyx_t_10 = __pyx_v_M;
for (__pyx_t_11 = 0; __pyx_t_11 < __pyx_t_10; __pyx_t_11+=1) {
@ -1420,7 +1420,7 @@ static PyObject *__pyx_pf_3GPy_4util_17choleskies_cython_flat_to_triang(CYTHON_U
* count = 0
* for m in range(M):
* for mm in range(m+1): # <<<<<<<<<<<<<<
* ret[d, m, mm] = flat[d,count]
* ret[d, m, mm] = flat[count,d]
* count += 1
*/
__pyx_t_12 = (__pyx_v_m + 1);
@ -1430,12 +1430,12 @@ static PyObject *__pyx_pf_3GPy_4util_17choleskies_cython_flat_to_triang(CYTHON_U
/* "GPy/util/choleskies_cython.pyx":26
* for m in range(M):
* for mm in range(m+1):
* ret[d, m, mm] = flat[d,count] # <<<<<<<<<<<<<<
* ret[d, m, mm] = flat[count,d] # <<<<<<<<<<<<<<
* count += 1
* return ret
*/
__pyx_t_14 = __pyx_v_d;
__pyx_t_15 = __pyx_v_count;
__pyx_t_14 = __pyx_v_count;
__pyx_t_15 = __pyx_v_d;
if (__pyx_t_14 < 0) __pyx_t_14 += __pyx_pybuffernd_flat.diminfo[0].shape;
if (__pyx_t_15 < 0) __pyx_t_15 += __pyx_pybuffernd_flat.diminfo[1].shape;
__pyx_t_16 = __pyx_v_d;
@ -1448,7 +1448,7 @@ static PyObject *__pyx_pf_3GPy_4util_17choleskies_cython_flat_to_triang(CYTHON_U
/* "GPy/util/choleskies_cython.pyx":27
* for mm in range(m+1):
* ret[d, m, mm] = flat[d,count]
* ret[d, m, mm] = flat[count,d]
* count += 1 # <<<<<<<<<<<<<<
* return ret
*
@ -1459,7 +1459,7 @@ static PyObject *__pyx_pf_3GPy_4util_17choleskies_cython_flat_to_triang(CYTHON_U
}
/* "GPy/util/choleskies_cython.pyx":28
* ret[d, m, mm] = flat[d,count]
* ret[d, m, mm] = flat[count,d]
* count += 1
* return ret # <<<<<<<<<<<<<<
*
@ -1474,7 +1474,7 @@ static PyObject *__pyx_pf_3GPy_4util_17choleskies_cython_flat_to_triang(CYTHON_U
* cimport numpy as np
*
* def flat_to_triang(np.ndarray[double, ndim=2] flat, int M): # <<<<<<<<<<<<<<
* """take a matrix D x N and return a D X M x M array where
* """take a matrix N x D and return a D X M x M array where
*
*/
@ -1607,7 +1607,7 @@ static PyObject *__pyx_pf_3GPy_4util_17choleskies_cython_2triang_to_flat(CYTHON_
* cdef int M = L.shape[1]
* cdef int N = M*(M+1)/2 # <<<<<<<<<<<<<<
* cdef int count = 0
* cdef np.ndarray[double, ndim=2] flat = np.empty((D, N))
* cdef np.ndarray[double, ndim=2] flat = np.empty((N, D))
*/
__pyx_v_N = __Pyx_div_long((__pyx_v_M * (__pyx_v_M + 1)), 2);
@ -1615,7 +1615,7 @@ static PyObject *__pyx_pf_3GPy_4util_17choleskies_cython_2triang_to_flat(CYTHON_
* cdef int M = L.shape[1]
* cdef int N = M*(M+1)/2
* cdef int count = 0 # <<<<<<<<<<<<<<
* cdef np.ndarray[double, ndim=2] flat = np.empty((D, N))
* cdef np.ndarray[double, ndim=2] flat = np.empty((N, D))
* cdef int d, m, mm
*/
__pyx_v_count = 0;
@ -1623,7 +1623,7 @@ static PyObject *__pyx_pf_3GPy_4util_17choleskies_cython_2triang_to_flat(CYTHON_
/* "GPy/util/choleskies_cython.pyx":35
* cdef int N = M*(M+1)/2
* cdef int count = 0
* cdef np.ndarray[double, ndim=2] flat = np.empty((D, N)) # <<<<<<<<<<<<<<
* cdef np.ndarray[double, ndim=2] flat = np.empty((N, D)) # <<<<<<<<<<<<<<
* cdef int d, m, mm
* for d in range(D):
*/
@ -1632,9 +1632,9 @@ static PyObject *__pyx_pf_3GPy_4util_17choleskies_cython_2triang_to_flat(CYTHON_
__pyx_t_3 = __Pyx_PyObject_GetAttrStr(__pyx_t_2, __pyx_n_s_empty); if (unlikely(!__pyx_t_3)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 35; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
__Pyx_GOTREF(__pyx_t_3);
__Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
__pyx_t_2 = __Pyx_PyInt_From_int(__pyx_v_D); if (unlikely(!__pyx_t_2)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 35; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
__pyx_t_2 = __Pyx_PyInt_From_int(__pyx_v_N); if (unlikely(!__pyx_t_2)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 35; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
__Pyx_GOTREF(__pyx_t_2);
__pyx_t_4 = __Pyx_PyInt_From_int(__pyx_v_N); if (unlikely(!__pyx_t_4)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 35; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
__pyx_t_4 = __Pyx_PyInt_From_int(__pyx_v_D); if (unlikely(!__pyx_t_4)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 35; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
__Pyx_GOTREF(__pyx_t_4);
__pyx_t_5 = PyTuple_New(2); if (unlikely(!__pyx_t_5)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 35; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
__Pyx_GOTREF(__pyx_t_5);
@ -1685,7 +1685,7 @@ static PyObject *__pyx_pf_3GPy_4util_17choleskies_cython_2triang_to_flat(CYTHON_
__pyx_t_1 = 0;
/* "GPy/util/choleskies_cython.pyx":37
* cdef np.ndarray[double, ndim=2] flat = np.empty((D, N))
* cdef np.ndarray[double, ndim=2] flat = np.empty((N, D))
* cdef int d, m, mm
* for d in range(D): # <<<<<<<<<<<<<<
* count = 0
@ -1709,7 +1709,7 @@ static PyObject *__pyx_pf_3GPy_4util_17choleskies_cython_2triang_to_flat(CYTHON_
* count = 0
* for m in range(M): # <<<<<<<<<<<<<<
* for mm in range(m+1):
* flat[d,count] = L[d, m, mm]
* flat[count,d] = L[d, m, mm]
*/
__pyx_t_9 = __pyx_v_M;
for (__pyx_t_10 = 0; __pyx_t_10 < __pyx_t_9; __pyx_t_10+=1) {
@ -1719,7 +1719,7 @@ static PyObject *__pyx_pf_3GPy_4util_17choleskies_cython_2triang_to_flat(CYTHON_
* count = 0
* for m in range(M):
* for mm in range(m+1): # <<<<<<<<<<<<<<
* flat[d,count] = L[d, m, mm]
* flat[count,d] = L[d, m, mm]
* count += 1
*/
__pyx_t_11 = (__pyx_v_m + 1);
@ -1729,7 +1729,7 @@ static PyObject *__pyx_pf_3GPy_4util_17choleskies_cython_2triang_to_flat(CYTHON_
/* "GPy/util/choleskies_cython.pyx":41
* for m in range(M):
* for mm in range(m+1):
* flat[d,count] = L[d, m, mm] # <<<<<<<<<<<<<<
* flat[count,d] = L[d, m, mm] # <<<<<<<<<<<<<<
* count += 1
* return flat
*/
@ -1739,15 +1739,15 @@ static PyObject *__pyx_pf_3GPy_4util_17choleskies_cython_2triang_to_flat(CYTHON_
if (__pyx_t_13 < 0) __pyx_t_13 += __pyx_pybuffernd_L.diminfo[0].shape;
if (__pyx_t_14 < 0) __pyx_t_14 += __pyx_pybuffernd_L.diminfo[1].shape;
if (__pyx_t_15 < 0) __pyx_t_15 += __pyx_pybuffernd_L.diminfo[2].shape;
__pyx_t_16 = __pyx_v_d;
__pyx_t_17 = __pyx_v_count;
__pyx_t_16 = __pyx_v_count;
__pyx_t_17 = __pyx_v_d;
if (__pyx_t_16 < 0) __pyx_t_16 += __pyx_pybuffernd_flat.diminfo[0].shape;
if (__pyx_t_17 < 0) __pyx_t_17 += __pyx_pybuffernd_flat.diminfo[1].shape;
*__Pyx_BufPtrStrided2d(double *, __pyx_pybuffernd_flat.rcbuffer->pybuffer.buf, __pyx_t_16, __pyx_pybuffernd_flat.diminfo[0].strides, __pyx_t_17, __pyx_pybuffernd_flat.diminfo[1].strides) = (*__Pyx_BufPtrStrided3d(double *, __pyx_pybuffernd_L.rcbuffer->pybuffer.buf, __pyx_t_13, __pyx_pybuffernd_L.diminfo[0].strides, __pyx_t_14, __pyx_pybuffernd_L.diminfo[1].strides, __pyx_t_15, __pyx_pybuffernd_L.diminfo[2].strides));
/* "GPy/util/choleskies_cython.pyx":42
* for mm in range(m+1):
* flat[d,count] = L[d, m, mm]
* flat[count,d] = L[d, m, mm]
* count += 1 # <<<<<<<<<<<<<<
* return flat
*
@ -1758,7 +1758,7 @@ static PyObject *__pyx_pf_3GPy_4util_17choleskies_cython_2triang_to_flat(CYTHON_
}
/* "GPy/util/choleskies_cython.pyx":43
* flat[d,count] = L[d, m, mm]
* flat[count,d] = L[d, m, mm]
* count += 1
* return flat # <<<<<<<<<<<<<<
*
@ -4381,7 +4381,7 @@ static int __Pyx_InitCachedConstants(void) {
* cimport numpy as np
*
* def flat_to_triang(np.ndarray[double, ndim=2] flat, int M): # <<<<<<<<<<<<<<
* """take a matrix D x N and return a D X M x M array where
* """take a matrix N x D and return a D X M x M array where
*
*/
__pyx_tuple__7 = PyTuple_Pack(9, __pyx_n_s_flat, __pyx_n_s_M, __pyx_n_s_D, __pyx_n_s_N, __pyx_n_s_count, __pyx_n_s_ret, __pyx_n_s_d, __pyx_n_s_m, __pyx_n_s_mm); if (unlikely(!__pyx_tuple__7)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 10; __pyx_clineno = __LINE__; goto __pyx_L1_error;}
@ -4539,7 +4539,7 @@ PyMODINIT_FUNC PyInit_choleskies_cython(void)
* cimport numpy as np
*
* def flat_to_triang(np.ndarray[double, ndim=2] flat, int M): # <<<<<<<<<<<<<<
* """take a matrix D x N and return a D X M x M array where
* """take a matrix N x D and return a D X M x M array where
*
*/
__pyx_t_1 = PyCFunction_NewEx(&__pyx_mdef_3GPy_4util_17choleskies_cython_1flat_to_triang, NULL, __pyx_n_s_GPy_util_choleskies_cython); if (unlikely(!__pyx_t_1)) {__pyx_filename = __pyx_f[0]; __pyx_lineno = 10; __pyx_clineno = __LINE__; goto __pyx_L1_error;}

View file

@ -8,14 +8,14 @@ import numpy as np
cimport numpy as np
def flat_to_triang(np.ndarray[double, ndim=2] flat, int M):
"""take a matrix D x N and return a D X M x M array where
"""take a matrix N x D and return a D X M x M array where
N = M(M+1)/2
the lower triangluar portion of the d'th slice of the result is filled by the d'th column of flat.
"""
cdef int D = flat.shape[0]
cdef int N = flat.shape[1]
cdef int D = flat.shape[1]
cdef int N = flat.shape[0]
cdef int count = 0
cdef np.ndarray[double, ndim=3] ret = np.zeros((D, M, M))
cdef int d, m, mm
@ -23,7 +23,7 @@ def flat_to_triang(np.ndarray[double, ndim=2] flat, int M):
count = 0
for m in range(M):
for mm in range(m+1):
ret[d, m, mm] = flat[d,count]
ret[d, m, mm] = flat[count,d]
count += 1
return ret
@ -32,13 +32,13 @@ def triang_to_flat(np.ndarray[double, ndim=3] L):
cdef int M = L.shape[1]
cdef int N = M*(M+1)/2
cdef int count = 0
cdef np.ndarray[double, ndim=2] flat = np.empty((D, N))
cdef np.ndarray[double, ndim=2] flat = np.empty((N, D))
cdef int d, m, mm
for d in range(D):
count = 0
for m in range(M):
for mm in range(m+1):
flat[d,count] = L[d, m, mm]
flat[count,d] = L[d, m, mm]
count += 1
return flat