Added option to plot the transformed link function (posterior once the link function has been applied)

This commit is contained in:
Alan Saul 2015-04-10 15:44:15 +01:00
parent dff9ca8e6b
commit f4cf052bce
2 changed files with 77 additions and 17 deletions

View file

@ -1,4 +1,4 @@
# Copyright (c) 2012, GPy authors (see AUTHORS.txt).
# Copyright (c) 2012-2015, GPy authors (see AUTHORS.txt).
# Licensed under the BSD 3-clause license (see LICENSE.txt)
try:
@ -16,7 +16,8 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
which_data_ycols='all', fixed_inputs=[],
levels=20, samples=0, fignum=None, ax=None, resolution=None,
plot_raw=False,
linecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue'], Y_metadata=None, data_symbol='kx'):
linecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue'], Y_metadata=None, data_symbol='kx',
apply_link=False, samples_f=0, plot_uncertain_inputs=True):
"""
Plot the posterior of the GP.
- In one dimension, the function is plotted with a shaded region identifying two standard deviations.
@ -38,7 +39,7 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
:type resolution: int
:param levels: number of levels to plot in a contour plot.
:type levels: int
:param samples: the number of a posteriori samples to plot
:param samples: the number of a posteriori samples to plot p(y*|y)
:type samples: int
:param fignum: figure to plot on.
:type fignum: figure number
@ -49,6 +50,10 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
:type linecol:
:param fillcol: color of fill
:param levels: for 2D plotting, the number of contour levels to use is ax is None, create a new figure
:param apply_link: apply the link function if plotting f (default false)
:type apply_link: boolean
:param samples_f: the number of posteriori f samples to plot p(f*|y)
:type samples_f: int
"""
#deal with optional arguments
if which_data_rows == 'all':
@ -88,8 +93,14 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
#make a prediction on the frame and plot it
if plot_raw:
m, v = model._raw_predict(Xgrid)
lower = m - 2*np.sqrt(v)
upper = m + 2*np.sqrt(v)
if apply_link:
lower = model.likelihood.gp_link.transf(m - 2*np.sqrt(v))
upper = model.likelihood.gp_link.transf(m + 2*np.sqrt(v))
#Once transformed this is now the median of the function
m = model.likelihood.gp_link.transf(m)
else:
lower = m - 2*np.sqrt(v)
upper = m + 2*np.sqrt(v)
else:
if isinstance(model,GPCoregionalizedRegression) or isinstance(model,SparseGPCoregionalizedRegression):
meta = {'output_index': Xgrid[:,-1:].astype(np.int)}
@ -110,13 +121,31 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
plots['posterior_samples'] = ax.plot(Xnew, yi[:,None], Tango.colorsHex['darkBlue'], linewidth=0.25)
#ax.plot(Xnew, yi[:,None], marker='x', linestyle='--',color=Tango.colorsHex['darkBlue']) #TODO apply this line for discrete outputs.
if samples_f: #NOTE not tested with fixed_inputs
Fsim = model.posterior_samples_f(Xgrid, samples_f)
for fi in Fsim.T:
plots['posterior_samples_f'] = ax.plot(Xnew, fi[:,None], Tango.colorsHex['darkBlue'], linewidth=0.25)
#ax.plot(Xnew, yi[:,None], marker='x', linestyle='--',color=Tango.colorsHex['darkBlue']) #TODO apply this line for discrete outputs.
#add error bars for uncertain (if input uncertainty is being modelled)
if hasattr(model,"has_uncertain_inputs") and model.has_uncertain_inputs():
plots['xerrorbar'] = ax.errorbar(X[which_data_rows, free_dims].flatten(), Y[which_data_rows, which_data_ycols].flatten(),
xerr=2 * np.sqrt(X_variance[which_data_rows, free_dims].flatten()),
ecolor='k', fmt=None, elinewidth=.5, alpha=.5)
if hasattr(model,"has_uncertain_inputs") and model.has_uncertain_inputs() and plot_uncertain_inputs:
if plot_raw:
#add error bars for uncertain (if input uncertainty is being modelled), for plot_f
#Hack to plot error bars on latent function, rather than on the data
vs = model.X.mean.values.copy()
for i,v in fixed_inputs:
vs[:,i] = v
m_X, _ = model._raw_predict(vs)
if apply_link:
m_X = model.likelihood.gp_link.transf(m_X)
plots['xerrorbar'] = ax.errorbar(X[which_data_rows, free_dims].flatten(), m_X[which_data_rows, which_data_ycols].flatten(),
xerr=2 * np.sqrt(X_variance[which_data_rows, free_dims].flatten()),
ecolor='k', fmt=None, elinewidth=.5, alpha=.5)
else:
plots['xerrorbar'] = ax.errorbar(X[which_data_rows, free_dims].flatten(), Y[which_data_rows, which_data_ycols].flatten(),
xerr=2 * np.sqrt(X_variance[which_data_rows, free_dims].flatten()),
ecolor='k', fmt=None, elinewidth=.5, alpha=.5)
#set the limits of the plot to some sensible values
ymin, ymax = min(np.append(Y[which_data_rows, which_data_ycols].flatten(), lower)), max(np.append(Y[which_data_rows, which_data_ycols].flatten(), upper))
@ -186,3 +215,29 @@ def plot_fit_f(model, *args, **kwargs):
"""
kwargs['plot_raw'] = True
plot_fit(model,*args, **kwargs)
def fixed_inputs(model, non_fixed_inputs, fix_routine='median'):
"""
Convenience function for returning back fixed_inputs where the other inputs
are fixed using fix_routine
:param model: model
:type model: Model
:param non_fixed_inputs: dimensions of non fixed inputs
:type non_fixed_inputs: list
:param fix_routine: fixing routine to use, 'mean', 'median', 'zero'
:type fix_routine: string
"""
f_inputs = []
if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs():
X = model.X.mean.values.copy()
else:
X = model.X.values.copy()
for i in range(X.shape[1]):
if i not in non_fixed_inputs:
if fix_routine == 'mean':
f_inputs.append( (i, np.mean(X[:,i])) )
if fix_routine == 'median':
f_inputs.append( (i, np.median(X[:,i])) )
elif fix_routine == 'zero':
f_inputs.append( (i, 0) )
return f_inputs