mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 12:02:38 +02:00
Added full cov prediction
This commit is contained in:
parent
ce3681a566
commit
070137504a
1 changed files with 9 additions and 4 deletions
|
|
@ -78,7 +78,7 @@ class GPVariationalGaussianApproximation(Model):
|
|||
dF_dK = self.alpha*dF_dm.T + np.dot(tmp*dF_dv, tmp.T)
|
||||
self.kern.update_gradients_full(dF_dK - dKL_dK, self.X)
|
||||
|
||||
def _raw_predict(self, Xnew):
|
||||
def _raw_predict(self, Xnew, full_cov=False):
|
||||
"""
|
||||
Predict the function(s) at the new point(s) Xnew.
|
||||
|
||||
|
|
@ -89,7 +89,12 @@ class GPVariationalGaussianApproximation(Model):
|
|||
Kux = self.kern.K(self.X, Xnew)
|
||||
mu = np.dot(Kux.T, self.alpha)
|
||||
WiKux = np.dot(Wi, Kux)
|
||||
Kxx = self.kern.Kdiag(Xnew)
|
||||
var = Kxx - np.sum(WiKux*Kux, 0)
|
||||
if full_cov:
|
||||
Kxx = self.kern.K(Xnew)
|
||||
var = Kxx - np.dot(Kux.T, WiKux)
|
||||
else:
|
||||
Kxx = self.kern.Kdiag(Xnew)
|
||||
var = Kxx - np.sum(WiKux*Kux, 0)
|
||||
var = var.reshape(-1,1)
|
||||
|
||||
return mu, var.reshape(-1,1)
|
||||
return mu, var
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue