diff --git a/GPy/inference/latent_function_inference/var_dtc_parallel.py b/GPy/inference/latent_function_inference/var_dtc_parallel.py index dfa16c4a..43e52e2b 100644 --- a/GPy/inference/latent_function_inference/var_dtc_parallel.py +++ b/GPy/inference/latent_function_inference/var_dtc_parallel.py @@ -26,7 +26,7 @@ class VarDTC_minibatch(LatentFunctionInference): """ const_jitter = 1e-6 - def __init__(self, batchsize=None, limit=1, mpi_comm=None): + def __init__(self, batchsize=None, limit=1, mpi_comm=None, Y_speedup=False): self.batchsize = batchsize self.mpi_comm = mpi_comm @@ -39,7 +39,7 @@ class VarDTC_minibatch(LatentFunctionInference): self.midRes = {} self.batch_pos = 0 # the starting position of the current mini-batch - self.Y_speedup = False # Replace Y with the cholesky factor of YY.T, but the posterior inference will be wrong + self.Y_speedup = Y_speedup # Replace Y with the cholesky factor of YY.T, but the posterior inference will be wrong def __getstate__(self): # has to be overridden, as Cacher objects cannot be pickled.