mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-05 14:55:15 +02:00
Now you can plot the filter estimate.
This commit is contained in:
parent
27a79945a6
commit
8ed9bbcb4d
1 changed files with 8 additions and 7 deletions
|
|
@ -77,7 +77,7 @@ class StateSpace(Model):
|
|||
#return self.kf_likelihood_g(F,L,Qc,self.sigma2,H,Pinf,dF,dQc,dPinf,self.X,self.Y)
|
||||
return False
|
||||
|
||||
def predict_raw(self, Xnew):
|
||||
def predict_raw(self, Xnew, filteronly=False):
|
||||
|
||||
# Make a single matrix containing training and testing points
|
||||
X = np.vstack((self.X, Xnew))
|
||||
|
|
@ -95,7 +95,8 @@ class StateSpace(Model):
|
|||
(M, P) = self.kalman_filter(F,L,Qc,H,self.sigma2,Pinf,X.T,Y.T)
|
||||
|
||||
# Run the Rauch-Tung-Striebel smoother
|
||||
(M, P) = self.rts_smoother(F,L,Qc,X.T,M,P)
|
||||
if not filter:
|
||||
(M, P) = self.rts_smoother(F,L,Qc,X.T,M,P)
|
||||
|
||||
# Put the data back in the original order
|
||||
M = M[:,return_inverse]
|
||||
|
|
@ -114,10 +115,10 @@ class StateSpace(Model):
|
|||
# Return the posterior of the state
|
||||
return (m, V)
|
||||
|
||||
def predict(self, Xnew):
|
||||
def predict(self, Xnew, filteronly=False):
|
||||
|
||||
# Run the Kalman filter to get the state
|
||||
(m, V) = self.predict_raw(Xnew)
|
||||
(m, V) = self.predict_raw(Xnew,filteronly=filteronly)
|
||||
|
||||
# Add the noise variance to the state variance
|
||||
V += self.sigma2
|
||||
|
|
@ -130,7 +131,7 @@ class StateSpace(Model):
|
|||
return (m, V, lower, upper)
|
||||
|
||||
def plot(self, plot_limits=None, levels=20, samples=0, fignum=None,
|
||||
ax=None, resolution=None, plot_raw=False,
|
||||
ax=None, resolution=None, plot_raw=False, plot_filter=False,
|
||||
linecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue']):
|
||||
|
||||
# Deal with optional parameters
|
||||
|
|
@ -144,12 +145,12 @@ class StateSpace(Model):
|
|||
|
||||
# Make a prediction on the frame and plot it
|
||||
if plot_raw:
|
||||
m, v = self.predict_raw(Xgrid)
|
||||
m, v = self.predict_raw(Xgrid,filteronly=plot_filter)
|
||||
lower = m - 2*np.sqrt(v)
|
||||
upper = m + 2*np.sqrt(v)
|
||||
Y = self.Y
|
||||
else:
|
||||
m, v, lower, upper = self.predict(Xgrid)
|
||||
m, v, lower, upper = self.predict(Xgrid,filteronly=plot_filter)
|
||||
Y = self.Y
|
||||
|
||||
# Plot the values
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue