var_dtc_parallel make YY.T speed up

This commit is contained in:
Zhenwen Dai 2014-06-23 14:23:43 +01:00
parent 7a35c9a6ad
commit cf2bf3bbe2
3 changed files with 23 additions and 13 deletions

View file

@ -39,14 +39,15 @@ class VarDTC_minibatch(LatentFunctionInference):
self.midRes = {}
self.batch_pos = 0 # the starting position of the current mini-batch
self.Y_speedup = False # Replace Y with the cholesky factor of YY.T, but the posterior inference will be wrong
def __getstate__(self):
# has to be overridden, as Cacher objects cannot be pickled.
return self.batchsize, self.limit
return self.batchsize, self.limit, self.Y_speedup
def __setstate__(self, state):
# has to be overridden, as Cacher objects cannot be pickled.
self.batchsize, self.limit = state
self.batchsize, self.limit, self.Y_speedup = state
self.mpi_comm = None
self.midRes = {}
self.batch_pos = 0
@ -75,14 +76,16 @@ class VarDTC_minibatch(LatentFunctionInference):
def gatherPsiStat(self, kern, X, Z, Y, beta, uncertain_inputs):
num_inducing = Z.shape[0]
num_data, output_dim = Y.shape
if self.batchsize == None:
self.batchsize = num_data
het_noise = beta.size > 1
trYYT = self.get_trYYT(Y)
if self.Y_speedup and not het_noise:
Y = self.get_YYTfactor(Y)
num_inducing = Z.shape[0]
num_data, output_dim = Y.shape
if self.batchsize == None:
self.batchsize = num_data
psi2_full = np.zeros((num_inducing,num_inducing))
psi1Y_full = np.zeros((output_dim,num_inducing)) # DxM
@ -201,7 +204,10 @@ class VarDTC_minibatch(LatentFunctionInference):
# Compute the Posterior distribution of inducing points p(u|Y)
#======================================================================
post = Posterior(woodbury_inv=KmmInvPsi2P, woodbury_vector=v, K=Kmm, mean=None, cov=None, K_chol=Lm)
if not self.Y_speedup or het_noise:
post = Posterior(woodbury_inv=KmmInvPsi2P, woodbury_vector=v, K=Kmm, mean=None, cov=None, K_chol=Lm)
else:
post = None
#======================================================================
# Compute dL_dthetaL for uncertian input and non-heter noise
@ -232,7 +238,10 @@ class VarDTC_minibatch(LatentFunctionInference):
het_noise = beta.size > 1
# VVT_factor is a matrix such that tdot(VVT_factor) = VVT...this is for efficiency!
#self.YYTfactor = beta*self.get_YYTfactor(Y)
YYT_factor = Y
if self.Y_speedup and not het_noise:
YYT_factor = self.get_YYTfactor(Y)
else:
YYT_factor = Y
n_start = self.batch_pos
n_end = min(self.batchsize+n_start, num_data)

View file

@ -31,6 +31,7 @@ class SSGPLVM(SparseGP):
self.mpi_comm = mpi_comm
self.__IN_OPTIMIZATION__ = False
self.group_spike = group_spike
if X == None:
from ..util.initialization import initialize_latent
@ -50,7 +51,7 @@ class SSGPLVM(SparseGP):
gamma[:] = 0.5
if group_spike:
gamma[:] = gamma.mean(axis=0)
gamma[:] = gamma[:,0]
if Z is None:
Z = np.random.permutation(X.copy())[:num_inducing]