mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 20:12:38 +02:00
Added full cov prediction
This commit is contained in:
parent
ce3681a566
commit
070137504a
1 changed files with 9 additions and 4 deletions
|
|
@ -78,7 +78,7 @@ class GPVariationalGaussianApproximation(Model):
|
||||||
dF_dK = self.alpha*dF_dm.T + np.dot(tmp*dF_dv, tmp.T)
|
dF_dK = self.alpha*dF_dm.T + np.dot(tmp*dF_dv, tmp.T)
|
||||||
self.kern.update_gradients_full(dF_dK - dKL_dK, self.X)
|
self.kern.update_gradients_full(dF_dK - dKL_dK, self.X)
|
||||||
|
|
||||||
def _raw_predict(self, Xnew):
|
def _raw_predict(self, Xnew, full_cov=False):
|
||||||
"""
|
"""
|
||||||
Predict the function(s) at the new point(s) Xnew.
|
Predict the function(s) at the new point(s) Xnew.
|
||||||
|
|
||||||
|
|
@ -89,7 +89,12 @@ class GPVariationalGaussianApproximation(Model):
|
||||||
Kux = self.kern.K(self.X, Xnew)
|
Kux = self.kern.K(self.X, Xnew)
|
||||||
mu = np.dot(Kux.T, self.alpha)
|
mu = np.dot(Kux.T, self.alpha)
|
||||||
WiKux = np.dot(Wi, Kux)
|
WiKux = np.dot(Wi, Kux)
|
||||||
Kxx = self.kern.Kdiag(Xnew)
|
if full_cov:
|
||||||
var = Kxx - np.sum(WiKux*Kux, 0)
|
Kxx = self.kern.K(Xnew)
|
||||||
|
var = Kxx - np.dot(Kux.T, WiKux)
|
||||||
|
else:
|
||||||
|
Kxx = self.kern.Kdiag(Xnew)
|
||||||
|
var = Kxx - np.sum(WiKux*Kux, 0)
|
||||||
|
var = var.reshape(-1,1)
|
||||||
|
|
||||||
return mu, var.reshape(-1,1)
|
return mu, var
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue