mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 03:52:39 +02:00
fix gpu initialziation
This commit is contained in:
parent
0dd52981d0
commit
a77a675549
3 changed files with 29 additions and 28 deletions
|
|
@ -5,7 +5,6 @@ The module for psi-statistics for RBF kernel
|
|||
import numpy as np
|
||||
from paramz.caching import Cache_this
|
||||
from . import PSICOMP_RBF
|
||||
from ....util import gpu_init
|
||||
|
||||
gpu_code = """
|
||||
// define THREADNUM
|
||||
|
|
@ -238,8 +237,6 @@ class PSICOMP_RBF_GPU(PSICOMP_RBF):
|
|||
self.fall_back = PSICOMP_RBF()
|
||||
|
||||
from pycuda.compiler import SourceModule
|
||||
from ....util.gpu_init import initGPU
|
||||
initGPU()
|
||||
|
||||
self.GPU_direct = GPU_direct
|
||||
self.gpuCache = None
|
||||
|
|
|
|||
|
|
@ -287,8 +287,6 @@ class PSICOMP_SSRBF_GPU(PSICOMP_RBF):
|
|||
def __init__(self, threadnum=128, blocknum=15, GPU_direct=False):
|
||||
|
||||
from pycuda.compiler import SourceModule
|
||||
from ....util.gpu_init import initGPU
|
||||
initGPU()
|
||||
|
||||
self.GPU_direct = GPU_direct
|
||||
self.gpuCache = None
|
||||
|
|
|
|||
|
|
@ -10,29 +10,35 @@ gpu_device = None
|
|||
gpu_context = None
|
||||
MPI_enabled = False
|
||||
|
||||
def initGPU():
|
||||
try:
|
||||
from mpi4py import MPI
|
||||
MPI_enabled = True
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
if MPI_enabled and MPI.COMM_WORLD.size>1:
|
||||
from .parallel import get_id_within_node
|
||||
gpuid = get_id_within_node()
|
||||
import pycuda.driver
|
||||
pycuda.driver.init()
|
||||
if gpuid>=pycuda.driver.Device.count():
|
||||
print('['+MPI.Get_processor_name()+'] more processes than the GPU numbers!')
|
||||
raise
|
||||
gpu_device = pycuda.driver.Device(gpuid)
|
||||
gpu_context = gpu_device.make_context()
|
||||
gpu_initialized = True
|
||||
else:
|
||||
import pycuda.autoinit
|
||||
gpu_initialized = True
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
import pycuda.autoinit
|
||||
gpu_initialized = True
|
||||
except:
|
||||
pass
|
||||
|
||||
# def initGPU():
|
||||
# try:
|
||||
# from mpi4py import MPI
|
||||
# MPI_enabled = True
|
||||
# except:
|
||||
# pass
|
||||
# try:
|
||||
# if MPI_enabled and MPI.COMM_WORLD.size>1:
|
||||
# from .parallel import get_id_within_node
|
||||
# gpuid = get_id_within_node()
|
||||
# import pycuda.driver
|
||||
# pycuda.driver.init()
|
||||
# if gpuid>=pycuda.driver.Device.count():
|
||||
# print('['+MPI.Get_processor_name()+'] more processes than the GPU numbers!')
|
||||
# raise
|
||||
# gpu_device = pycuda.driver.Device(gpuid)
|
||||
# gpu_context = gpu_device.make_context()
|
||||
# gpu_initialized = True
|
||||
# else:
|
||||
# import pycuda.autoinit
|
||||
# gpu_initialized = True
|
||||
# except:
|
||||
# pass
|
||||
|
||||
def closeGPU():
|
||||
if gpu_context is not None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue