mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 12:32:40 +02:00
fix add kernel and VarDTC_minibatch speed tuning
This commit is contained in:
parent
129985998c
commit
216de32c0c
3 changed files with 18 additions and 9 deletions
|
|
@ -94,8 +94,12 @@ class VarDTC_minibatch(LatentFunctionInference):
|
|||
|
||||
for n_start in xrange(0,num_data,self.batchsize):
|
||||
n_end = min(self.batchsize+n_start, num_data)
|
||||
Y_slice = Y[n_start:n_end]
|
||||
X_slice = X[n_start:n_end]
|
||||
if (n_end-n_start)==num_data:
|
||||
Y_slice = Y
|
||||
X_slice = X
|
||||
else:
|
||||
Y_slice = Y[n_start:n_end]
|
||||
X_slice = X[n_start:n_end]
|
||||
|
||||
if het_noise:
|
||||
b = beta[n_start]
|
||||
|
|
@ -347,7 +351,9 @@ def update_gradients(model, mpi_comm=None):
|
|||
while not isEnd:
|
||||
isEnd, n_range, grad_dict = model.inference_method.inference_minibatch(model.kern, X, model.Z, model.likelihood, Y)
|
||||
if isinstance(model.X, VariationalPosterior):
|
||||
if mpi_comm ==None:
|
||||
if (n_range[1]-n_range[0])==X.shape[0]:
|
||||
X_slice = X
|
||||
elif mpi_comm ==None:
|
||||
X_slice = model.X[n_range[0]:n_range[1]]
|
||||
else:
|
||||
X_slice = model.X[model.N_range[0]+n_range[0]:model.N_range[0]+n_range[1]]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue