mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-08 03:22:38 +02:00
added cython code for lengthscale gradients
This commit is contained in:
parent
b36a845821
commit
2e8ce34ee0
5 changed files with 369 additions and 39 deletions
|
|
@ -7,6 +7,9 @@ ctypedef np.float64_t DTYPE_t
|
|||
|
||||
cdef extern from "stationary_utils.h":
|
||||
void _grad_X "_grad_X" (int N, int D, int M, double* X, double* X2, double* tmp, double* grad)
|
||||
|
||||
cdef extern from "stationary_utils.h":
|
||||
void _lengthscale_grads "_lengthscale_grads" (int N, int M, int Q, double* tmp, double* X, double* X2, double* grad)
|
||||
|
||||
def grad_X(int N, int D, int M,
|
||||
np.ndarray[DTYPE_t, ndim=2] _X,
|
||||
|
|
@ -18,3 +21,14 @@ def grad_X(int N, int D, int M,
|
|||
cdef double *tmp = <double*> _tmp.data
|
||||
cdef double *grad = <double*> _grad.data
|
||||
_grad_X(N, D, M, X, X2, tmp, grad) # return nothing, work in place.
|
||||
|
||||
def lengthscale_grads(int N, int M, int Q,
|
||||
np.ndarray[DTYPE_t, ndim=2] _tmp,
|
||||
np.ndarray[DTYPE_t, ndim=2] _X,
|
||||
np.ndarray[DTYPE_t, ndim=2] _X2,
|
||||
np.ndarray[DTYPE_t, ndim=1] _grad):
|
||||
cdef double *tmp = <double*> _tmp.data
|
||||
cdef double *X = <double*> _X.data
|
||||
cdef double *X2 = <double*> _X2.data
|
||||
cdef double *grad = <double*> _grad.data
|
||||
_lengthscale_grads(N, M, Q, tmp, X, X2, grad) # return nothing, work in place.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue