allow Y_speedup flag for vardtc_minibatch

This commit is contained in:
Zhenwen Dai 2014-10-17 15:34:14 +01:00
parent deb48407ff
commit 3971e68b9c

View file

@ -26,7 +26,7 @@ class VarDTC_minibatch(LatentFunctionInference):
""" """
const_jitter = 1e-6 const_jitter = 1e-6
def __init__(self, batchsize=None, limit=1, mpi_comm=None): def __init__(self, batchsize=None, limit=1, mpi_comm=None, Y_speedup=False):
self.batchsize = batchsize self.batchsize = batchsize
self.mpi_comm = mpi_comm self.mpi_comm = mpi_comm
@ -39,7 +39,7 @@ class VarDTC_minibatch(LatentFunctionInference):
self.midRes = {} self.midRes = {}
self.batch_pos = 0 # the starting position of the current mini-batch self.batch_pos = 0 # the starting position of the current mini-batch
self.Y_speedup = False # Replace Y with the cholesky factor of YY.T, but the posterior inference will be wrong self.Y_speedup = Y_speedup # Replace Y with the cholesky factor of YY.T, but the posterior inference will be wrong
def __getstate__(self): def __getstate__(self):
# has to be overridden, as Cacher objects cannot be pickled. # has to be overridden, as Cacher objects cannot be pickled.