From 8eaa0bbf8a64c8bfd130d6bff135a36c21e09a7a Mon Sep 17 00:00:00 2001 From: James Hensman Date: Fri, 30 May 2014 16:45:51 +0100 Subject: [PATCH] basic vardtc working --- .../latent_function_inference/var_dtc.py | 49 +++++++------------ 1 file changed, 17 insertions(+), 32 deletions(-) diff --git a/GPy/inference/latent_function_inference/var_dtc.py b/GPy/inference/latent_function_inference/var_dtc.py index a9a137dc..9ade188f 100644 --- a/GPy/inference/latent_function_inference/var_dtc.py +++ b/GPy/inference/latent_function_inference/var_dtc.py @@ -36,11 +36,11 @@ class VarDTC(LatentFunctionInference): return param_to_array(np.sum(np.square(Y))) def __getstate__(self): - # has to be overridden, as Cacher objects cannot be pickled. + # has to be overridden, as Cacher objects cannot be pickled. return self.limit def __setstate__(self, state): - # has to be overridden, as Cacher objects cannot be pickled. + # has to be overridden, as Cacher objects cannot be pickled. self.limit = state from ...util.caching import Cacher self.get_trYYT = Cacher(self._get_trYYT, self.limit) @@ -62,20 +62,9 @@ class VarDTC(LatentFunctionInference): return Y * prec # TODO chache this, and make it effective def inference(self, kern, X, Z, likelihood, Y, Y_metadata=None): - if isinstance(X, VariationalPosterior): - uncertain_inputs = True - psi0 = kern.psi0(Z, X) - psi1 = kern.psi1(Z, X) - psi2 = kern.psi2(Z, X) - else: - uncertain_inputs = False - psi0 = kern.Kdiag(X) - psi1 = kern.K(X, Z) - psi2 = None - - #see whether we're using variational uncertain inputs _, output_dim = Y.shape + uncertain_inputs = isinstance(X, VariationalPosterior) #see whether we've got a different noise variance for each datum beta = 1./np.fmax(likelihood.gaussian_variance(Y_metadata), 1e-6) @@ -96,23 +85,21 @@ class VarDTC(LatentFunctionInference): diag.add(Kmm, self.const_jitter) Lm = jitchol(Kmm) - # The rather complex computations of A + + # The rather complex computations of A, and the psi stats if uncertain_inputs: + psi0 = kern.psi0(Z, X) + psi1 = kern.psi1(Z, X) if het_noise: - psi2_beta = psi2 * (beta.flatten().reshape(num_data, 1, 1)).sum(0) + psi2_beta = np.sum([kern.psi2(Z,X[i:i+1,:]) * beta_i for i,beta_i in enumerate(beta)],0) else: - psi2_beta = psi2.sum(0) * beta - #if 0: - # evals, evecs = linalg.eigh(psi2_beta) - # clipped_evals = np.clip(evals, 0., 1e6) # TODO: make clipping configurable - # if not np.array_equal(evals, clipped_evals): - # pass # print evals - # tmp = evecs * np.sqrt(clipped_evals) - # tmp = tmp.T - # no backsubstitution because of bound explosion on tr(A) if not... + psi2_beta = kern.psi2(Z,X) * beta LmInv = dtrtri(Lm) A = LmInv.dot(psi2_beta.dot(LmInv.T)) else: + psi0 = kern.Kdiag(X) + psi1 = kern.K(X, Z) + psi2 = None if het_noise: tmp = psi1 * (np.sqrt(beta.reshape(num_data, 1))) else: @@ -149,7 +136,7 @@ class VarDTC(LatentFunctionInference): log_marginal = _compute_log_marginal_likelihood(likelihood, num_data, output_dim, beta, het_noise, psi0, A, LB, trYYT, data_fit, VVT_factor) - #put the gradients in the right places + #noise derivatives dL_dR = _compute_dL_dR(likelihood, het_noise, uncertain_inputs, LB, _LBi_Lmi_psi1Vf, DBi_plus_BiPBi, Lm, A, @@ -158,6 +145,7 @@ class VarDTC(LatentFunctionInference): dL_dthetaL = likelihood.exact_inference_gradients(dL_dR,Y_metadata) + #put the gradients in the right places if uncertain_inputs: grad_dict = {'dL_dKmm': dL_dKmm, 'dL_dpsi0':dL_dpsi0, @@ -203,11 +191,11 @@ class VarDTCMissingData(LatentFunctionInference): self._Y.limit = limit def __getstate__(self): - # has to be overridden, as Cacher objects cannot be pickled. + # has to be overridden, as Cacher objects cannot be pickled. return self._Y.limit, self._inan def __setstate__(self, state): - # has to be overridden, as Cacher objects cannot be pickled. + # has to be overridden, as Cacher objects cannot be pickled. from ...util.caching import Cacher self.limit = state[0] self._inan = state[1] @@ -409,10 +397,7 @@ def _compute_dL_dpsi(num_inducing, num_data, output_dim, beta, Lm, VVT_factor, C dL_dpsi2 = None else: dL_dpsi2 = beta * dL_dpsi2_beta - if uncertain_inputs: - # repeat for each of the N psi_2 matrices - dL_dpsi2 = np.repeat(dL_dpsi2[None, :, :], num_data, axis=0) - else: + if not uncertain_inputs: # subsume back into psi1 (==Kmn) dL_dpsi1 += 2.*np.dot(psi1, dL_dpsi2) dL_dpsi2 = None