[statespace] make predict comply to gpy standards (no confidence interval)

This commit is contained in:
Max Zwiessele 2016-04-04 15:37:51 +01:00
parent 5f3956478f
commit d6ccccc7e4

View file

@ -15,19 +15,10 @@
# #
import numpy as np import numpy as np
from scipy import linalg
from scipy import stats from scipy import stats
from ..core import Model
from .. import kern
#from GPy.plotting.matplot_dep.models_plots import gpplot
#from GPy.plotting.matplot_dep.base_plots import x_frame1D
#from GPy.plotting.matplot_dep import Tango
#import pylab as pb
from GPy.core.parameterization.param import Param
import GPy
from .. import likelihoods from .. import likelihoods
#from . import state_space_setup as ss_setup
from ..core import Model
from . import state_space_main as ssm from . import state_space_main as ssm
from . import state_space_setup as ss_setup from . import state_space_setup as ss_setup
@ -37,12 +28,12 @@ class StateSpace(Model):
if len(X.shape) == 1: if len(X.shape) == 1:
X = np.atleast_2d(X).T X = np.atleast_2d(X).T
self.num_data, input_dim = X.shape self.num_data, self.input_dim = X.shape
if len(Y.shape) == 1: if len(Y.shape) == 1:
Y = np.atleast_2d(Y).T Y = np.atleast_2d(Y).T
assert input_dim==1, "State space methods are only for 1D data" assert self.input_dim==1, "State space methods are only for 1D data"
if len(Y.shape)==2: if len(Y.shape)==2:
num_data_Y, self.output_dim = Y.shape num_data_Y, self.output_dim = Y.shape
@ -168,7 +159,7 @@ class StateSpace(Model):
def log_likelihood(self): def log_likelihood(self):
return self._log_marginal_likelihood return self._log_marginal_likelihood
def _raw_predict(self, Xnew=None, Ynew=None, filteronly=False): def _raw_predict(self, Xnew=None, Ynew=None, filteronly=False, **kw):
""" """
Performs the actual prediction for new X points. Performs the actual prediction for new X points.
Inner function. It is called only from inside this class. Inner function. It is called only from inside this class.
@ -270,22 +261,23 @@ class StateSpace(Model):
# Return the posterior of the state # Return the posterior of the state
return (m, V) return (m, V)
def predict(self, Xnew=None, filteronly=False): def predict(self, Xnew=None, filteronly=False, include_likelihood=True, **kw):
# Run the Kalman filter to get the state # Run the Kalman filter to get the state
(m, V) = self._raw_predict(Xnew,filteronly=filteronly) (m, V) = self._raw_predict(Xnew,filteronly=filteronly)
# Add the noise variance to the state variance # Add the noise variance to the state variance
V += float(self.Gaussian_noise.variance) if include_likelihood:
V += float(self.likelihood.variance)
# Lower and upper bounds # Lower and upper bounds
lower = m - 2*np.sqrt(V) #lower = m - 2*np.sqrt(V)
upper = m + 2*np.sqrt(V) #upper = m + 2*np.sqrt(V)
# Return mean and variance # Return mean and variance
return (m, V, lower, upper) return m, V
def predict_quantiles(self, Xnew=None, quantiles=(2.5, 97.5)): def predict_quantiles(self, Xnew=None, quantiles=(2.5, 97.5), **kw):
mu, var = self._raw_predict(Xnew) mu, var = self._raw_predict(Xnew)
#import pdb; pdb.set_trace() #import pdb; pdb.set_trace()
return [stats.norm.ppf(q/100.)*np.sqrt(var + float(self.Gaussian_noise.variance)) + mu for q in quantiles] return [stats.norm.ppf(q/100.)*np.sqrt(var + float(self.Gaussian_noise.variance)) + mu for q in quantiles]