[ssgplvm] linear kernel

This commit is contained in:
Zhenwen Dai 2014-05-15 11:43:29 +01:00
parent b65da11df5
commit dad476faf6
7 changed files with 162 additions and 104 deletions

View file

@ -29,6 +29,7 @@ class VarDTC_minibatch(object):
self.batchsize = batchsize
self.mpi_comm = mpi_comm
self.limit = limit
# Cache functions
from ...util.caching import Cacher
@ -37,6 +38,20 @@ class VarDTC_minibatch(object):
self.midRes = {}
self.batch_pos = 0 # the starting position of the current mini-batch
def __getstate__(self):
# has to be overridden, as Cacher objects cannot be pickled.
return self.batchsize, self.limit
def __setstate__(self, state):
# has to be overridden, as Cacher objects cannot be pickled.
self.batchsize, self.limit = state
self.mpi_comm = None
self.midRes = {}
self.batch_pos = 0
from ...util.caching import Cacher
self.get_trYYT = Cacher(self._get_trYYT, self.limit)
self.get_YYTfactor = Cacher(self._get_YYTfactor, self.limit)
def set_limit(self, limit):
self.get_trYYT.limit = limit
@ -334,7 +349,10 @@ def update_gradients(model, mpi_comm=None):
while not isEnd:
isEnd, n_range, grad_dict = model.inference_method.inference_minibatch(model.kern, X, model.Z, model.likelihood, Y)
if isinstance(model.X, VariationalPosterior):
X_slice = model.X[model.Y_range[0]+n_range[0]:model.Y_range[0]+n_range[1]]
if mpi_comm ==None:
X_slice = model.X[n_range[0]:n_range[1]]
else:
X_slice = model.X[model.Y_range[0]+n_range[0]:model.Y_range[0]+n_range[1]]
#gradients w.r.t. kernel
model.kern.update_gradients_expectations(variational_posterior=X_slice, Z=model.Z, dL_dpsi0=grad_dict['dL_dpsi0'], dL_dpsi1=grad_dict['dL_dpsi1'], dL_dpsi2=grad_dict['dL_dpsi2'])