mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-05 14:55:15 +02:00
xt
This commit is contained in:
parent
55a9c5a423
commit
8b9170ad96
1 changed files with 17 additions and 12 deletions
|
|
@ -31,7 +31,7 @@ from GPy.util.plot import gpplot, Tango, x_frame1D
|
|||
import pylab as pb
|
||||
|
||||
class StateSpace_1(Model):
|
||||
def __init__(self, SXP, SI, X, Y, kernel=None):
|
||||
def __init__(self, SXP, SI, X, Y, tempokernel=None,spacekernel=None):
|
||||
super(StateSpace_1, self).__init__()
|
||||
self.num_data, input_dim = X.shape
|
||||
assert input_dim==1, "State space methods for time and space 2"
|
||||
|
|
@ -55,13 +55,18 @@ class StateSpace_1(Model):
|
|||
self.sigma2 = 1.
|
||||
|
||||
# Default kernel
|
||||
if kernel is None:
|
||||
if tempokernel is None:
|
||||
self.kern = kern.Matern32(1,lengthscale=1)
|
||||
else:
|
||||
self.kern = tempokernel
|
||||
|
||||
if spacekernel is None:
|
||||
#self.kern = kern.Matern32(1,lengthscale=1)
|
||||
#self.spacekern = kern.rbf(1,lengthscale=0.1)
|
||||
self.spacekern = kern.exponential(1,lengthscale=1)
|
||||
#self.spacekern = kern.Matern52(1,lengthscale=1)
|
||||
else:
|
||||
self.kern = kernel
|
||||
self.spacekern = spacekernel
|
||||
|
||||
# Make sure all parameters are positive
|
||||
self.ensure_default_constraints()
|
||||
|
|
@ -225,9 +230,9 @@ class StateSpace_1(Model):
|
|||
linecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue']):
|
||||
|
||||
# Deal with optional parameters
|
||||
if ax is None:
|
||||
fig = pb.figure(num=fignum)
|
||||
ax = fig.add_subplot(111)
|
||||
#if ax is None:
|
||||
#fig = pb.figure(num=fignum)
|
||||
#ax = fig.add_subplot(111)
|
||||
|
||||
# Define the frame on which to plot
|
||||
resolution = resolution or 200
|
||||
|
|
@ -285,9 +290,9 @@ class StateSpace_1(Model):
|
|||
Y = self.Y
|
||||
|
||||
# Plot the values
|
||||
gpplot(Xgrid, m, lower, upper, axes=ax, edgecol=linecol, fillcol=fillcol)
|
||||
#gpplot(Xgrid, m, lower, upper, axes=ax, edgecol=linecol, fillcol=fillcol)
|
||||
#gpplot(self.X, m, lower, upper, axes=ax, edgecol=linecol, fillcol=fillcol)
|
||||
ax.plot(self.X, self.Y, 'kx', mew=1.5)
|
||||
#ax.plot(self.X, self.Y, 'kx', mew=1.5)
|
||||
|
||||
# Optionally plot some samples
|
||||
if samples:
|
||||
|
|
@ -296,10 +301,10 @@ class StateSpace_1(Model):
|
|||
ax.plot(Xgrid, yi, Tango.colorsHex['darkBlue'], linewidth=0.25)
|
||||
|
||||
# Set the limits of the plot to some sensible values
|
||||
ymin, ymax = min(np.append(Y.flatten(), lower.flatten())), max(np.append(Y.flatten(), upper.flatten()))
|
||||
ymin, ymax = ymin - 0.1 * (ymax - ymin), ymax + 0.1 * (ymax - ymin)
|
||||
ax.set_xlim(xmin, xmax)
|
||||
ax.set_ylim(ymin, ymax)
|
||||
#ymin, ymax = min(np.append(Y.flatten(), lower.flatten())), max(np.append(Y.flatten(), upper.flatten()))
|
||||
#ymin, ymax = ymin - 0.1 * (ymax - ymin), ymax + 0.1 * (ymax - ymin)
|
||||
#ax.set_xlim(xmin, xmax)
|
||||
#ax.set_ylim(ymin, ymax)
|
||||
|
||||
def posterior_samples_f(self,X,size=10):
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue