Added option to plot the transformed link function (posterior once the link function has been applied)

This commit is contained in:
Alan Saul 2015-04-10 15:44:15 +01:00
parent dff9ca8e6b
commit f4cf052bce
2 changed files with 77 additions and 17 deletions

View file

@ -6,14 +6,13 @@ import sys
from .. import kern from .. import kern
from .model import Model from .model import Model
from .parameterization import ObsAr from .parameterization import ObsAr
from .model import Model
from .mapping import Mapping from .mapping import Mapping
from .parameterization import ObsAr
from .. import likelihoods from .. import likelihoods
from ..inference.latent_function_inference import exact_gaussian_inference, expectation_propagation from ..inference.latent_function_inference import exact_gaussian_inference, expectation_propagation
from .parameterization.variational import VariationalPosterior from .parameterization.variational import VariationalPosterior
import logging import logging
import warnings
from GPy.util.normalizer import MeanNorm from GPy.util.normalizer import MeanNorm
logger = logging.getLogger("GP") logger = logging.getLogger("GP")
@ -65,10 +64,14 @@ class GP(Model):
self.Y = ObsAr(Y) self.Y = ObsAr(Y)
self.Y_normalized = self.Y self.Y_normalized = self.Y
assert Y.shape[0] == self.num_data if Y.shape[0] != self.num_data:
#There can be cases where we want inputs than outputs, for example if we have multiple latent
#function values
warnings.warn("There are more rows in your input data X, \
than in your output data Y, be VERY sure this is what you want")
_, self.output_dim = self.Y.shape _, self.output_dim = self.Y.shape
#TODO: check the type of this is okay? assert ((Y_metadata is None) or isinstance(Y_metadata, dict))
self.Y_metadata = Y_metadata self.Y_metadata = Y_metadata
assert isinstance(kernel, kern.Kern) assert isinstance(kernel, kern.Kern)
@ -326,14 +329,14 @@ class GP(Model):
""" """
fsim = self.posterior_samples_f(X, size, full_cov=full_cov) fsim = self.posterior_samples_f(X, size, full_cov=full_cov)
Ysim = self.likelihood.samples(fsim, Y_metadata) Ysim = self.likelihood.samples(fsim, Y_metadata)
return Ysim return Ysim
def plot_f(self, plot_limits=None, which_data_rows='all', def plot_f(self, plot_limits=None, which_data_rows='all',
which_data_ycols='all', fixed_inputs=[], which_data_ycols='all', fixed_inputs=[],
levels=20, samples=0, fignum=None, ax=None, resolution=None, levels=20, samples=0, fignum=None, ax=None, resolution=None,
plot_raw=True, plot_raw=True,
linecol=None,fillcol=None, Y_metadata=None, data_symbol='kx'): linecol=None,fillcol=None, Y_metadata=None, data_symbol='kx',
apply_link=False):
""" """
Plot the GP's view of the world, where the data is normalized and before applying a likelihood. Plot the GP's view of the world, where the data is normalized and before applying a likelihood.
This is a call to plot with plot_raw=True. This is a call to plot with plot_raw=True.
@ -370,6 +373,8 @@ class GP(Model):
:type Y_metadata: dict :type Y_metadata: dict
:param data_symbol: symbol as used matplotlib, by default this is a black cross ('kx') :param data_symbol: symbol as used matplotlib, by default this is a black cross ('kx')
:type data_symbol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) alongside marker type, as is standard in matplotlib. :type data_symbol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) alongside marker type, as is standard in matplotlib.
:param apply_link: if there is a link function of the likelihood, plot the link(f*) rather than f*
:type apply_link: boolean
""" """
assert "matplotlib" in sys.modules, "matplotlib package has not been imported." assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from ..plotting.matplot_dep import models_plots from ..plotting.matplot_dep import models_plots
@ -382,7 +387,7 @@ class GP(Model):
which_data_ycols, fixed_inputs, which_data_ycols, fixed_inputs,
levels, samples, fignum, ax, resolution, levels, samples, fignum, ax, resolution,
plot_raw=plot_raw, Y_metadata=Y_metadata, plot_raw=plot_raw, Y_metadata=Y_metadata,
data_symbol=data_symbol, **kw) data_symbol=data_symbol, apply_link=apply_link, **kw)
def plot(self, plot_limits=None, which_data_rows='all', def plot(self, plot_limits=None, which_data_rows='all',
which_data_ycols='all', fixed_inputs=[], which_data_ycols='all', fixed_inputs=[],

View file

@ -1,4 +1,4 @@
# Copyright (c) 2012, GPy authors (see AUTHORS.txt). # Copyright (c) 2012-2015, GPy authors (see AUTHORS.txt).
# Licensed under the BSD 3-clause license (see LICENSE.txt) # Licensed under the BSD 3-clause license (see LICENSE.txt)
try: try:
@ -16,7 +16,8 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
which_data_ycols='all', fixed_inputs=[], which_data_ycols='all', fixed_inputs=[],
levels=20, samples=0, fignum=None, ax=None, resolution=None, levels=20, samples=0, fignum=None, ax=None, resolution=None,
plot_raw=False, plot_raw=False,
linecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue'], Y_metadata=None, data_symbol='kx'): linecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue'], Y_metadata=None, data_symbol='kx',
apply_link=False, samples_f=0, plot_uncertain_inputs=True):
""" """
Plot the posterior of the GP. Plot the posterior of the GP.
- In one dimension, the function is plotted with a shaded region identifying two standard deviations. - In one dimension, the function is plotted with a shaded region identifying two standard deviations.
@ -38,7 +39,7 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
:type resolution: int :type resolution: int
:param levels: number of levels to plot in a contour plot. :param levels: number of levels to plot in a contour plot.
:type levels: int :type levels: int
:param samples: the number of a posteriori samples to plot :param samples: the number of a posteriori samples to plot p(y*|y)
:type samples: int :type samples: int
:param fignum: figure to plot on. :param fignum: figure to plot on.
:type fignum: figure number :type fignum: figure number
@ -49,6 +50,10 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
:type linecol: :type linecol:
:param fillcol: color of fill :param fillcol: color of fill
:param levels: for 2D plotting, the number of contour levels to use is ax is None, create a new figure :param levels: for 2D plotting, the number of contour levels to use is ax is None, create a new figure
:param apply_link: apply the link function if plotting f (default false)
:type apply_link: boolean
:param samples_f: the number of posteriori f samples to plot p(f*|y)
:type samples_f: int
""" """
#deal with optional arguments #deal with optional arguments
if which_data_rows == 'all': if which_data_rows == 'all':
@ -88,8 +93,14 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
#make a prediction on the frame and plot it #make a prediction on the frame and plot it
if plot_raw: if plot_raw:
m, v = model._raw_predict(Xgrid) m, v = model._raw_predict(Xgrid)
lower = m - 2*np.sqrt(v) if apply_link:
upper = m + 2*np.sqrt(v) lower = model.likelihood.gp_link.transf(m - 2*np.sqrt(v))
upper = model.likelihood.gp_link.transf(m + 2*np.sqrt(v))
#Once transformed this is now the median of the function
m = model.likelihood.gp_link.transf(m)
else:
lower = m - 2*np.sqrt(v)
upper = m + 2*np.sqrt(v)
else: else:
if isinstance(model,GPCoregionalizedRegression) or isinstance(model,SparseGPCoregionalizedRegression): if isinstance(model,GPCoregionalizedRegression) or isinstance(model,SparseGPCoregionalizedRegression):
meta = {'output_index': Xgrid[:,-1:].astype(np.int)} meta = {'output_index': Xgrid[:,-1:].astype(np.int)}
@ -110,13 +121,31 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
plots['posterior_samples'] = ax.plot(Xnew, yi[:,None], Tango.colorsHex['darkBlue'], linewidth=0.25) plots['posterior_samples'] = ax.plot(Xnew, yi[:,None], Tango.colorsHex['darkBlue'], linewidth=0.25)
#ax.plot(Xnew, yi[:,None], marker='x', linestyle='--',color=Tango.colorsHex['darkBlue']) #TODO apply this line for discrete outputs. #ax.plot(Xnew, yi[:,None], marker='x', linestyle='--',color=Tango.colorsHex['darkBlue']) #TODO apply this line for discrete outputs.
if samples_f: #NOTE not tested with fixed_inputs
Fsim = model.posterior_samples_f(Xgrid, samples_f)
for fi in Fsim.T:
plots['posterior_samples_f'] = ax.plot(Xnew, fi[:,None], Tango.colorsHex['darkBlue'], linewidth=0.25)
#ax.plot(Xnew, yi[:,None], marker='x', linestyle='--',color=Tango.colorsHex['darkBlue']) #TODO apply this line for discrete outputs.
#add error bars for uncertain (if input uncertainty is being modelled) #add error bars for uncertain (if input uncertainty is being modelled)
if hasattr(model,"has_uncertain_inputs") and model.has_uncertain_inputs(): if hasattr(model,"has_uncertain_inputs") and model.has_uncertain_inputs() and plot_uncertain_inputs:
plots['xerrorbar'] = ax.errorbar(X[which_data_rows, free_dims].flatten(), Y[which_data_rows, which_data_ycols].flatten(), if plot_raw:
xerr=2 * np.sqrt(X_variance[which_data_rows, free_dims].flatten()), #add error bars for uncertain (if input uncertainty is being modelled), for plot_f
ecolor='k', fmt=None, elinewidth=.5, alpha=.5) #Hack to plot error bars on latent function, rather than on the data
vs = model.X.mean.values.copy()
for i,v in fixed_inputs:
vs[:,i] = v
m_X, _ = model._raw_predict(vs)
if apply_link:
m_X = model.likelihood.gp_link.transf(m_X)
plots['xerrorbar'] = ax.errorbar(X[which_data_rows, free_dims].flatten(), m_X[which_data_rows, which_data_ycols].flatten(),
xerr=2 * np.sqrt(X_variance[which_data_rows, free_dims].flatten()),
ecolor='k', fmt=None, elinewidth=.5, alpha=.5)
else:
plots['xerrorbar'] = ax.errorbar(X[which_data_rows, free_dims].flatten(), Y[which_data_rows, which_data_ycols].flatten(),
xerr=2 * np.sqrt(X_variance[which_data_rows, free_dims].flatten()),
ecolor='k', fmt=None, elinewidth=.5, alpha=.5)
#set the limits of the plot to some sensible values #set the limits of the plot to some sensible values
ymin, ymax = min(np.append(Y[which_data_rows, which_data_ycols].flatten(), lower)), max(np.append(Y[which_data_rows, which_data_ycols].flatten(), upper)) ymin, ymax = min(np.append(Y[which_data_rows, which_data_ycols].flatten(), lower)), max(np.append(Y[which_data_rows, which_data_ycols].flatten(), upper))
@ -186,3 +215,29 @@ def plot_fit_f(model, *args, **kwargs):
""" """
kwargs['plot_raw'] = True kwargs['plot_raw'] = True
plot_fit(model,*args, **kwargs) plot_fit(model,*args, **kwargs)
def fixed_inputs(model, non_fixed_inputs, fix_routine='median'):
"""
Convenience function for returning back fixed_inputs where the other inputs
are fixed using fix_routine
:param model: model
:type model: Model
:param non_fixed_inputs: dimensions of non fixed inputs
:type non_fixed_inputs: list
:param fix_routine: fixing routine to use, 'mean', 'median', 'zero'
:type fix_routine: string
"""
f_inputs = []
if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs():
X = model.X.mean.values.copy()
else:
X = model.X.values.copy()
for i in range(X.shape[1]):
if i not in non_fixed_inputs:
if fix_routine == 'mean':
f_inputs.append( (i, np.mean(X[:,i])) )
if fix_routine == 'median':
f_inputs.append( (i, np.median(X[:,i])) )
elif fix_routine == 'zero':
f_inputs.append( (i, 0) )
return f_inputs