From c3482e7a94f82416e7abc4436656b46866f942f7 Mon Sep 17 00:00:00 2001 From: Zhenwen Dai Date: Thu, 26 Jun 2014 13:46:47 +0100 Subject: [PATCH] fix the gpu initialization for multiple cards --- GPy/kern/_src/psi_comp/rbf_psi_gpucomp.py | 1 - GPy/kern/_src/psi_comp/ssrbf_psi_gpucomp.py | 1 - GPy/kern/_src/rbf.py | 2 - GPy/util/gpu_init.py | 56 ++++++++++++--------- GPy/util/parallel.py | 14 ++++++ 5 files changed, 45 insertions(+), 29 deletions(-) create mode 100644 GPy/util/parallel.py diff --git a/GPy/kern/_src/psi_comp/rbf_psi_gpucomp.py b/GPy/kern/_src/psi_comp/rbf_psi_gpucomp.py index d8fec65b..623c45c4 100644 --- a/GPy/kern/_src/psi_comp/rbf_psi_gpucomp.py +++ b/GPy/kern/_src/psi_comp/rbf_psi_gpucomp.py @@ -242,7 +242,6 @@ gpu_code = """ class PSICOMP_RBF_GPU(PSICOMP_RBF): def __init__(self, threadnum=128, blocknum=15, GPU_direct=False): - assert gpu_init.initSuccess, "GPU initialization failed!" self.GPU_direct = GPU_direct self.cublas_handle = gpu_init.cublas_handle self.gpuCache = None diff --git a/GPy/kern/_src/psi_comp/ssrbf_psi_gpucomp.py b/GPy/kern/_src/psi_comp/ssrbf_psi_gpucomp.py index 9c699daa..0c28794c 100644 --- a/GPy/kern/_src/psi_comp/ssrbf_psi_gpucomp.py +++ b/GPy/kern/_src/psi_comp/ssrbf_psi_gpucomp.py @@ -292,7 +292,6 @@ gpu_code = """ class PSICOMP_SSRBF_GPU(PSICOMP_RBF): def __init__(self, threadnum=128, blocknum=15, GPU_direct=False): - assert gpu_init.initSuccess, "GPU initialization failed!" self.GPU_direct = GPU_direct self.cublas_handle = gpu_init.cublas_handle self.gpuCache = None diff --git a/GPy/kern/_src/rbf.py b/GPy/kern/_src/rbf.py index bd768769..90c9100b 100644 --- a/GPy/kern/_src/rbf.py +++ b/GPy/kern/_src/rbf.py @@ -6,7 +6,6 @@ import numpy as np from stationary import Stationary from psi_comp import PSICOMP_RBF from psi_comp.rbf_psi_gpucomp import PSICOMP_RBF_GPU -from ...util.gpu_init import initGPU from ...util.config import * class RBF(Stationary): @@ -25,7 +24,6 @@ class RBF(Stationary): self.group_spike_prob = False self.psicomp = PSICOMP_RBF() if self.useGPU: - initGPU() self.psicomp = PSICOMP_RBF_GPU() else: self.psicomp = PSICOMP_RBF() diff --git a/GPy/util/gpu_init.py b/GPy/util/gpu_init.py index e433e5fa..845d38a1 100644 --- a/GPy/util/gpu_init.py +++ b/GPy/util/gpu_init.py @@ -5,37 +5,43 @@ Global variables: initSuccess providing CUBLAS handle: cublas_handle """ +gpu_initialized = False +gpu_device = None +gpu_context = None +MPI_enabled = False + +try: + from mpi4py import MPI + MPI_enabled = True +except: + pass + +try: + if MPI_enabled and MPI.COMM_WORLD.size>1: + from .parallel import get_id_within_node + gpuid = get_id_within_node() + import pycuda.driver + pycuda.driver.init() + if gpuid>=pycuda.driver.Device.count(): + print '['+MPI.Get_processor_name()+'] more processes than the GPU numbers!' + MPI.COMM_WORLD.Abort() + raise + gpu_device = pycuda.driver.Device(gpuid) + gpu_context = gpu_device.make_context() + gpu_initialized = True + else: + import pycuda.autoinit + gpu_initialized = True +except: + pass + try: from scikits.cuda import cublas import scikits.cuda.linalg as culinalg culinalg.init() cublas_handle = cublas.cublasCreate() except: - -gpu_initialized = False -gpu_device = None -gpu_context = None - -def initGPU(gpuid=None): - if gpu_initialized: - return - if gpuid==None: - try: - import pycuda.autoinit - gpu_initialized = True - except: - pass - else: - try: - import pycuda.driver - pycuda.driver.init() - if gpuid>=pycuda.driver.Device.count(): - return - gpu_device = pycuda.driver.Device(gpuid) - gpu_context = gpu_device.make_context() - gpu_initialized = True - except: - pass + pass def closeGPU(): if gpu_context is not None: diff --git a/GPy/util/parallel.py b/GPy/util/parallel.py new file mode 100644 index 00000000..fd8791d4 --- /dev/null +++ b/GPy/util/parallel.py @@ -0,0 +1,14 @@ +""" +The module of tools for parallelization (MPI) +""" + +try: + from mpi4py import MPI +except: + pass + +def get_id_within_node(comm=MPI.COMM_WORLD): + rank = comm.rank + nodename = MPI.Get_processor_name() + nodelist = comm.allgather(nodename) + return len([i for i in nodelist[:rank] if i==nodename])