This commit is contained in:
mu 2014-02-24 11:26:23 +00:00
parent 55a9c5a423
commit 8b9170ad96

View file

@ -31,7 +31,7 @@ from GPy.util.plot import gpplot, Tango, x_frame1D
import pylab as pb
class StateSpace_1(Model):
def __init__(self, SXP, SI, X, Y, kernel=None):
def __init__(self, SXP, SI, X, Y, tempokernel=None,spacekernel=None):
super(StateSpace_1, self).__init__()
self.num_data, input_dim = X.shape
assert input_dim==1, "State space methods for time and space 2"
@ -55,13 +55,18 @@ class StateSpace_1(Model):
self.sigma2 = 1.
# Default kernel
if kernel is None:
if tempokernel is None:
self.kern = kern.Matern32(1,lengthscale=1)
else:
self.kern = tempokernel
if spacekernel is None:
#self.kern = kern.Matern32(1,lengthscale=1)
#self.spacekern = kern.rbf(1,lengthscale=0.1)
self.spacekern = kern.exponential(1,lengthscale=1)
#self.spacekern = kern.Matern52(1,lengthscale=1)
else:
self.kern = kernel
self.spacekern = spacekernel
# Make sure all parameters are positive
self.ensure_default_constraints()
@ -225,9 +230,9 @@ class StateSpace_1(Model):
linecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue']):
# Deal with optional parameters
if ax is None:
fig = pb.figure(num=fignum)
ax = fig.add_subplot(111)
#if ax is None:
#fig = pb.figure(num=fignum)
#ax = fig.add_subplot(111)
# Define the frame on which to plot
resolution = resolution or 200
@ -285,9 +290,9 @@ class StateSpace_1(Model):
Y = self.Y
# Plot the values
gpplot(Xgrid, m, lower, upper, axes=ax, edgecol=linecol, fillcol=fillcol)
#gpplot(Xgrid, m, lower, upper, axes=ax, edgecol=linecol, fillcol=fillcol)
#gpplot(self.X, m, lower, upper, axes=ax, edgecol=linecol, fillcol=fillcol)
ax.plot(self.X, self.Y, 'kx', mew=1.5)
#ax.plot(self.X, self.Y, 'kx', mew=1.5)
# Optionally plot some samples
if samples:
@ -296,10 +301,10 @@ class StateSpace_1(Model):
ax.plot(Xgrid, yi, Tango.colorsHex['darkBlue'], linewidth=0.25)
# Set the limits of the plot to some sensible values
ymin, ymax = min(np.append(Y.flatten(), lower.flatten())), max(np.append(Y.flatten(), upper.flatten()))
ymin, ymax = ymin - 0.1 * (ymax - ymin), ymax + 0.1 * (ymax - ymin)
ax.set_xlim(xmin, xmax)
ax.set_ylim(ymin, ymax)
#ymin, ymax = min(np.append(Y.flatten(), lower.flatten())), max(np.append(Y.flatten(), upper.flatten()))
#ymin, ymax = ymin - 0.1 * (ymax - ymin), ymax + 0.1 * (ymax - ymin)
#ax.set_xlim(xmin, xmax)
#ax.set_ylim(ymin, ymax)
def posterior_samples_f(self,X,size=10):