Ensure postive params, added plot function.

This commit is contained in:
Arno Solin 2013-11-11 17:12:40 +00:00
parent 1b0cb8fd77
commit f4425efdae

View file

@ -18,6 +18,8 @@ import numpy as np
from scipy import linalg
from ..core import Model
from .. import kern
from GPy.util.plot import gpplot, Tango, x_frame1D
import pylab as pb
class StateSpace(Model):
def __init__(self, X, Y, kernel=None):
@ -42,6 +44,9 @@ class StateSpace(Model):
else:
self.kern = kernel
# Make sure all parameters are positive
self.ensure_default_constraints()
# Assert that the kernel is supported
#assert self.kern.sde(), "This kernel is not supported for state space estimation"
@ -79,7 +84,7 @@ class StateSpace(Model):
Y = np.vstack((self.Y, np.nan*np.zeros(Xnew.shape)))
# Sort the matrix (save the order)
(Z, return_index, return_inverse) = np.unique(X,True,True)
_, return_index, return_inverse = np.unique(X,True,True)
X = X[return_index]
Y = Y[return_index]
@ -103,12 +108,13 @@ class StateSpace(Model):
P = P[:,:,self.num_data:]
# Calculate the mean and variance
m = H.dot(M)
m = H.dot(M).T
V = np.tensordot(H[0],P,(0,0))
V = np.tensordot(V,H[0],(0,0))
V = V[:,None]
# Return the posterior of the state
return (m.T, V.T)
return (m, V)
def predict(self, Xnew):
@ -118,12 +124,51 @@ class StateSpace(Model):
# Add the noise variance to the state variance
V += self.sigma2
# Return mean and variance
return (m, V)
# Lower and upper bounds
lower = m - 2*np.sqrt(V)
upper = m + 2*np.sqrt(V)
def plot(self):
# TODO
return 0
# Return mean and variance
return (m, V, lower, upper)
def plot(self, plot_limits=None, levels=20, samples=0, fignum=None,
ax=None, resolution=None, plot_raw=False,
linecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue']):
# Deal with optional parameters
if ax is None:
fig = pb.figure(num=fignum)
ax = fig.add_subplot(111)
# Define the frame on which to plot
resolution = resolution or 200
Xgrid, xmin, xmax = x_frame1D(self.X, plot_limits=plot_limits)
# Make a prediction on the frame and plot it
if plot_raw:
m, v = self.predict_raw(Xgrid)
lower = m - 2*np.sqrt(v)
upper = m + 2*np.sqrt(v)
Y = self.Y
else:
m, v, lower, upper = self.predict(Xgrid)
Y = self.Y
# Plot the values
gpplot(Xgrid, m, lower, upper, axes=ax, edgecol=linecol, fillcol=fillcol)
ax.plot(self.X, self.Y, 'kx', mew=1.5)
# Optionally plot some samples
if samples:
Ysim = self.posterior_samples(Xgrid, samples)
for yi in Ysim.T:
ax.plot(Xgrid, yi, Tango.colorsHex['darkBlue'], linewidth=0.25)
# Set the limits of the plot to some sensible values
ymin, ymax = min(np.append(Y.flatten(), lower.flatten())), max(np.append(Y.flatten(), upper.flatten()))
ymin, ymax = ymin - 0.1 * (ymax - ymin), ymax + 0.1 * (ymax - ymin)
ax.set_xlim(xmin, xmax)
ax.set_ylim(ymin, ymax)
def posterior_samples_f(self,X,size=10):
@ -149,7 +194,7 @@ class StateSpace(Model):
def posterior_samples(self, X, size=10):
# TODO
return 0
return self.posterior_samples_f(X,size)
def kalman_filter(self,F,L,Qc,H,R,Pinf,X,Y):
# KALMAN_FILTER - Run the Kalman filter for a given model and data