[GPU] bug fix

This commit is contained in:
Zhenwen Dai 2014-04-01 18:10:35 +01:00
parent af56b9951c
commit 5875a33d21
3 changed files with 7 additions and 5 deletions

View file

@ -15,7 +15,7 @@ try:
from scikits.cuda import cublas
import pycuda.autoinit
from pycuda.reduction import ReductionKernel
from ...util.linalg_gpu import logDiagSum
from ...util.linalg_gpu import logDiagSum, strideSum
except:
pass
@ -212,7 +212,8 @@ class VarDTC_GPU(object):
cublas.cublasDcopy(self.cublas_handle, psi2_gpu.size, psi2_gpu.gpudata, 1, LmInvPsi2LmInvT_gpu.gpudata, 1)
cublas.cublasDtrsm(self.cublas_handle , 'L', 'L', 'N', 'N', num_inducing, num_inducing, np.float64(1.0), Lm_gpu.gpudata, num_inducing, LmInvPsi2LmInvT_gpu.gpudata, num_inducing)
cublas.cublasDtrsm(self.cublas_handle , 'r', 'L', 'T', 'N', num_inducing, num_inducing, np.float64(1.0), Lm_gpu.gpudata, num_inducing, LmInvPsi2LmInvT_gpu.gpudata, num_inducing)
tr_LmInvPsi2LmInvT = cublas.cublasDasum(self.cublas_handle, num_inducing, LmInvPsi2LmInvT_gpu.gpudata, num_inducing+1)
#tr_LmInvPsi2LmInvT = cublas.cublasDasum(self.cublas_handle, num_inducing, LmInvPsi2LmInvT_gpu.gpudata, num_inducing+1)
tr_LmInvPsi2LmInvT = strideSum(LmInvPsi2LmInvT_gpu, num_inducing+1)
print np.abs(vvt-vvt_gpu.get()).max()
print np.abs(np.trace(LmInvPsi2LmInvT)-tr_LmInvPsi2LmInvT)