mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 12:32:40 +02:00
[GPU] varDTC_gpu bug fix
This commit is contained in:
parent
7a74c0b80d
commit
934ecc7e95
1 changed files with 8 additions and 5 deletions
|
|
@ -30,7 +30,7 @@ class VarDTC_GPU(object):
|
|||
|
||||
"""
|
||||
const_jitter = np.float64(1e-6)
|
||||
def __init__(self, batchsize, limit=1):
|
||||
def __init__(self, batchsize=None, limit=1):
|
||||
|
||||
self.batchsize = batchsize
|
||||
|
||||
|
|
@ -45,6 +45,8 @@ class VarDTC_GPU(object):
|
|||
self.gpuCache = None
|
||||
|
||||
def _initGPUCache(self, num_inducing, output_dim, Y):
|
||||
if self.batchsize==None:
|
||||
self.batchsize = Y.shape[0]
|
||||
if self.gpuCache == None:
|
||||
ndata = Y.shape[0]
|
||||
self.gpuCache = {# inference_likelihood
|
||||
|
|
@ -398,13 +400,14 @@ class VarDTC_GPU(object):
|
|||
psi0p_gpu = self.gpuCache['psi0p_gpu']
|
||||
psi1p_gpu = self.gpuCache['psi1p_gpu']
|
||||
psi2p_gpu = self.gpuCache['psi2p_gpu']
|
||||
if psi0p_gpu > nSlice:
|
||||
if psi0p_gpu.shape[0] > nSlice:
|
||||
psi0p_gpu = psi0p_gpu[:nSlice]
|
||||
psi1p_gpu = psi1p_gpu.ravel()[:nSlice*num_inducing].reshape(nSlice,num_inducing)
|
||||
psi2p_gpu = psi2p_gpu.ravel()[:nSlice*num_inducing*num_inducing].reshape(nSlice,num_inducing,num_inducing)
|
||||
psi0p_gpu.get(psi0)
|
||||
psi1p_gpu.get(psi1)
|
||||
psi2p_gpu.get(psi2)
|
||||
psi0p_gpu.set(np.asfortranarray(psi0))
|
||||
psi1p_gpu.set(np.asfortranarray(psi1))
|
||||
if uncertain_inputs:
|
||||
psi2p_gpu.set(np.asfortranarray(psi2))
|
||||
|
||||
#======================================================================
|
||||
# Prepare gpu memory
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue