mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 12:32:40 +02:00
var_dtc_parallel make YY.T speed up
This commit is contained in:
parent
7a35c9a6ad
commit
cf2bf3bbe2
3 changed files with 23 additions and 13 deletions
|
|
@ -39,14 +39,15 @@ class VarDTC_minibatch(LatentFunctionInference):
|
|||
|
||||
self.midRes = {}
|
||||
self.batch_pos = 0 # the starting position of the current mini-batch
|
||||
self.Y_speedup = False # Replace Y with the cholesky factor of YY.T, but the posterior inference will be wrong
|
||||
|
||||
def __getstate__(self):
|
||||
# has to be overridden, as Cacher objects cannot be pickled.
|
||||
return self.batchsize, self.limit
|
||||
return self.batchsize, self.limit, self.Y_speedup
|
||||
|
||||
def __setstate__(self, state):
|
||||
# has to be overridden, as Cacher objects cannot be pickled.
|
||||
self.batchsize, self.limit = state
|
||||
self.batchsize, self.limit, self.Y_speedup = state
|
||||
self.mpi_comm = None
|
||||
self.midRes = {}
|
||||
self.batch_pos = 0
|
||||
|
|
@ -75,14 +76,16 @@ class VarDTC_minibatch(LatentFunctionInference):
|
|||
|
||||
def gatherPsiStat(self, kern, X, Z, Y, beta, uncertain_inputs):
|
||||
|
||||
num_inducing = Z.shape[0]
|
||||
num_data, output_dim = Y.shape
|
||||
|
||||
if self.batchsize == None:
|
||||
self.batchsize = num_data
|
||||
het_noise = beta.size > 1
|
||||
|
||||
trYYT = self.get_trYYT(Y)
|
||||
if self.Y_speedup and not het_noise:
|
||||
Y = self.get_YYTfactor(Y)
|
||||
|
||||
num_inducing = Z.shape[0]
|
||||
num_data, output_dim = Y.shape
|
||||
if self.batchsize == None:
|
||||
self.batchsize = num_data
|
||||
|
||||
psi2_full = np.zeros((num_inducing,num_inducing))
|
||||
psi1Y_full = np.zeros((output_dim,num_inducing)) # DxM
|
||||
|
|
@ -201,7 +204,10 @@ class VarDTC_minibatch(LatentFunctionInference):
|
|||
# Compute the Posterior distribution of inducing points p(u|Y)
|
||||
#======================================================================
|
||||
|
||||
post = Posterior(woodbury_inv=KmmInvPsi2P, woodbury_vector=v, K=Kmm, mean=None, cov=None, K_chol=Lm)
|
||||
if not self.Y_speedup or het_noise:
|
||||
post = Posterior(woodbury_inv=KmmInvPsi2P, woodbury_vector=v, K=Kmm, mean=None, cov=None, K_chol=Lm)
|
||||
else:
|
||||
post = None
|
||||
|
||||
#======================================================================
|
||||
# Compute dL_dthetaL for uncertian input and non-heter noise
|
||||
|
|
@ -232,7 +238,10 @@ class VarDTC_minibatch(LatentFunctionInference):
|
|||
het_noise = beta.size > 1
|
||||
# VVT_factor is a matrix such that tdot(VVT_factor) = VVT...this is for efficiency!
|
||||
#self.YYTfactor = beta*self.get_YYTfactor(Y)
|
||||
YYT_factor = Y
|
||||
if self.Y_speedup and not het_noise:
|
||||
YYT_factor = self.get_YYTfactor(Y)
|
||||
else:
|
||||
YYT_factor = Y
|
||||
|
||||
n_start = self.batch_pos
|
||||
n_end = min(self.batchsize+n_start, num_data)
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ class SSGPLVM(SparseGP):
|
|||
|
||||
self.mpi_comm = mpi_comm
|
||||
self.__IN_OPTIMIZATION__ = False
|
||||
self.group_spike = group_spike
|
||||
|
||||
if X == None:
|
||||
from ..util.initialization import initialize_latent
|
||||
|
|
@ -50,7 +51,7 @@ class SSGPLVM(SparseGP):
|
|||
gamma[:] = 0.5
|
||||
|
||||
if group_spike:
|
||||
gamma[:] = gamma.mean(axis=0)
|
||||
gamma[:] = gamma[:,0]
|
||||
|
||||
if Z is None:
|
||||
Z = np.random.permutation(X.copy())[:num_inducing]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue