more edits to stationary to clean up for cython

This commit is contained in:
James Hensman 2015-04-28 11:38:33 +01:00
parent 780cf85687
commit a11cf422c2
3 changed files with 424 additions and 93 deletions

View file

@ -22,7 +22,7 @@ def grad_X(int N, int D, int M,
cdef double *grad = <double*> _grad.data
_grad_X(N, D, M, X, X2, tmp, grad) # return nothing, work in place.
def lengthscale_grads(int N, int M, int Q,
def lengthscale_grads_c(int N, int M, int Q,
np.ndarray[DTYPE_t, ndim=2] _tmp,
np.ndarray[DTYPE_t, ndim=2] _X,
np.ndarray[DTYPE_t, ndim=2] _X2,
@ -32,3 +32,14 @@ def lengthscale_grads(int N, int M, int Q,
cdef double *X2 = <double*> _X2.data
cdef double *grad = <double*> _grad.data
_lengthscale_grads(N, M, Q, tmp, X, X2, grad) # return nothing, work in place.
def lengthscale_grads(int N, int M, int Q,
np.ndarray[DTYPE_t, ndim=2] tmp,
np.ndarray[DTYPE_t, ndim=2] X,
np.ndarray[DTYPE_t, ndim=2] X2,
np.ndarray[DTYPE_t, ndim=1] grad):
for q in range(Q):
for i in range(N):
for j in range(M):
grad[q] += tmp[i,j]*(X[i,q]-X2[j,q])**2