mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-24 14:15:14 +02:00
[GPU] psi varDTC ready
This commit is contained in:
parent
5e01b94d37
commit
f776db45df
3 changed files with 8 additions and 6 deletions
|
|
@ -328,7 +328,7 @@ class PSICOMP_SSRBF(object):
|
|||
self.gpuCacheAll = None
|
||||
self.gpuCache = None
|
||||
|
||||
@Cache_this(limit=1)
|
||||
@Cache_this(limit=1,ignore_args=(0,))
|
||||
def psicomputations(self, variance, lengthscale, Z, mu, S, gamma):
|
||||
"""Compute Psi statitsitcs"""
|
||||
if isinstance(lengthscale, np.ndarray) and len(lengthscale)>1:
|
||||
|
|
@ -373,7 +373,7 @@ class PSICOMP_SSRBF(object):
|
|||
|
||||
return psi0_gpu, psi1_gpu, psi2_gpu
|
||||
|
||||
@Cache_this(limit=1)
|
||||
@Cache_this(limit=1,ignore_args=(0,))
|
||||
def _psiDercomputations(self, variance, lengthscale, Z, mu, S, gamma):
|
||||
"""Compute the derivatives w.r.t. Psi statistics"""
|
||||
N, M, Q = mu.shape[0],Z.shape[0], mu.shape[1]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue