Ensure postive params, added plot function.

This commit is contained in:
Arno Solin 2013-11-11 17:12:40 +00:00
parent 1b0cb8fd77
commit f4425efdae

View file

@ -18,6 +18,8 @@ import numpy as np
from scipy import linalg from scipy import linalg
from ..core import Model from ..core import Model
from .. import kern from .. import kern
from GPy.util.plot import gpplot, Tango, x_frame1D
import pylab as pb
class StateSpace(Model): class StateSpace(Model):
def __init__(self, X, Y, kernel=None): def __init__(self, X, Y, kernel=None):
@ -42,6 +44,9 @@ class StateSpace(Model):
else: else:
self.kern = kernel self.kern = kernel
# Make sure all parameters are positive
self.ensure_default_constraints()
# Assert that the kernel is supported # Assert that the kernel is supported
#assert self.kern.sde(), "This kernel is not supported for state space estimation" #assert self.kern.sde(), "This kernel is not supported for state space estimation"
@ -79,7 +84,7 @@ class StateSpace(Model):
Y = np.vstack((self.Y, np.nan*np.zeros(Xnew.shape))) Y = np.vstack((self.Y, np.nan*np.zeros(Xnew.shape)))
# Sort the matrix (save the order) # Sort the matrix (save the order)
(Z, return_index, return_inverse) = np.unique(X,True,True) _, return_index, return_inverse = np.unique(X,True,True)
X = X[return_index] X = X[return_index]
Y = Y[return_index] Y = Y[return_index]
@ -103,12 +108,13 @@ class StateSpace(Model):
P = P[:,:,self.num_data:] P = P[:,:,self.num_data:]
# Calculate the mean and variance # Calculate the mean and variance
m = H.dot(M) m = H.dot(M).T
V = np.tensordot(H[0],P,(0,0)) V = np.tensordot(H[0],P,(0,0))
V = np.tensordot(V,H[0],(0,0)) V = np.tensordot(V,H[0],(0,0))
V = V[:,None]
# Return the posterior of the state # Return the posterior of the state
return (m.T, V.T) return (m, V)
def predict(self, Xnew): def predict(self, Xnew):
@ -118,12 +124,51 @@ class StateSpace(Model):
# Add the noise variance to the state variance # Add the noise variance to the state variance
V += self.sigma2 V += self.sigma2
# Return mean and variance # Lower and upper bounds
return (m, V) lower = m - 2*np.sqrt(V)
upper = m + 2*np.sqrt(V)
def plot(self): # Return mean and variance
# TODO return (m, V, lower, upper)
return 0
def plot(self, plot_limits=None, levels=20, samples=0, fignum=None,
ax=None, resolution=None, plot_raw=False,
linecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue']):
# Deal with optional parameters
if ax is None:
fig = pb.figure(num=fignum)
ax = fig.add_subplot(111)
# Define the frame on which to plot
resolution = resolution or 200
Xgrid, xmin, xmax = x_frame1D(self.X, plot_limits=plot_limits)
# Make a prediction on the frame and plot it
if plot_raw:
m, v = self.predict_raw(Xgrid)
lower = m - 2*np.sqrt(v)
upper = m + 2*np.sqrt(v)
Y = self.Y
else:
m, v, lower, upper = self.predict(Xgrid)
Y = self.Y
# Plot the values
gpplot(Xgrid, m, lower, upper, axes=ax, edgecol=linecol, fillcol=fillcol)
ax.plot(self.X, self.Y, 'kx', mew=1.5)
# Optionally plot some samples
if samples:
Ysim = self.posterior_samples(Xgrid, samples)
for yi in Ysim.T:
ax.plot(Xgrid, yi, Tango.colorsHex['darkBlue'], linewidth=0.25)
# Set the limits of the plot to some sensible values
ymin, ymax = min(np.append(Y.flatten(), lower.flatten())), max(np.append(Y.flatten(), upper.flatten()))
ymin, ymax = ymin - 0.1 * (ymax - ymin), ymax + 0.1 * (ymax - ymin)
ax.set_xlim(xmin, xmax)
ax.set_ylim(ymin, ymax)
def posterior_samples_f(self,X,size=10): def posterior_samples_f(self,X,size=10):
@ -149,7 +194,7 @@ class StateSpace(Model):
def posterior_samples(self, X, size=10): def posterior_samples(self, X, size=10):
# TODO # TODO
return 0 return self.posterior_samples_f(X,size)
def kalman_filter(self,F,L,Qc,H,R,Pinf,X,Y): def kalman_filter(self,F,L,Qc,H,R,Pinf,X,Y):
# KALMAN_FILTER - Run the Kalman filter for a given model and data # KALMAN_FILTER - Run the Kalman filter for a given model and data