Changes to allow multiple output plotting

This commit is contained in:
Ricardo 2013-07-31 19:00:54 +01:00
parent 7e1e8de5e4
commit 1c2a4c5c64
6 changed files with 109 additions and 27 deletions

View file

@ -5,7 +5,7 @@ import numpy as np
import pylab as pb
from ..util.linalg import mdot, jitchol, tdot, symmetrify, backsub_both_sides, chol_inv, dtrtrs, dpotrs, dpotri
from scipy import linalg
from ..likelihoods import Gaussian
from ..likelihoods import Gaussian, EP,EP_Mixed_Noise
from gp_base import GPBase
class SparseGP(GPBase):
@ -314,3 +314,37 @@ class SparseGP(GPBase):
elif self.X.shape[1] == 2:
Zu = self.Z * self._Xscale + self._Xoffset
ax.plot(Zu[:, 0], Zu[:, 1], 'wo')
def predict_single_output(self, Xnew, output=0, which_parts='all', full_cov=False):
"""
Predict the function(s) at the new point(s) Xnew.
Arguments
---------
:param Xnew: The points at which to make a prediction
:type Xnew: np.ndarray, Nnew x self.input_dim
:param which_parts: specifies which outputs kernel(s) to use in prediction
:type which_parts: ('all', list of bools)
:param full_cov: whether to return the folll covariance matrix, or just the diagonal
:type full_cov: bool
:rtype: posterior mean, a Numpy array, Nnew x self.input_dim
:rtype: posterior variance, a Numpy array, Nnew x 1 if full_cov=False, Nnew x Nnew otherwise
:rtype: lower and upper boundaries of the 95% confidence intervals, Numpy arrays, Nnew x self.input_dim
If full_cov and self.input_dim > 1, the return shape of var is Nnew x Nnew x self.input_dim. If self.input_dim == 1, the return shape is Nnew x Nnew.
This is to allow for different normalizations of the output dimensions.
"""
assert isinstance(self.likelihood,EP_Mixed_Noise)
index = np.ones_like(Xnew)*output
Xnew = np.hstack((Xnew,index))
# normalize X values
Xnew = (Xnew.copy() - self._Xoffset) / self._Xscale
mu, var = self._raw_predict(Xnew, full_cov=full_cov, which_parts=which_parts)
# now push through likelihood
mean, var, _025pm, _975pm = self.likelihood.predictive_values(mu, var, full_cov, noise_model = output)
return mean, var, _025pm, _975pm