Now you can plot the filter estimate.

This commit is contained in:
Arno Solin 2013-11-22 17:01:13 +00:00
parent 27a79945a6
commit 8ed9bbcb4d

View file

@ -77,7 +77,7 @@ class StateSpace(Model):
#return self.kf_likelihood_g(F,L,Qc,self.sigma2,H,Pinf,dF,dQc,dPinf,self.X,self.Y) #return self.kf_likelihood_g(F,L,Qc,self.sigma2,H,Pinf,dF,dQc,dPinf,self.X,self.Y)
return False return False
def predict_raw(self, Xnew): def predict_raw(self, Xnew, filteronly=False):
# Make a single matrix containing training and testing points # Make a single matrix containing training and testing points
X = np.vstack((self.X, Xnew)) X = np.vstack((self.X, Xnew))
@ -95,7 +95,8 @@ class StateSpace(Model):
(M, P) = self.kalman_filter(F,L,Qc,H,self.sigma2,Pinf,X.T,Y.T) (M, P) = self.kalman_filter(F,L,Qc,H,self.sigma2,Pinf,X.T,Y.T)
# Run the Rauch-Tung-Striebel smoother # Run the Rauch-Tung-Striebel smoother
(M, P) = self.rts_smoother(F,L,Qc,X.T,M,P) if not filter:
(M, P) = self.rts_smoother(F,L,Qc,X.T,M,P)
# Put the data back in the original order # Put the data back in the original order
M = M[:,return_inverse] M = M[:,return_inverse]
@ -114,10 +115,10 @@ class StateSpace(Model):
# Return the posterior of the state # Return the posterior of the state
return (m, V) return (m, V)
def predict(self, Xnew): def predict(self, Xnew, filteronly=False):
# Run the Kalman filter to get the state # Run the Kalman filter to get the state
(m, V) = self.predict_raw(Xnew) (m, V) = self.predict_raw(Xnew,filteronly=filteronly)
# Add the noise variance to the state variance # Add the noise variance to the state variance
V += self.sigma2 V += self.sigma2
@ -130,7 +131,7 @@ class StateSpace(Model):
return (m, V, lower, upper) return (m, V, lower, upper)
def plot(self, plot_limits=None, levels=20, samples=0, fignum=None, def plot(self, plot_limits=None, levels=20, samples=0, fignum=None,
ax=None, resolution=None, plot_raw=False, ax=None, resolution=None, plot_raw=False, plot_filter=False,
linecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue']): linecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue']):
# Deal with optional parameters # Deal with optional parameters
@ -144,12 +145,12 @@ class StateSpace(Model):
# Make a prediction on the frame and plot it # Make a prediction on the frame and plot it
if plot_raw: if plot_raw:
m, v = self.predict_raw(Xgrid) m, v = self.predict_raw(Xgrid,filteronly=plot_filter)
lower = m - 2*np.sqrt(v) lower = m - 2*np.sqrt(v)
upper = m + 2*np.sqrt(v) upper = m + 2*np.sqrt(v)
Y = self.Y Y = self.Y
else: else:
m, v, lower, upper = self.predict(Xgrid) m, v, lower, upper = self.predict(Xgrid,filteronly=plot_filter)
Y = self.Y Y = self.Y
# Plot the values # Plot the values