mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-14 15:25:15 +02:00
gently fall back if gpu psicomp fail
This commit is contained in:
parent
cb1f6f1486
commit
c639b0d126
2 changed files with 22 additions and 7 deletions
|
|
@ -24,7 +24,7 @@ class PSICOMP(Pickleable):
|
|||
from .gaussherm import PSICOMP_GH
|
||||
|
||||
class PSICOMP_RBF(PSICOMP):
|
||||
@Cache_this(limit=5, ignore_args=(0,))
|
||||
@Cache_this(limit=10, ignore_args=(0,))
|
||||
def psicomputations(self, kern, Z, variational_posterior, return_psi2_n=False):
|
||||
variance, lengthscale = kern.variance, kern.lengthscale
|
||||
if isinstance(variational_posterior, variational.NormalPosterior):
|
||||
|
|
@ -34,7 +34,7 @@ class PSICOMP_RBF(PSICOMP):
|
|||
else:
|
||||
raise ValueError("unknown distriubtion received for psi-statistics")
|
||||
|
||||
@Cache_this(limit=5, ignore_args=(0,2,3,4))
|
||||
@Cache_this(limit=10, ignore_args=(0,2,3,4))
|
||||
def psiDerivativecomputations(self, kern, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
||||
variance, lengthscale = kern.variance, kern.lengthscale
|
||||
if isinstance(variational_posterior, variational.NormalPosterior):
|
||||
|
|
@ -46,7 +46,7 @@ class PSICOMP_RBF(PSICOMP):
|
|||
|
||||
class PSICOMP_Linear(PSICOMP):
|
||||
|
||||
@Cache_this(limit=5, ignore_args=(0,))
|
||||
@Cache_this(limit=10, ignore_args=(0,))
|
||||
def psicomputations(self, kern, Z, variational_posterior, return_psi2_n=False):
|
||||
variances = kern.variances
|
||||
if isinstance(variational_posterior, variational.NormalPosterior):
|
||||
|
|
@ -56,7 +56,7 @@ class PSICOMP_Linear(PSICOMP):
|
|||
else:
|
||||
raise ValueError("unknown distriubtion received for psi-statistics")
|
||||
|
||||
@Cache_this(limit=2, ignore_args=(0,2,3,4))
|
||||
@Cache_this(limit=10, ignore_args=(0,2,3,4))
|
||||
def psiDerivativecomputations(self, kern, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
||||
variances = kern.variances
|
||||
if isinstance(variational_posterior, variational.NormalPosterior):
|
||||
|
|
|
|||
|
|
@ -235,6 +235,9 @@ gpu_code = """
|
|||
class PSICOMP_RBF_GPU(PSICOMP_RBF):
|
||||
|
||||
def __init__(self, threadnum=256, blocknum=30, GPU_direct=False):
|
||||
from . import PSICOMP_RBF
|
||||
self.fall_back = PSICOMP_RBF()
|
||||
|
||||
from pycuda.compiler import SourceModule
|
||||
from ....util.gpu_init import initGPU
|
||||
initGPU()
|
||||
|
|
@ -318,8 +321,14 @@ class PSICOMP_RBF_GPU(PSICOMP_RBF):
|
|||
def get_dimensions(self, Z, variational_posterior):
|
||||
return variational_posterior.mean.shape[0], Z.shape[0], Z.shape[1]
|
||||
|
||||
@Cache_this(limit=5, ignore_args=(0,))
|
||||
def psicomputations(self, kern, Z, variational_posterior, return_psi2_n=False):
|
||||
try:
|
||||
return self._psicomputations(kern, Z, variational_posterior, return_psi2_n)
|
||||
except:
|
||||
return self.fall_back.psicomputations(kern, Z, variational_posterior, return_psi2_n)
|
||||
|
||||
@Cache_this(limit=10, ignore_args=(0,))
|
||||
def _psicomputations(self, kern, Z, variational_posterior, return_psi2_n=False):
|
||||
"""
|
||||
Z - MxQ
|
||||
mu - NxQ
|
||||
|
|
@ -353,9 +362,15 @@ class PSICOMP_RBF_GPU(PSICOMP_RBF):
|
|||
return psi0, psi1_gpu, psi2_gpu
|
||||
else:
|
||||
return psi0, psi1_gpu.get(), psi2_gpu.get()
|
||||
|
||||
@Cache_this(limit=5, ignore_args=(0,2,3,4))
|
||||
|
||||
def psiDerivativecomputations(self, kern, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
||||
try:
|
||||
return self._psiDerivativecomputations(kern, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior)
|
||||
except:
|
||||
return self.fall_back.psiDerivativecomputations(kern, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior)
|
||||
|
||||
@Cache_this(limit=10, ignore_args=(0,2,3,4))
|
||||
def _psiDerivativecomputations(self, kern, dL_dpsi0, dL_dpsi1, dL_dpsi2, Z, variational_posterior):
|
||||
variance, lengthscale = kern.variance, kern.lengthscale
|
||||
from ....util.linalg_gpu import sum_axis
|
||||
ARD = (len(lengthscale)!=1)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue