From cf2bf3bbe2c59910a7a45df1270154983d0be01f Mon Sep 17 00:00:00 2001 From: Zhenwen Dai Date: Mon, 23 Jun 2014 14:23:43 +0100 Subject: [PATCH] var_dtc_parallel make YY.T speed up --- GPy/core/gp.py | 2 +- .../var_dtc_parallel.py | 29 ++++++++++++------- GPy/models/ss_gplvm.py | 5 ++-- 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/GPy/core/gp.py b/GPy/core/gp.py index f4b02128..3825a48e 100644 --- a/GPy/core/gp.py +++ b/GPy/core/gp.py @@ -277,4 +277,4 @@ class GP(Model): """ self.inference_method.on_optimization_start() super(GP, self).optimize(optimizer, start, **kwargs) - self.inference_method.on_optimization_end() \ No newline at end of file + self.inference_method.on_optimization_end() diff --git a/GPy/inference/latent_function_inference/var_dtc_parallel.py b/GPy/inference/latent_function_inference/var_dtc_parallel.py index 54bc11ac..450e767a 100644 --- a/GPy/inference/latent_function_inference/var_dtc_parallel.py +++ b/GPy/inference/latent_function_inference/var_dtc_parallel.py @@ -39,14 +39,15 @@ class VarDTC_minibatch(LatentFunctionInference): self.midRes = {} self.batch_pos = 0 # the starting position of the current mini-batch + self.Y_speedup = False # Replace Y with the cholesky factor of YY.T, but the posterior inference will be wrong def __getstate__(self): # has to be overridden, as Cacher objects cannot be pickled. - return self.batchsize, self.limit + return self.batchsize, self.limit, self.Y_speedup def __setstate__(self, state): # has to be overridden, as Cacher objects cannot be pickled. - self.batchsize, self.limit = state + self.batchsize, self.limit, self.Y_speedup = state self.mpi_comm = None self.midRes = {} self.batch_pos = 0 @@ -75,14 +76,16 @@ class VarDTC_minibatch(LatentFunctionInference): def gatherPsiStat(self, kern, X, Z, Y, beta, uncertain_inputs): - num_inducing = Z.shape[0] - num_data, output_dim = Y.shape - - if self.batchsize == None: - self.batchsize = num_data het_noise = beta.size > 1 trYYT = self.get_trYYT(Y) + if self.Y_speedup and not het_noise: + Y = self.get_YYTfactor(Y) + + num_inducing = Z.shape[0] + num_data, output_dim = Y.shape + if self.batchsize == None: + self.batchsize = num_data psi2_full = np.zeros((num_inducing,num_inducing)) psi1Y_full = np.zeros((output_dim,num_inducing)) # DxM @@ -200,8 +203,11 @@ class VarDTC_minibatch(LatentFunctionInference): #====================================================================== # Compute the Posterior distribution of inducing points p(u|Y) #====================================================================== - - post = Posterior(woodbury_inv=KmmInvPsi2P, woodbury_vector=v, K=Kmm, mean=None, cov=None, K_chol=Lm) + + if not self.Y_speedup or het_noise: + post = Posterior(woodbury_inv=KmmInvPsi2P, woodbury_vector=v, K=Kmm, mean=None, cov=None, K_chol=Lm) + else: + post = None #====================================================================== # Compute dL_dthetaL for uncertian input and non-heter noise @@ -232,7 +238,10 @@ class VarDTC_minibatch(LatentFunctionInference): het_noise = beta.size > 1 # VVT_factor is a matrix such that tdot(VVT_factor) = VVT...this is for efficiency! #self.YYTfactor = beta*self.get_YYTfactor(Y) - YYT_factor = Y + if self.Y_speedup and not het_noise: + YYT_factor = self.get_YYTfactor(Y) + else: + YYT_factor = Y n_start = self.batch_pos n_end = min(self.batchsize+n_start, num_data) diff --git a/GPy/models/ss_gplvm.py b/GPy/models/ss_gplvm.py index b10991ab..ab32e462 100644 --- a/GPy/models/ss_gplvm.py +++ b/GPy/models/ss_gplvm.py @@ -31,7 +31,8 @@ class SSGPLVM(SparseGP): self.mpi_comm = mpi_comm self.__IN_OPTIMIZATION__ = False - + self.group_spike = group_spike + if X == None: from ..util.initialization import initialize_latent X, fracs = initialize_latent(init, input_dim, Y) @@ -50,7 +51,7 @@ class SSGPLVM(SparseGP): gamma[:] = 0.5 if group_spike: - gamma[:] = gamma.mean(axis=0) + gamma[:] = gamma[:,0] if Z is None: Z = np.random.permutation(X.copy())[:num_inducing]