diff --git a/GPy/kern/_src/psi_comp/ssrbf_psi_gpucomp.py b/GPy/kern/_src/psi_comp/ssrbf_psi_gpucomp.py index 6ad9b20a..ad186594 100644 --- a/GPy/kern/_src/psi_comp/ssrbf_psi_gpucomp.py +++ b/GPy/kern/_src/psi_comp/ssrbf_psi_gpucomp.py @@ -90,7 +90,7 @@ try: # The kernel form computing psi1 het_noise comp_dpsi1_dvar = ElementwiseKernel( - "double *dpsi1_dvar, double *psi1_neq, double *psi1exp1, double *psi11exp2, double *l, double *Z, double *mu, double *S, double *logGamma, double *log1Gamma, double *logpsi1denom, int N, int M, int Q", + "double *dpsi1_dvar, double *psi1_neq, double *psi1exp1, double *psi1exp2, double *l, double *Z, double *mu, double *S, double *logGamma, double *log1Gamma, double *logpsi1denom, int N, int M, int Q", "dpsi1_dvar[i] = comp_dpsi1_dvar_element(psi1_neq, psi1exp1, psi1exp2, l, Z, mu, S, logGamma, log1Gamma, logpsi1denom, N, M, Q, i)", "comp_dpsi1_dvar", preamble=""" @@ -99,7 +99,7 @@ try: #define IDX_MQ(m,q) (q*M+m) #define LOGEXPSUM(a,b) (a>=b?a+log(1.0+exp(b-a)):b+log(1.0+exp(a-b))) - __device__ double comp_dpsi1_dvar_element(double *psi1_neq, double *psi1exp1, double *psi11exp2, double *l, double *Z, double *mu, double *S, double *logGamma, double *log1Gamma, double *logpsi1denom, int N, int M, int Q, int idx) + __device__ double comp_dpsi1_dvar_element(double *psi1_neq, double *psi1exp1, double *psi1exp2, double *l, double *Z, double *mu, double *S, double *logGamma, double *log1Gamma, double *logpsi1denom, int N, int M, int Q, int idx) { int n = idx%N; int m = idx/N; @@ -107,9 +107,9 @@ try: double psi1_sum = 0; for(int q=0;q=b?a+log(1.0+exp(b-a)):b+log(1.0+exp(a-b))) - __device__ double comp_dpsi1_der_element(double *dpsi1_dmu, double *dpsi1_dS, double *dpsi1_dgamma, double *dpsi1_dZ, double var, double *psi1_neq, double psi1exp1, double *psi11exp2, double *l, double *Z, double *mu, double *S, double *gamma, int N, int M, int Q, int idx) + __device__ double comp_psi1_der_element(double *dpsi1_dmu, double *dpsi1_dS, double *dpsi1_dgamma, double *dpsi1_dZ, double *psi1_neq, double *psi1exp1, double *psi1exp2, double var, double *l, double *Z, double *mu, double *S, double *gamma, int N, int M, int Q, int idx) { int q = idx/(M*N); int m = (idx%(M*N))/N; @@ -146,6 +145,7 @@ try: double Z_c = Z[IDX_MQ(m,q)]; double S_c = S[IDX_NQ(n,q)]; double l_c = l[q]; + double l_sqrt_c = sqrt(l[q]); double psi1exp1_c = psi1exp1[IDX_NMQ(n,m,q)]; double psi1exp2_c = psi1exp2[IDX_MQ(m,q)]; @@ -153,13 +153,101 @@ try: double denom_sqrt = sqrt(denom); double Zmu = Z_c-mu[IDX_NQ(n,q)]; double psi1_common = gamma_c/(denom_sqrt*denom*l_c); - double gamma1 = 1-gamma_c + double gamma1 = 1-gamma_c; dpsi1_dgamma[IDX_NMQ(n,m,q)] = var*neq*(psi1exp1_c/denom_sqrt - psi1exp2_c); dpsi1_dmu[IDX_NMQ(n,m,q)] = var*neq*(psi1_common*Zmu*psi1exp1_c); dpsi1_dS[IDX_NMQ(n,m,q)] = var*neq*(psi1_common*(Zmu*Zmu/(S_c+l_c)-1.0)*psi1exp1_c)/2.0; dpsi1_dZ[IDX_NMQ(n,m,q)] = var*neq*(-psi1_common*Zmu*psi1exp1_c-gamma1*Z_c/l_c*psi1exp2_c); - return var*neq*(psi1_common*(S_c/l_c+Zmu*Zmu/(S_c+l_c))*psi1exp1_c+gamma1*Z_c*Z_c/l_c*psi1exp2_c)/2.0; + return var*neq*(psi1_common*(S_c/l_c+Zmu*Zmu/(S_c+l_c))*psi1exp1_c+gamma1*Z_c*Z_c/l_c*psi1exp2_c)*l_sqrt_c; + } + """) + + # The kernel form computing psi1 het_noise + comp_dpsi2_dvar = ElementwiseKernel( + "double *dpsi2_dvar, double *psi2_neq, double *psi2exp1, double *psi2exp2, double var, double *l, double *Z, double *mu, double *S, double *logGamma, double *log1Gamma, double *logpsi2denom, int N, int M, int Q", + "dpsi2_dvar[i] = comp_dpsi2_dvar_element(psi2_neq, psi2exp1, psi2exp2, var, l, Z, mu, S, logGamma, log1Gamma, logpsi2denom, N, M, Q, i)", + "comp_dpsi2_dvar", + preamble=""" + #define IDX_NMMQ(n,m1,m2,q) (((q*M+m2)*M+m1)*N+n) + #define IDX_MMQ(m1,m2,q) ((q*M+m2)*M+m1) + #define IDX_NMQ(n,m,q) ((q*M+m)*N+n) + #define IDX_NQ(n,q) (q*N+n) + #define IDX_MQ(m,q) (q*M+m) + #define LOGEXPSUM(a,b) (a>=b?a+log(1.0+exp(b-a)):b+log(1.0+exp(a-b))) + + __device__ double comp_dpsi2_dvar_element(double *psi2_neq, double *psi2exp1, double *psi2exp2, double var, double *l, double *Z, double *mu, double *S, double *logGamma, double *log1Gamma, double *logpsi2denom, int N, int M, int Q, int idx) + { + // psi2 (n,m1,m2) + int m2 = idx/(M*N); + int m1 = (idx%(M*N))/N; + int n = idx%N; + + double psi2_sum=0; + for(int q=0;q