mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-08 15:05:15 +02:00
Ensure postive params, added plot function.
This commit is contained in:
parent
1b0cb8fd77
commit
f4425efdae
1 changed files with 54 additions and 9 deletions
|
|
@ -18,6 +18,8 @@ import numpy as np
|
||||||
from scipy import linalg
|
from scipy import linalg
|
||||||
from ..core import Model
|
from ..core import Model
|
||||||
from .. import kern
|
from .. import kern
|
||||||
|
from GPy.util.plot import gpplot, Tango, x_frame1D
|
||||||
|
import pylab as pb
|
||||||
|
|
||||||
class StateSpace(Model):
|
class StateSpace(Model):
|
||||||
def __init__(self, X, Y, kernel=None):
|
def __init__(self, X, Y, kernel=None):
|
||||||
|
|
@ -42,6 +44,9 @@ class StateSpace(Model):
|
||||||
else:
|
else:
|
||||||
self.kern = kernel
|
self.kern = kernel
|
||||||
|
|
||||||
|
# Make sure all parameters are positive
|
||||||
|
self.ensure_default_constraints()
|
||||||
|
|
||||||
# Assert that the kernel is supported
|
# Assert that the kernel is supported
|
||||||
#assert self.kern.sde(), "This kernel is not supported for state space estimation"
|
#assert self.kern.sde(), "This kernel is not supported for state space estimation"
|
||||||
|
|
||||||
|
|
@ -79,7 +84,7 @@ class StateSpace(Model):
|
||||||
Y = np.vstack((self.Y, np.nan*np.zeros(Xnew.shape)))
|
Y = np.vstack((self.Y, np.nan*np.zeros(Xnew.shape)))
|
||||||
|
|
||||||
# Sort the matrix (save the order)
|
# Sort the matrix (save the order)
|
||||||
(Z, return_index, return_inverse) = np.unique(X,True,True)
|
_, return_index, return_inverse = np.unique(X,True,True)
|
||||||
X = X[return_index]
|
X = X[return_index]
|
||||||
Y = Y[return_index]
|
Y = Y[return_index]
|
||||||
|
|
||||||
|
|
@ -103,12 +108,13 @@ class StateSpace(Model):
|
||||||
P = P[:,:,self.num_data:]
|
P = P[:,:,self.num_data:]
|
||||||
|
|
||||||
# Calculate the mean and variance
|
# Calculate the mean and variance
|
||||||
m = H.dot(M)
|
m = H.dot(M).T
|
||||||
V = np.tensordot(H[0],P,(0,0))
|
V = np.tensordot(H[0],P,(0,0))
|
||||||
V = np.tensordot(V,H[0],(0,0))
|
V = np.tensordot(V,H[0],(0,0))
|
||||||
|
V = V[:,None]
|
||||||
|
|
||||||
# Return the posterior of the state
|
# Return the posterior of the state
|
||||||
return (m.T, V.T)
|
return (m, V)
|
||||||
|
|
||||||
def predict(self, Xnew):
|
def predict(self, Xnew):
|
||||||
|
|
||||||
|
|
@ -118,12 +124,51 @@ class StateSpace(Model):
|
||||||
# Add the noise variance to the state variance
|
# Add the noise variance to the state variance
|
||||||
V += self.sigma2
|
V += self.sigma2
|
||||||
|
|
||||||
# Return mean and variance
|
# Lower and upper bounds
|
||||||
return (m, V)
|
lower = m - 2*np.sqrt(V)
|
||||||
|
upper = m + 2*np.sqrt(V)
|
||||||
|
|
||||||
def plot(self):
|
# Return mean and variance
|
||||||
# TODO
|
return (m, V, lower, upper)
|
||||||
return 0
|
|
||||||
|
def plot(self, plot_limits=None, levels=20, samples=0, fignum=None,
|
||||||
|
ax=None, resolution=None, plot_raw=False,
|
||||||
|
linecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue']):
|
||||||
|
|
||||||
|
# Deal with optional parameters
|
||||||
|
if ax is None:
|
||||||
|
fig = pb.figure(num=fignum)
|
||||||
|
ax = fig.add_subplot(111)
|
||||||
|
|
||||||
|
# Define the frame on which to plot
|
||||||
|
resolution = resolution or 200
|
||||||
|
Xgrid, xmin, xmax = x_frame1D(self.X, plot_limits=plot_limits)
|
||||||
|
|
||||||
|
# Make a prediction on the frame and plot it
|
||||||
|
if plot_raw:
|
||||||
|
m, v = self.predict_raw(Xgrid)
|
||||||
|
lower = m - 2*np.sqrt(v)
|
||||||
|
upper = m + 2*np.sqrt(v)
|
||||||
|
Y = self.Y
|
||||||
|
else:
|
||||||
|
m, v, lower, upper = self.predict(Xgrid)
|
||||||
|
Y = self.Y
|
||||||
|
|
||||||
|
# Plot the values
|
||||||
|
gpplot(Xgrid, m, lower, upper, axes=ax, edgecol=linecol, fillcol=fillcol)
|
||||||
|
ax.plot(self.X, self.Y, 'kx', mew=1.5)
|
||||||
|
|
||||||
|
# Optionally plot some samples
|
||||||
|
if samples:
|
||||||
|
Ysim = self.posterior_samples(Xgrid, samples)
|
||||||
|
for yi in Ysim.T:
|
||||||
|
ax.plot(Xgrid, yi, Tango.colorsHex['darkBlue'], linewidth=0.25)
|
||||||
|
|
||||||
|
# Set the limits of the plot to some sensible values
|
||||||
|
ymin, ymax = min(np.append(Y.flatten(), lower.flatten())), max(np.append(Y.flatten(), upper.flatten()))
|
||||||
|
ymin, ymax = ymin - 0.1 * (ymax - ymin), ymax + 0.1 * (ymax - ymin)
|
||||||
|
ax.set_xlim(xmin, xmax)
|
||||||
|
ax.set_ylim(ymin, ymax)
|
||||||
|
|
||||||
def posterior_samples_f(self,X,size=10):
|
def posterior_samples_f(self,X,size=10):
|
||||||
|
|
||||||
|
|
@ -149,7 +194,7 @@ class StateSpace(Model):
|
||||||
|
|
||||||
def posterior_samples(self, X, size=10):
|
def posterior_samples(self, X, size=10):
|
||||||
# TODO
|
# TODO
|
||||||
return 0
|
return self.posterior_samples_f(X,size)
|
||||||
|
|
||||||
def kalman_filter(self,F,L,Qc,H,R,Pinf,X,Y):
|
def kalman_filter(self,F,L,Qc,H,R,Pinf,X,Y):
|
||||||
# KALMAN_FILTER - Run the Kalman filter for a given model and data
|
# KALMAN_FILTER - Run the Kalman filter for a given model and data
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue