diff --git a/GPy/models/state_space.py b/GPy/models/state_space.py index 601be3b1..27ea7ca3 100644 --- a/GPy/models/state_space.py +++ b/GPy/models/state_space.py @@ -77,7 +77,7 @@ class StateSpace(Model): #return self.kf_likelihood_g(F,L,Qc,self.sigma2,H,Pinf,dF,dQc,dPinf,self.X,self.Y) return False - def predict_raw(self, Xnew): + def predict_raw(self, Xnew, filteronly=False): # Make a single matrix containing training and testing points X = np.vstack((self.X, Xnew)) @@ -95,7 +95,8 @@ class StateSpace(Model): (M, P) = self.kalman_filter(F,L,Qc,H,self.sigma2,Pinf,X.T,Y.T) # Run the Rauch-Tung-Striebel smoother - (M, P) = self.rts_smoother(F,L,Qc,X.T,M,P) + if not filter: + (M, P) = self.rts_smoother(F,L,Qc,X.T,M,P) # Put the data back in the original order M = M[:,return_inverse] @@ -114,10 +115,10 @@ class StateSpace(Model): # Return the posterior of the state return (m, V) - def predict(self, Xnew): + def predict(self, Xnew, filteronly=False): # Run the Kalman filter to get the state - (m, V) = self.predict_raw(Xnew) + (m, V) = self.predict_raw(Xnew,filteronly=filteronly) # Add the noise variance to the state variance V += self.sigma2 @@ -130,7 +131,7 @@ class StateSpace(Model): return (m, V, lower, upper) def plot(self, plot_limits=None, levels=20, samples=0, fignum=None, - ax=None, resolution=None, plot_raw=False, + ax=None, resolution=None, plot_raw=False, plot_filter=False, linecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue']): # Deal with optional parameters @@ -144,12 +145,12 @@ class StateSpace(Model): # Make a prediction on the frame and plot it if plot_raw: - m, v = self.predict_raw(Xgrid) + m, v = self.predict_raw(Xgrid,filteronly=plot_filter) lower = m - 2*np.sqrt(v) upper = m + 2*np.sqrt(v) Y = self.Y else: - m, v, lower, upper = self.predict(Xgrid) + m, v, lower, upper = self.predict(Xgrid,filteronly=plot_filter) Y = self.Y # Plot the values