From 070137504aff443342cc8ce10f66cec89d3c3966 Mon Sep 17 00:00:00 2001 From: Alan Saul Date: Mon, 17 Aug 2015 16:06:43 +0100 Subject: [PATCH] Added full cov prediction --- GPy/models/gp_var_gauss.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/GPy/models/gp_var_gauss.py b/GPy/models/gp_var_gauss.py index 729b6bb8..c7926c52 100644 --- a/GPy/models/gp_var_gauss.py +++ b/GPy/models/gp_var_gauss.py @@ -78,7 +78,7 @@ class GPVariationalGaussianApproximation(Model): dF_dK = self.alpha*dF_dm.T + np.dot(tmp*dF_dv, tmp.T) self.kern.update_gradients_full(dF_dK - dKL_dK, self.X) - def _raw_predict(self, Xnew): + def _raw_predict(self, Xnew, full_cov=False): """ Predict the function(s) at the new point(s) Xnew. @@ -89,7 +89,12 @@ class GPVariationalGaussianApproximation(Model): Kux = self.kern.K(self.X, Xnew) mu = np.dot(Kux.T, self.alpha) WiKux = np.dot(Wi, Kux) - Kxx = self.kern.Kdiag(Xnew) - var = Kxx - np.sum(WiKux*Kux, 0) + if full_cov: + Kxx = self.kern.K(Xnew) + var = Kxx - np.dot(Kux.T, WiKux) + else: + Kxx = self.kern.Kdiag(Xnew) + var = Kxx - np.sum(WiKux*Kux, 0) + var = var.reshape(-1,1) - return mu, var.reshape(-1,1) + return mu, var