mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 12:32:40 +02:00
[statespace] make predict comply to gpy standards (no confidence interval)
This commit is contained in:
parent
5f3956478f
commit
d6ccccc7e4
1 changed files with 76 additions and 84 deletions
|
|
@ -15,19 +15,10 @@
|
|||
#
|
||||
|
||||
import numpy as np
|
||||
from scipy import linalg
|
||||
from scipy import stats
|
||||
from ..core import Model
|
||||
from .. import kern
|
||||
#from GPy.plotting.matplot_dep.models_plots import gpplot
|
||||
#from GPy.plotting.matplot_dep.base_plots import x_frame1D
|
||||
#from GPy.plotting.matplot_dep import Tango
|
||||
#import pylab as pb
|
||||
from GPy.core.parameterization.param import Param
|
||||
|
||||
import GPy
|
||||
from .. import likelihoods
|
||||
|
||||
#from . import state_space_setup as ss_setup
|
||||
from ..core import Model
|
||||
from . import state_space_main as ssm
|
||||
from . import state_space_setup as ss_setup
|
||||
|
||||
|
|
@ -37,12 +28,12 @@ class StateSpace(Model):
|
|||
|
||||
if len(X.shape) == 1:
|
||||
X = np.atleast_2d(X).T
|
||||
self.num_data, input_dim = X.shape
|
||||
self.num_data, self.input_dim = X.shape
|
||||
|
||||
if len(Y.shape) == 1:
|
||||
Y = np.atleast_2d(Y).T
|
||||
|
||||
assert input_dim==1, "State space methods are only for 1D data"
|
||||
assert self.input_dim==1, "State space methods are only for 1D data"
|
||||
|
||||
if len(Y.shape)==2:
|
||||
num_data_Y, self.output_dim = Y.shape
|
||||
|
|
@ -168,7 +159,7 @@ class StateSpace(Model):
|
|||
def log_likelihood(self):
|
||||
return self._log_marginal_likelihood
|
||||
|
||||
def _raw_predict(self, Xnew=None, Ynew=None, filteronly=False):
|
||||
def _raw_predict(self, Xnew=None, Ynew=None, filteronly=False, **kw):
|
||||
"""
|
||||
Performs the actual prediction for new X points.
|
||||
Inner function. It is called only from inside this class.
|
||||
|
|
@ -270,22 +261,23 @@ class StateSpace(Model):
|
|||
# Return the posterior of the state
|
||||
return (m, V)
|
||||
|
||||
def predict(self, Xnew=None, filteronly=False):
|
||||
def predict(self, Xnew=None, filteronly=False, include_likelihood=True, **kw):
|
||||
|
||||
# Run the Kalman filter to get the state
|
||||
(m, V) = self._raw_predict(Xnew,filteronly=filteronly)
|
||||
|
||||
# Add the noise variance to the state variance
|
||||
V += float(self.Gaussian_noise.variance)
|
||||
if include_likelihood:
|
||||
V += float(self.likelihood.variance)
|
||||
|
||||
# Lower and upper bounds
|
||||
lower = m - 2*np.sqrt(V)
|
||||
upper = m + 2*np.sqrt(V)
|
||||
#lower = m - 2*np.sqrt(V)
|
||||
#upper = m + 2*np.sqrt(V)
|
||||
|
||||
# Return mean and variance
|
||||
return (m, V, lower, upper)
|
||||
return m, V
|
||||
|
||||
def predict_quantiles(self, Xnew=None, quantiles=(2.5, 97.5)):
|
||||
def predict_quantiles(self, Xnew=None, quantiles=(2.5, 97.5), **kw):
|
||||
mu, var = self._raw_predict(Xnew)
|
||||
#import pdb; pdb.set_trace()
|
||||
return [stats.norm.ppf(q/100.)*np.sqrt(var + float(self.Gaussian_noise.variance)) + mu for q in quantiles]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue