diff --git a/GPy/kern/_src/psi_comp/rbf_psi_gpucomp.py b/GPy/kern/_src/psi_comp/rbf_psi_gpucomp.py index 7b604227..00e0d397 100644 --- a/GPy/kern/_src/psi_comp/rbf_psi_gpucomp.py +++ b/GPy/kern/_src/psi_comp/rbf_psi_gpucomp.py @@ -17,6 +17,7 @@ gpu_code = """ // define THREADNUM #define IDX_NMQ(n,m,q) ((q*M+m)*N+n) + #define IDX_NMM(n,m1,m2) ((m2*M+m1)*N+n) #define IDX_NQ(n,q) (q*N+n) #define IDX_NM(n,m) (m*N+n) #define IDX_MQ(m,q) (q*M+m) @@ -66,15 +67,17 @@ gpu_code = """ double log_psi1 = 0; for(int q=0;q