mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 20:42:39 +02:00
[statespace] make predict comply to gpy standards (no confidence interval)
This commit is contained in:
parent
5f3956478f
commit
d6ccccc7e4
1 changed files with 76 additions and 84 deletions
|
|
@ -15,19 +15,10 @@
|
||||||
#
|
#
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy import linalg
|
|
||||||
from scipy import stats
|
from scipy import stats
|
||||||
from ..core import Model
|
|
||||||
from .. import kern
|
|
||||||
#from GPy.plotting.matplot_dep.models_plots import gpplot
|
|
||||||
#from GPy.plotting.matplot_dep.base_plots import x_frame1D
|
|
||||||
#from GPy.plotting.matplot_dep import Tango
|
|
||||||
#import pylab as pb
|
|
||||||
from GPy.core.parameterization.param import Param
|
|
||||||
|
|
||||||
import GPy
|
|
||||||
from .. import likelihoods
|
from .. import likelihoods
|
||||||
|
#from . import state_space_setup as ss_setup
|
||||||
|
from ..core import Model
|
||||||
from . import state_space_main as ssm
|
from . import state_space_main as ssm
|
||||||
from . import state_space_setup as ss_setup
|
from . import state_space_setup as ss_setup
|
||||||
|
|
||||||
|
|
@ -37,12 +28,12 @@ class StateSpace(Model):
|
||||||
|
|
||||||
if len(X.shape) == 1:
|
if len(X.shape) == 1:
|
||||||
X = np.atleast_2d(X).T
|
X = np.atleast_2d(X).T
|
||||||
self.num_data, input_dim = X.shape
|
self.num_data, self.input_dim = X.shape
|
||||||
|
|
||||||
if len(Y.shape) == 1:
|
if len(Y.shape) == 1:
|
||||||
Y = np.atleast_2d(Y).T
|
Y = np.atleast_2d(Y).T
|
||||||
|
|
||||||
assert input_dim==1, "State space methods are only for 1D data"
|
assert self.input_dim==1, "State space methods are only for 1D data"
|
||||||
|
|
||||||
if len(Y.shape)==2:
|
if len(Y.shape)==2:
|
||||||
num_data_Y, self.output_dim = Y.shape
|
num_data_Y, self.output_dim = Y.shape
|
||||||
|
|
@ -168,7 +159,7 @@ class StateSpace(Model):
|
||||||
def log_likelihood(self):
|
def log_likelihood(self):
|
||||||
return self._log_marginal_likelihood
|
return self._log_marginal_likelihood
|
||||||
|
|
||||||
def _raw_predict(self, Xnew=None, Ynew=None, filteronly=False):
|
def _raw_predict(self, Xnew=None, Ynew=None, filteronly=False, **kw):
|
||||||
"""
|
"""
|
||||||
Performs the actual prediction for new X points.
|
Performs the actual prediction for new X points.
|
||||||
Inner function. It is called only from inside this class.
|
Inner function. It is called only from inside this class.
|
||||||
|
|
@ -270,22 +261,23 @@ class StateSpace(Model):
|
||||||
# Return the posterior of the state
|
# Return the posterior of the state
|
||||||
return (m, V)
|
return (m, V)
|
||||||
|
|
||||||
def predict(self, Xnew=None, filteronly=False):
|
def predict(self, Xnew=None, filteronly=False, include_likelihood=True, **kw):
|
||||||
|
|
||||||
# Run the Kalman filter to get the state
|
# Run the Kalman filter to get the state
|
||||||
(m, V) = self._raw_predict(Xnew,filteronly=filteronly)
|
(m, V) = self._raw_predict(Xnew,filteronly=filteronly)
|
||||||
|
|
||||||
# Add the noise variance to the state variance
|
# Add the noise variance to the state variance
|
||||||
V += float(self.Gaussian_noise.variance)
|
if include_likelihood:
|
||||||
|
V += float(self.likelihood.variance)
|
||||||
|
|
||||||
# Lower and upper bounds
|
# Lower and upper bounds
|
||||||
lower = m - 2*np.sqrt(V)
|
#lower = m - 2*np.sqrt(V)
|
||||||
upper = m + 2*np.sqrt(V)
|
#upper = m + 2*np.sqrt(V)
|
||||||
|
|
||||||
# Return mean and variance
|
# Return mean and variance
|
||||||
return (m, V, lower, upper)
|
return m, V
|
||||||
|
|
||||||
def predict_quantiles(self, Xnew=None, quantiles=(2.5, 97.5)):
|
def predict_quantiles(self, Xnew=None, quantiles=(2.5, 97.5), **kw):
|
||||||
mu, var = self._raw_predict(Xnew)
|
mu, var = self._raw_predict(Xnew)
|
||||||
#import pdb; pdb.set_trace()
|
#import pdb; pdb.set_trace()
|
||||||
return [stats.norm.ppf(q/100.)*np.sqrt(var + float(self.Gaussian_noise.variance)) + mu for q in quantiles]
|
return [stats.norm.ppf(q/100.)*np.sqrt(var + float(self.Gaussian_noise.variance)) + mu for q in quantiles]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue