Added option to plot the transformed link function (posterior once the link function has been applied)

This commit is contained in:
Alan Saul 2015-04-10 15:44:15 +01:00
parent dff9ca8e6b
commit f4cf052bce
2 changed files with 77 additions and 17 deletions

View file

@ -6,14 +6,13 @@ import sys
from .. import kern
from .model import Model
from .parameterization import ObsAr
from .model import Model
from .mapping import Mapping
from .parameterization import ObsAr
from .. import likelihoods
from ..inference.latent_function_inference import exact_gaussian_inference, expectation_propagation
from .parameterization.variational import VariationalPosterior
import logging
import warnings
from GPy.util.normalizer import MeanNorm
logger = logging.getLogger("GP")
@ -65,10 +64,14 @@ class GP(Model):
self.Y = ObsAr(Y)
self.Y_normalized = self.Y
assert Y.shape[0] == self.num_data
if Y.shape[0] != self.num_data:
#There can be cases where we want inputs than outputs, for example if we have multiple latent
#function values
warnings.warn("There are more rows in your input data X, \
than in your output data Y, be VERY sure this is what you want")
_, self.output_dim = self.Y.shape
#TODO: check the type of this is okay?
assert ((Y_metadata is None) or isinstance(Y_metadata, dict))
self.Y_metadata = Y_metadata
assert isinstance(kernel, kern.Kern)
@ -326,14 +329,14 @@ class GP(Model):
"""
fsim = self.posterior_samples_f(X, size, full_cov=full_cov)
Ysim = self.likelihood.samples(fsim, Y_metadata)
return Ysim
def plot_f(self, plot_limits=None, which_data_rows='all',
which_data_ycols='all', fixed_inputs=[],
levels=20, samples=0, fignum=None, ax=None, resolution=None,
plot_raw=True,
linecol=None,fillcol=None, Y_metadata=None, data_symbol='kx'):
linecol=None,fillcol=None, Y_metadata=None, data_symbol='kx',
apply_link=False):
"""
Plot the GP's view of the world, where the data is normalized and before applying a likelihood.
This is a call to plot with plot_raw=True.
@ -370,6 +373,8 @@ class GP(Model):
:type Y_metadata: dict
:param data_symbol: symbol as used matplotlib, by default this is a black cross ('kx')
:type data_symbol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) alongside marker type, as is standard in matplotlib.
:param apply_link: if there is a link function of the likelihood, plot the link(f*) rather than f*
:type apply_link: boolean
"""
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from ..plotting.matplot_dep import models_plots
@ -382,7 +387,7 @@ class GP(Model):
which_data_ycols, fixed_inputs,
levels, samples, fignum, ax, resolution,
plot_raw=plot_raw, Y_metadata=Y_metadata,
data_symbol=data_symbol, **kw)
data_symbol=data_symbol, apply_link=apply_link, **kw)
def plot(self, plot_limits=None, which_data_rows='all',
which_data_ycols='all', fixed_inputs=[],