mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 20:42:39 +02:00
[GPU] varDTC_gpu bug fix
This commit is contained in:
parent
7a74c0b80d
commit
934ecc7e95
1 changed files with 8 additions and 5 deletions
|
|
@ -30,7 +30,7 @@ class VarDTC_GPU(object):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
const_jitter = np.float64(1e-6)
|
const_jitter = np.float64(1e-6)
|
||||||
def __init__(self, batchsize, limit=1):
|
def __init__(self, batchsize=None, limit=1):
|
||||||
|
|
||||||
self.batchsize = batchsize
|
self.batchsize = batchsize
|
||||||
|
|
||||||
|
|
@ -45,6 +45,8 @@ class VarDTC_GPU(object):
|
||||||
self.gpuCache = None
|
self.gpuCache = None
|
||||||
|
|
||||||
def _initGPUCache(self, num_inducing, output_dim, Y):
|
def _initGPUCache(self, num_inducing, output_dim, Y):
|
||||||
|
if self.batchsize==None:
|
||||||
|
self.batchsize = Y.shape[0]
|
||||||
if self.gpuCache == None:
|
if self.gpuCache == None:
|
||||||
ndata = Y.shape[0]
|
ndata = Y.shape[0]
|
||||||
self.gpuCache = {# inference_likelihood
|
self.gpuCache = {# inference_likelihood
|
||||||
|
|
@ -398,13 +400,14 @@ class VarDTC_GPU(object):
|
||||||
psi0p_gpu = self.gpuCache['psi0p_gpu']
|
psi0p_gpu = self.gpuCache['psi0p_gpu']
|
||||||
psi1p_gpu = self.gpuCache['psi1p_gpu']
|
psi1p_gpu = self.gpuCache['psi1p_gpu']
|
||||||
psi2p_gpu = self.gpuCache['psi2p_gpu']
|
psi2p_gpu = self.gpuCache['psi2p_gpu']
|
||||||
if psi0p_gpu > nSlice:
|
if psi0p_gpu.shape[0] > nSlice:
|
||||||
psi0p_gpu = psi0p_gpu[:nSlice]
|
psi0p_gpu = psi0p_gpu[:nSlice]
|
||||||
psi1p_gpu = psi1p_gpu.ravel()[:nSlice*num_inducing].reshape(nSlice,num_inducing)
|
psi1p_gpu = psi1p_gpu.ravel()[:nSlice*num_inducing].reshape(nSlice,num_inducing)
|
||||||
psi2p_gpu = psi2p_gpu.ravel()[:nSlice*num_inducing*num_inducing].reshape(nSlice,num_inducing,num_inducing)
|
psi2p_gpu = psi2p_gpu.ravel()[:nSlice*num_inducing*num_inducing].reshape(nSlice,num_inducing,num_inducing)
|
||||||
psi0p_gpu.get(psi0)
|
psi0p_gpu.set(np.asfortranarray(psi0))
|
||||||
psi1p_gpu.get(psi1)
|
psi1p_gpu.set(np.asfortranarray(psi1))
|
||||||
psi2p_gpu.get(psi2)
|
if uncertain_inputs:
|
||||||
|
psi2p_gpu.set(np.asfortranarray(psi2))
|
||||||
|
|
||||||
#======================================================================
|
#======================================================================
|
||||||
# Prepare gpu memory
|
# Prepare gpu memory
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue