mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-08 15:05:15 +02:00
Now you can plot the filter estimate.
This commit is contained in:
parent
27a79945a6
commit
8ed9bbcb4d
1 changed files with 8 additions and 7 deletions
|
|
@ -77,7 +77,7 @@ class StateSpace(Model):
|
||||||
#return self.kf_likelihood_g(F,L,Qc,self.sigma2,H,Pinf,dF,dQc,dPinf,self.X,self.Y)
|
#return self.kf_likelihood_g(F,L,Qc,self.sigma2,H,Pinf,dF,dQc,dPinf,self.X,self.Y)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def predict_raw(self, Xnew):
|
def predict_raw(self, Xnew, filteronly=False):
|
||||||
|
|
||||||
# Make a single matrix containing training and testing points
|
# Make a single matrix containing training and testing points
|
||||||
X = np.vstack((self.X, Xnew))
|
X = np.vstack((self.X, Xnew))
|
||||||
|
|
@ -95,6 +95,7 @@ class StateSpace(Model):
|
||||||
(M, P) = self.kalman_filter(F,L,Qc,H,self.sigma2,Pinf,X.T,Y.T)
|
(M, P) = self.kalman_filter(F,L,Qc,H,self.sigma2,Pinf,X.T,Y.T)
|
||||||
|
|
||||||
# Run the Rauch-Tung-Striebel smoother
|
# Run the Rauch-Tung-Striebel smoother
|
||||||
|
if not filter:
|
||||||
(M, P) = self.rts_smoother(F,L,Qc,X.T,M,P)
|
(M, P) = self.rts_smoother(F,L,Qc,X.T,M,P)
|
||||||
|
|
||||||
# Put the data back in the original order
|
# Put the data back in the original order
|
||||||
|
|
@ -114,10 +115,10 @@ class StateSpace(Model):
|
||||||
# Return the posterior of the state
|
# Return the posterior of the state
|
||||||
return (m, V)
|
return (m, V)
|
||||||
|
|
||||||
def predict(self, Xnew):
|
def predict(self, Xnew, filteronly=False):
|
||||||
|
|
||||||
# Run the Kalman filter to get the state
|
# Run the Kalman filter to get the state
|
||||||
(m, V) = self.predict_raw(Xnew)
|
(m, V) = self.predict_raw(Xnew,filteronly=filteronly)
|
||||||
|
|
||||||
# Add the noise variance to the state variance
|
# Add the noise variance to the state variance
|
||||||
V += self.sigma2
|
V += self.sigma2
|
||||||
|
|
@ -130,7 +131,7 @@ class StateSpace(Model):
|
||||||
return (m, V, lower, upper)
|
return (m, V, lower, upper)
|
||||||
|
|
||||||
def plot(self, plot_limits=None, levels=20, samples=0, fignum=None,
|
def plot(self, plot_limits=None, levels=20, samples=0, fignum=None,
|
||||||
ax=None, resolution=None, plot_raw=False,
|
ax=None, resolution=None, plot_raw=False, plot_filter=False,
|
||||||
linecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue']):
|
linecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue']):
|
||||||
|
|
||||||
# Deal with optional parameters
|
# Deal with optional parameters
|
||||||
|
|
@ -144,12 +145,12 @@ class StateSpace(Model):
|
||||||
|
|
||||||
# Make a prediction on the frame and plot it
|
# Make a prediction on the frame and plot it
|
||||||
if plot_raw:
|
if plot_raw:
|
||||||
m, v = self.predict_raw(Xgrid)
|
m, v = self.predict_raw(Xgrid,filteronly=plot_filter)
|
||||||
lower = m - 2*np.sqrt(v)
|
lower = m - 2*np.sqrt(v)
|
||||||
upper = m + 2*np.sqrt(v)
|
upper = m + 2*np.sqrt(v)
|
||||||
Y = self.Y
|
Y = self.Y
|
||||||
else:
|
else:
|
||||||
m, v, lower, upper = self.predict(Xgrid)
|
m, v, lower, upper = self.predict(Xgrid,filteronly=plot_filter)
|
||||||
Y = self.Y
|
Y = self.Y
|
||||||
|
|
||||||
# Plot the values
|
# Plot the values
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue