mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 12:32:40 +02:00
rbf gpu usable
This commit is contained in:
parent
e486f3fd99
commit
ca1edecce4
5 changed files with 140 additions and 146 deletions
|
|
@ -73,13 +73,14 @@ class VarDTC_minibatch(LatentFunctionInference):
|
|||
else:
|
||||
return jitchol(tdot(Y))
|
||||
|
||||
def gatherPsiStat(self, kern, X, Z, Y, beta, uncertain_inputs, het_noise):
|
||||
def gatherPsiStat(self, kern, X, Z, Y, beta, uncertain_inputs):
|
||||
|
||||
num_inducing = Z.shape[0]
|
||||
num_data, output_dim = Y.shape
|
||||
|
||||
if self.batchsize == None:
|
||||
self.batchsize = num_data
|
||||
het_noise = beta.size > 1
|
||||
|
||||
trYYT = self.get_trYYT(Y)
|
||||
|
||||
|
|
@ -88,46 +89,30 @@ class VarDTC_minibatch(LatentFunctionInference):
|
|||
psi0_full = 0.
|
||||
YRY_full = 0.
|
||||
|
||||
for n_start in xrange(0,num_data,self.batchsize):
|
||||
|
||||
for n_start in xrange(0,num_data,self.batchsize):
|
||||
n_end = min(self.batchsize+n_start, num_data)
|
||||
|
||||
Y_slice = Y[n_start:n_end]
|
||||
X_slice = X[n_start:n_end]
|
||||
|
||||
if het_noise:
|
||||
b = beta[n_start]
|
||||
YRY_full += np.inner(Y_slice, Y_slice)*b
|
||||
else:
|
||||
b = beta
|
||||
|
||||
if uncertain_inputs:
|
||||
psi0 = kern.psi0(Z, X_slice)
|
||||
psi1 = kern.psi1(Z, X_slice)
|
||||
psi2 = kern.psi2(Z, X_slice)
|
||||
psi2_full += kern.psi2(Z, X_slice)*b
|
||||
else:
|
||||
psi0 = kern.Kdiag(X_slice)
|
||||
psi1 = kern.K(X_slice, Z)
|
||||
psi2 = None
|
||||
|
||||
if het_noise:
|
||||
beta_slice = beta[n_start:n_end]
|
||||
psi0_full += (beta_slice*psi0).sum()
|
||||
psi1Y_full += np.dot(beta_slice*Y_slice.T,psi1) # DxM
|
||||
YRY_full += (beta_slice*np.square(Y_slice).sum(axis=-1)).sum()
|
||||
else:
|
||||
psi0_full += psi0.sum()
|
||||
psi1Y_full += np.dot(Y_slice.T,psi1) # DxM
|
||||
|
||||
if uncertain_inputs:
|
||||
if het_noise:
|
||||
psi2_full += beta_slice*psi2
|
||||
else:
|
||||
psi2_full += psi2
|
||||
else:
|
||||
if het_noise:
|
||||
psi2_full += beta_slice*np.outer(psi1,psi1)
|
||||
else:
|
||||
psi2_full += np.dot(psi1.T,psi1)
|
||||
psi2_full += np.dot(psi1.T,psi1)*b
|
||||
|
||||
psi0_full += psi0.sum()*b
|
||||
psi1Y_full += np.dot(Y_slice.T,psi1)*b # DxM
|
||||
|
||||
if not het_noise:
|
||||
psi0_full *= beta
|
||||
psi1Y_full *= beta
|
||||
psi2_full *= beta
|
||||
YRY_full = trYYT*beta
|
||||
|
||||
if self.mpi_comm != None:
|
||||
|
|
@ -168,7 +153,7 @@ class VarDTC_minibatch(LatentFunctionInference):
|
|||
if het_noise:
|
||||
self.batchsize = 1
|
||||
|
||||
psi0_full, psi1Y_full, psi2_full, YRY_full = self.gatherPsiStat(kern, X, Z, Y, beta, uncertain_inputs, het_noise)
|
||||
psi0_full, psi1Y_full, psi2_full, YRY_full = self.gatherPsiStat(kern, X, Z, Y, beta, uncertain_inputs)
|
||||
|
||||
#======================================================================
|
||||
# Compute Common Components
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue