diff --git a/GPy/kern/_src/rbf.py b/GPy/kern/_src/rbf.py index 90c9100b..bd768769 100644 --- a/GPy/kern/_src/rbf.py +++ b/GPy/kern/_src/rbf.py @@ -6,6 +6,7 @@ import numpy as np from stationary import Stationary from psi_comp import PSICOMP_RBF from psi_comp.rbf_psi_gpucomp import PSICOMP_RBF_GPU +from ...util.gpu_init import initGPU from ...util.config import * class RBF(Stationary): @@ -24,6 +25,7 @@ class RBF(Stationary): self.group_spike_prob = False self.psicomp = PSICOMP_RBF() if self.useGPU: + initGPU() self.psicomp = PSICOMP_RBF_GPU() else: self.psicomp = PSICOMP_RBF() diff --git a/GPy/util/gpu_init.py b/GPy/util/gpu_init.py index 98f711e8..e433e5fa 100644 --- a/GPy/util/gpu_init.py +++ b/GPy/util/gpu_init.py @@ -6,17 +6,37 @@ providing CUBLAS handle: cublas_handle """ try: - import pycuda.autoinit from scikits.cuda import cublas import scikits.cuda.linalg as culinalg culinalg.init() cublas_handle = cublas.cublasCreate() - initSuccess = True except: - initSuccess = False + +gpu_initialized = False +gpu_device = None +gpu_context = None def initGPU(gpuid=None): + if gpu_initialized: + return if gpuid==None: - return pycuda.tools.make_default_context() + try: + import pycuda.autoinit + gpu_initialized = True + except: + pass else: - return pycuda.driver.Device(gpuid).make_context() \ No newline at end of file + try: + import pycuda.driver + pycuda.driver.init() + if gpuid>=pycuda.driver.Device.count(): + return + gpu_device = pycuda.driver.Device(gpuid) + gpu_context = gpu_device.make_context() + gpu_initialized = True + except: + pass + +def closeGPU(): + if gpu_context is not None: + gpu_context.detach()