mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 04:22:38 +02:00
var_dtc_parallel make YY.T speed up
This commit is contained in:
parent
7a35c9a6ad
commit
cf2bf3bbe2
3 changed files with 23 additions and 13 deletions
|
|
@ -39,14 +39,15 @@ class VarDTC_minibatch(LatentFunctionInference):
|
||||||
|
|
||||||
self.midRes = {}
|
self.midRes = {}
|
||||||
self.batch_pos = 0 # the starting position of the current mini-batch
|
self.batch_pos = 0 # the starting position of the current mini-batch
|
||||||
|
self.Y_speedup = False # Replace Y with the cholesky factor of YY.T, but the posterior inference will be wrong
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
# has to be overridden, as Cacher objects cannot be pickled.
|
# has to be overridden, as Cacher objects cannot be pickled.
|
||||||
return self.batchsize, self.limit
|
return self.batchsize, self.limit, self.Y_speedup
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
# has to be overridden, as Cacher objects cannot be pickled.
|
# has to be overridden, as Cacher objects cannot be pickled.
|
||||||
self.batchsize, self.limit = state
|
self.batchsize, self.limit, self.Y_speedup = state
|
||||||
self.mpi_comm = None
|
self.mpi_comm = None
|
||||||
self.midRes = {}
|
self.midRes = {}
|
||||||
self.batch_pos = 0
|
self.batch_pos = 0
|
||||||
|
|
@ -75,14 +76,16 @@ class VarDTC_minibatch(LatentFunctionInference):
|
||||||
|
|
||||||
def gatherPsiStat(self, kern, X, Z, Y, beta, uncertain_inputs):
|
def gatherPsiStat(self, kern, X, Z, Y, beta, uncertain_inputs):
|
||||||
|
|
||||||
num_inducing = Z.shape[0]
|
|
||||||
num_data, output_dim = Y.shape
|
|
||||||
|
|
||||||
if self.batchsize == None:
|
|
||||||
self.batchsize = num_data
|
|
||||||
het_noise = beta.size > 1
|
het_noise = beta.size > 1
|
||||||
|
|
||||||
trYYT = self.get_trYYT(Y)
|
trYYT = self.get_trYYT(Y)
|
||||||
|
if self.Y_speedup and not het_noise:
|
||||||
|
Y = self.get_YYTfactor(Y)
|
||||||
|
|
||||||
|
num_inducing = Z.shape[0]
|
||||||
|
num_data, output_dim = Y.shape
|
||||||
|
if self.batchsize == None:
|
||||||
|
self.batchsize = num_data
|
||||||
|
|
||||||
psi2_full = np.zeros((num_inducing,num_inducing))
|
psi2_full = np.zeros((num_inducing,num_inducing))
|
||||||
psi1Y_full = np.zeros((output_dim,num_inducing)) # DxM
|
psi1Y_full = np.zeros((output_dim,num_inducing)) # DxM
|
||||||
|
|
@ -201,7 +204,10 @@ class VarDTC_minibatch(LatentFunctionInference):
|
||||||
# Compute the Posterior distribution of inducing points p(u|Y)
|
# Compute the Posterior distribution of inducing points p(u|Y)
|
||||||
#======================================================================
|
#======================================================================
|
||||||
|
|
||||||
|
if not self.Y_speedup or het_noise:
|
||||||
post = Posterior(woodbury_inv=KmmInvPsi2P, woodbury_vector=v, K=Kmm, mean=None, cov=None, K_chol=Lm)
|
post = Posterior(woodbury_inv=KmmInvPsi2P, woodbury_vector=v, K=Kmm, mean=None, cov=None, K_chol=Lm)
|
||||||
|
else:
|
||||||
|
post = None
|
||||||
|
|
||||||
#======================================================================
|
#======================================================================
|
||||||
# Compute dL_dthetaL for uncertian input and non-heter noise
|
# Compute dL_dthetaL for uncertian input and non-heter noise
|
||||||
|
|
@ -232,6 +238,9 @@ class VarDTC_minibatch(LatentFunctionInference):
|
||||||
het_noise = beta.size > 1
|
het_noise = beta.size > 1
|
||||||
# VVT_factor is a matrix such that tdot(VVT_factor) = VVT...this is for efficiency!
|
# VVT_factor is a matrix such that tdot(VVT_factor) = VVT...this is for efficiency!
|
||||||
#self.YYTfactor = beta*self.get_YYTfactor(Y)
|
#self.YYTfactor = beta*self.get_YYTfactor(Y)
|
||||||
|
if self.Y_speedup and not het_noise:
|
||||||
|
YYT_factor = self.get_YYTfactor(Y)
|
||||||
|
else:
|
||||||
YYT_factor = Y
|
YYT_factor = Y
|
||||||
|
|
||||||
n_start = self.batch_pos
|
n_start = self.batch_pos
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@ class SSGPLVM(SparseGP):
|
||||||
|
|
||||||
self.mpi_comm = mpi_comm
|
self.mpi_comm = mpi_comm
|
||||||
self.__IN_OPTIMIZATION__ = False
|
self.__IN_OPTIMIZATION__ = False
|
||||||
|
self.group_spike = group_spike
|
||||||
|
|
||||||
if X == None:
|
if X == None:
|
||||||
from ..util.initialization import initialize_latent
|
from ..util.initialization import initialize_latent
|
||||||
|
|
@ -50,7 +51,7 @@ class SSGPLVM(SparseGP):
|
||||||
gamma[:] = 0.5
|
gamma[:] = 0.5
|
||||||
|
|
||||||
if group_spike:
|
if group_spike:
|
||||||
gamma[:] = gamma.mean(axis=0)
|
gamma[:] = gamma[:,0]
|
||||||
|
|
||||||
if Z is None:
|
if Z is None:
|
||||||
Z = np.random.permutation(X.copy())[:num_inducing]
|
Z = np.random.permutation(X.copy())[:num_inducing]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue