diff --git a/GPy/kern/_src/psi_comp/ssrbf_psi_gpucomp.py b/GPy/kern/_src/psi_comp/ssrbf_psi_gpucomp.py index da948661..12c39e16 100644 --- a/GPy/kern/_src/psi_comp/ssrbf_psi_gpucomp.py +++ b/GPy/kern/_src/psi_comp/ssrbf_psi_gpucomp.py @@ -414,7 +414,7 @@ class PSICOMP_SSRBF(object): grad_dl_gpu = self.gpuCache['grad_l_gpu'] # variance - variance.gradient = gpuarray.sum(dL_dpsi0) \ + variance.gradient = gpuarray.sum(dL_dpsi0).get() \ + cublas.cublasDdot(self.cublas_handle, dL_dpsi1.size, dL_dpsi1.gpudata, 1, dpsi1_dvar_gpu.gpudata, 1) \ + cublas.cublasDdot(self.cublas_handle, dL_dpsi2.size, dL_dpsi2.gpudata, 1, dpsi2_dvar_gpu.gpudata, 1) @@ -429,7 +429,7 @@ class PSICOMP_SSRBF(object): else: linalg_gpu.mul_bcast(psi1_comb_gpu, dL_dpsi1, dpsi1_dl_gpu, dL_dpsi1.size) linalg_gpu.mul_bcast(psi2_comb_gpu, dL_dpsi2, dpsi2_dl_gpu, dL_dpsi2.size) - lengthscale.gradient = gpuarray.sum(psi1_comb_gpu) + gpuarray.sum(psi2_comb_gpu) + lengthscale.gradient = gpuarray.sum(psi1_comb_gpu).get() + gpuarray.sum(psi2_comb_gpu).get() def gradients_Z_expectations(self, dL_dpsi1, dL_dpsi2, variance, lengthscale, Z, mu, S, gamma): pass diff --git a/GPy/util/linalg_gpu.py b/GPy/util/linalg_gpu.py index 60eb8101..6f5dc45b 100644 --- a/GPy/util/linalg_gpu.py +++ b/GPy/util/linalg_gpu.py @@ -28,8 +28,8 @@ try: # log(1.0-X) logOne = ElementwiseKernel("double *in, double *out", "out[i] = log(1.-in[i])", "logOne_element") - # multiplication with broadcast on the last dimension - mul_bcast = ElementwiseKernel("double *out, double *shorter, double *longer, int shorter_size", "out[i] = longer[i]*shorter[i%shorter_size]", "mul_bcast") + # multiplication with broadcast on the last dimension (a has to be smaller than b) + mul_bcast = ElementwiseKernel("double *out, double *a, double *b, int a_size", "out[i] = b[i]*a[i % a_size ]", "mul_bcast") # sum through the middle dimension (size_2) of a 3D matrix (size_1, size_2, size_3) sum_axis = ElementwiseKernel("double *out, double *in, int size_1, int size_2", "out[i] += sum_axis_element(in, size_1, size_2, i)", "sum_axis",preamble="""