[GPU] varDTC_gpu bug fix

This commit is contained in:
Zhenwen Dai 2014-04-04 19:03:35 +01:00
parent 7a74c0b80d
commit 934ecc7e95

View file

@ -30,7 +30,7 @@ class VarDTC_GPU(object):
"""
const_jitter = np.float64(1e-6)
def __init__(self, batchsize, limit=1):
def __init__(self, batchsize=None, limit=1):
self.batchsize = batchsize
@ -45,6 +45,8 @@ class VarDTC_GPU(object):
self.gpuCache = None
def _initGPUCache(self, num_inducing, output_dim, Y):
if self.batchsize==None:
self.batchsize = Y.shape[0]
if self.gpuCache == None:
ndata = Y.shape[0]
self.gpuCache = {# inference_likelihood
@ -398,13 +400,14 @@ class VarDTC_GPU(object):
psi0p_gpu = self.gpuCache['psi0p_gpu']
psi1p_gpu = self.gpuCache['psi1p_gpu']
psi2p_gpu = self.gpuCache['psi2p_gpu']
if psi0p_gpu > nSlice:
if psi0p_gpu.shape[0] > nSlice:
psi0p_gpu = psi0p_gpu[:nSlice]
psi1p_gpu = psi1p_gpu.ravel()[:nSlice*num_inducing].reshape(nSlice,num_inducing)
psi2p_gpu = psi2p_gpu.ravel()[:nSlice*num_inducing*num_inducing].reshape(nSlice,num_inducing,num_inducing)
psi0p_gpu.get(psi0)
psi1p_gpu.get(psi1)
psi2p_gpu.get(psi2)
psi0p_gpu.set(np.asfortranarray(psi0))
psi1p_gpu.set(np.asfortranarray(psi1))
if uncertain_inputs:
psi2p_gpu.set(np.asfortranarray(psi2))
#======================================================================
# Prepare gpu memory