mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-30 14:35:15 +02:00
allow Y_speedup flag for vardtc_minibatch
This commit is contained in:
parent
deb48407ff
commit
3971e68b9c
1 changed files with 2 additions and 2 deletions
|
|
@ -26,7 +26,7 @@ class VarDTC_minibatch(LatentFunctionInference):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
const_jitter = 1e-6
|
const_jitter = 1e-6
|
||||||
def __init__(self, batchsize=None, limit=1, mpi_comm=None):
|
def __init__(self, batchsize=None, limit=1, mpi_comm=None, Y_speedup=False):
|
||||||
|
|
||||||
self.batchsize = batchsize
|
self.batchsize = batchsize
|
||||||
self.mpi_comm = mpi_comm
|
self.mpi_comm = mpi_comm
|
||||||
|
|
@ -39,7 +39,7 @@ class VarDTC_minibatch(LatentFunctionInference):
|
||||||
|
|
||||||
self.midRes = {}
|
self.midRes = {}
|
||||||
self.batch_pos = 0 # the starting position of the current mini-batch
|
self.batch_pos = 0 # the starting position of the current mini-batch
|
||||||
self.Y_speedup = False # Replace Y with the cholesky factor of YY.T, but the posterior inference will be wrong
|
self.Y_speedup = Y_speedup # Replace Y with the cholesky factor of YY.T, but the posterior inference will be wrong
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
# has to be overridden, as Cacher objects cannot be pickled.
|
# has to be overridden, as Cacher objects cannot be pickled.
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue