diff --git a/GPy/util/gpu_init.py b/GPy/util/gpu_init.py index 0c496db3..94763d8b 100644 --- a/GPy/util/gpu_init.py +++ b/GPy/util/gpu_init.py @@ -10,15 +10,12 @@ gpu_device = None gpu_context = None MPI_enabled = False -try: - from mpi4py import MPI - MPI_enabled = True -except: - pass - - - def initGPU(): + try: + from mpi4py import MPI + MPI_enabled = True + except: + pass try: if MPI_enabled and MPI.COMM_WORLD.size>1: from .parallel import get_id_within_node