diff --git a/GPy/models/state_space_xt_sep.py b/GPy/models/state_space_xt_sep.py index b25e3235..5127cb98 100644 --- a/GPy/models/state_space_xt_sep.py +++ b/GPy/models/state_space_xt_sep.py @@ -31,7 +31,7 @@ from GPy.util.plot import gpplot, Tango, x_frame1D import pylab as pb class StateSpace_1(Model): - def __init__(self, SXP, SI, X, Y, kernel=None): + def __init__(self, SXP, SI, X, Y, tempokernel=None,spacekernel=None): super(StateSpace_1, self).__init__() self.num_data, input_dim = X.shape assert input_dim==1, "State space methods for time and space 2" @@ -55,13 +55,18 @@ class StateSpace_1(Model): self.sigma2 = 1. # Default kernel - if kernel is None: + if tempokernel is None: self.kern = kern.Matern32(1,lengthscale=1) + else: + self.kern = tempokernel + + if spacekernel is None: + #self.kern = kern.Matern32(1,lengthscale=1) #self.spacekern = kern.rbf(1,lengthscale=0.1) self.spacekern = kern.exponential(1,lengthscale=1) #self.spacekern = kern.Matern52(1,lengthscale=1) else: - self.kern = kernel + self.spacekern = spacekernel # Make sure all parameters are positive self.ensure_default_constraints() @@ -225,9 +230,9 @@ class StateSpace_1(Model): linecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue']): # Deal with optional parameters - if ax is None: - fig = pb.figure(num=fignum) - ax = fig.add_subplot(111) + #if ax is None: + #fig = pb.figure(num=fignum) + #ax = fig.add_subplot(111) # Define the frame on which to plot resolution = resolution or 200 @@ -285,9 +290,9 @@ class StateSpace_1(Model): Y = self.Y # Plot the values - gpplot(Xgrid, m, lower, upper, axes=ax, edgecol=linecol, fillcol=fillcol) + #gpplot(Xgrid, m, lower, upper, axes=ax, edgecol=linecol, fillcol=fillcol) #gpplot(self.X, m, lower, upper, axes=ax, edgecol=linecol, fillcol=fillcol) - ax.plot(self.X, self.Y, 'kx', mew=1.5) + #ax.plot(self.X, self.Y, 'kx', mew=1.5) # Optionally plot some samples if samples: @@ -296,10 +301,10 @@ class StateSpace_1(Model): ax.plot(Xgrid, yi, Tango.colorsHex['darkBlue'], linewidth=0.25) # Set the limits of the plot to some sensible values - ymin, ymax = min(np.append(Y.flatten(), lower.flatten())), max(np.append(Y.flatten(), upper.flatten())) - ymin, ymax = ymin - 0.1 * (ymax - ymin), ymax + 0.1 * (ymax - ymin) - ax.set_xlim(xmin, xmax) - ax.set_ylim(ymin, ymax) + #ymin, ymax = min(np.append(Y.flatten(), lower.flatten())), max(np.append(Y.flatten(), upper.flatten())) + #ymin, ymax = ymin - 0.1 * (ymax - ymin), ymax + 0.1 * (ymax - ymin) + #ax.set_xlim(xmin, xmax) + #ax.set_ylim(ymin, ymax) def posterior_samples_f(self,X,size=10):