Added full cov prediction

This commit is contained in:
Alan Saul 2015-08-17 16:06:43 +01:00
parent ce3681a566
commit 070137504a

View file

@ -78,7 +78,7 @@ class GPVariationalGaussianApproximation(Model):
dF_dK = self.alpha*dF_dm.T + np.dot(tmp*dF_dv, tmp.T) dF_dK = self.alpha*dF_dm.T + np.dot(tmp*dF_dv, tmp.T)
self.kern.update_gradients_full(dF_dK - dKL_dK, self.X) self.kern.update_gradients_full(dF_dK - dKL_dK, self.X)
def _raw_predict(self, Xnew): def _raw_predict(self, Xnew, full_cov=False):
""" """
Predict the function(s) at the new point(s) Xnew. Predict the function(s) at the new point(s) Xnew.
@ -89,7 +89,12 @@ class GPVariationalGaussianApproximation(Model):
Kux = self.kern.K(self.X, Xnew) Kux = self.kern.K(self.X, Xnew)
mu = np.dot(Kux.T, self.alpha) mu = np.dot(Kux.T, self.alpha)
WiKux = np.dot(Wi, Kux) WiKux = np.dot(Wi, Kux)
Kxx = self.kern.Kdiag(Xnew) if full_cov:
var = Kxx - np.sum(WiKux*Kux, 0) Kxx = self.kern.K(Xnew)
var = Kxx - np.dot(Kux.T, WiKux)
else:
Kxx = self.kern.Kdiag(Xnew)
var = Kxx - np.sum(WiKux*Kux, 0)
var = var.reshape(-1,1)
return mu, var.reshape(-1,1) return mu, var