New function to plot just the errorbars of the training data

This commit is contained in:
Ricardo 2015-09-02 19:40:11 +01:00
parent a56b6a6d6a
commit 8e437e4bda
3 changed files with 256 additions and 8 deletions

View file

@ -47,6 +47,27 @@ def gpplot(x, mu, lower, upper, edgecol='#3300FF', fillcol='#33CCFF', ax=None, f
return plots
def gperrors(x, mu, lower, upper, edgecol=None, ax=None, fignum=None, **kwargs):
_, axes = ax_default(fignum, ax)
mu = mu.flatten()
x = x.flatten()
lower = lower.flatten()
upper = upper.flatten()
plots = []
if edgecol is None:
edgecol='#3300FF'
if not 'alpha' in kwargs.keys():
kwargs['alpha'] = 0.3
plots.append(axes.errorbar(x,mu,yerr=np.vstack([mu-lower,upper-mu]),color=edgecol,**kwargs))
plots[-1][0].remove()
return plots
def removeRightTicks(ax=None):
ax = ax or pb.gca()
for i, line in enumerate(ax.get_yticklines()):

View file

@ -3,19 +3,79 @@
import numpy as np
from . import Tango
from base_plots import gpplot, x_frame1D, x_frame2D
from base_plots import gpplot, x_frame1D, x_frame2D,gperrors
from ...models.gp_coregionalized_regression import GPCoregionalizedRegression
from ...models.sparse_gp_coregionalized_regression import SparseGPCoregionalizedRegression
from scipy import sparse
from ...core.parameterization.variational import VariationalPosterior
from matplotlib import pyplot as plt
def plot_data(model, which_data_rows='all',
which_data_ycols='all', visible_dims=None,
fignum=None, ax=None, data_symbol='kx',mew=1.5):
"""
Plot the training data
- For higher dimensions than two, use fixed_inputs to plot the data points with some of the inputs fixed.
Can plot only part of the data
using which_data_rows and which_data_ycols.
:param which_data_rows: which of the training data to plot (default all)
:type which_data_rows: 'all' or a slice object to slice model.X, model.Y
:param which_data_ycols: when the data has several columns (independant outputs), only plot these
:type which_data_rows: 'all' or a list of integers
:param visible_dims: an array specifying the input dimensions to plot (maximum two)
:type visible_dims: a numpy array
:param fignum: figure to plot on.
:type fignum: figure number
:param ax: axes to plot on.
:type ax: axes handle
"""
#deal with optional arguments
if which_data_rows == 'all':
which_data_rows = slice(None)
if which_data_ycols == 'all':
which_data_ycols = np.arange(model.output_dim)
if ax is None:
fig = plt.figure(num=fignum)
ax = fig.add_subplot(111)
#data
X = model.X
Y = model.Y
#work out what the inputs are for plotting (1D or 2D)
if visible_dims is None:
visible_dims = np.arange(model.input_dim)
assert visible_dims.size <= 2, "Visible inputs cannot be larger than two"
free_dims = visible_dims
plots = {}
#one dimensional plotting
if len(free_dims) == 1:
for d in which_data_ycols:
plots['dataplot'] = ax.plot(X[which_data_rows,free_dims], Y[which_data_rows, d], data_symbol, mew=mew)
#2D plotting
elif len(free_dims) == 2:
for d in which_data_ycols:
plots['dataplot'] = ax.scatter(X[which_data_rows, free_dims[0]], X[which_data_rows, free_dims[1]], 40,
Y[which_data_rows, d], cmap=plt.cm.jet, vmin=Y.min(), vmax=Y.max(), linewidth=0.)
else:
raise NotImplementedError("Cannot define a frame with more than two input dimensions")
return plots
def plot_fit(model, plot_limits=None, which_data_rows='all',
which_data_ycols='all', fixed_inputs=[],
levels=20, samples=0, fignum=None, ax=None, resolution=None,
plot_raw=False,
linecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue'], Y_metadata=None, data_symbol='kx',
apply_link=False, samples_f=0, plot_uncertain_inputs=True, predict_kw=None):
apply_link=False, samples_f=0, plot_uncertain_inputs=True, predict_kw=None, plot_training_data=True):
"""
Plot the posterior of the GP.
- In one dimension, the function is plotted with a shaded region identifying two standard deviations.
@ -43,7 +103,6 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
:type fignum: figure number
:param ax: axes to plot on.
:type ax: axes handle
:type output: integer (first output is 0)
:param linecol: color of line to plot.
:type linecol:
:param fillcol: color of fill
@ -52,6 +111,8 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
:type apply_link: boolean
:param samples_f: the number of posteriori f samples to plot p(f*|y)
:type samples_f: int
:param plot_training_data: whether or not to plot the training points
:type plot_training_data: boolean
"""
#deal with optional arguments
if which_data_rows == 'all':
@ -116,7 +177,11 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
for d in which_data_ycols:
plots['gpplot'] = gpplot(Xnew, m[:, d], lower[:, d], upper[:, d], ax=ax, edgecol=linecol, fillcol=fillcol)
if not plot_raw: plots['dataplot'] = ax.plot(X[which_data_rows,free_dims], Y[which_data_rows, d], data_symbol, mew=1.5)
#if not plot_raw: plots['dataplot'] = ax.plot(X[which_data_rows,free_dims], Y[which_data_rows, d], data_symbol, mew=1.5)
if not plot_raw and plot_training_data:
plots['dataplot'] = plot_data(model=model, which_data_rows=which_data_rows,
visible_dims=free_dims, data_symbol=data_symbol, mew=1.5, ax=ax, fignum=fignum)
#optionally plot some samples
if samples: #NOTE not tested with fixed_inputs
@ -196,7 +261,9 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
for d in which_data_ycols:
m_d = m[:,d].reshape(resolution, resolution).T
plots['contour'] = ax.contour(x, y, m_d, levels, vmin=m.min(), vmax=m.max(), cmap=plt.cm.jet)
if not plot_raw: plots['dataplot'] = ax.scatter(X[which_data_rows, free_dims[0]], X[which_data_rows, free_dims[1]], 40, Y[which_data_rows, d], cmap=plt.cm.jet, vmin=m.min(), vmax=m.max(), linewidth=0.)
#if not plot_raw: plots['dataplot'] = ax.scatter(X[which_data_rows, free_dims[0]], X[which_data_rows, free_dims[1]], 40, Y[which_data_rows, d], cmap=plt.cm.jet, vmin=m.min(), vmax=m.max(), linewidth=0.)
if not plot_raw and plot_training_data:
plots['dataplot'] = ax.scatter(X[which_data_rows, free_dims[0]], X[which_data_rows, free_dims[1]], 40, Y[which_data_rows, d], cmap=plt.cm.jet, vmin=m.min(), vmax=m.max(), linewidth=0.)
#set the limits of the plot to some sensible values
ax.set_xlim(xmin[0], xmax[0])
@ -261,3 +328,83 @@ def fixed_inputs(model, non_fixed_inputs, fix_routine='median', as_list=True, X_
return f_inputs
else:
return X
def plot_fit_errorbars(model, which_data_rows='all',
which_data_ycols='all', fixed_inputs=[],
fignum=None, ax=None,
linecol='red', data_symbol='kx',
predict_kw=None, plot_training_data=True):
"""
Plot the posterior error bars corresponding to the training data
- For higher dimensions than two, use fixed_inputs to plot the data points with some of the inputs fixed.
Can plot only part of the data
using which_data_rows and which_data_ycols.
:param which_data_rows: which of the training data to plot (default all)
:type which_data_rows: 'all' or a slice object to slice model.X, model.Y
:param which_data_ycols: when the data has several columns (independant outputs), only plot these
:type which_data_rows: 'all' or a list of integers
:param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input index i should be set to value v.
:type fixed_inputs: a list of tuples
:param fignum: figure to plot on.
:type fignum: figure number
:param ax: axes to plot on.
:type ax: axes handle
:param plot_training_data: whether or not to plot the training points
:type plot_training_data: boolean
"""
#deal with optional arguments
if which_data_rows == 'all':
which_data_rows = slice(None)
if which_data_ycols == 'all':
which_data_ycols = np.arange(model.output_dim)
if ax is None:
fig = plt.figure(num=fignum)
ax = fig.add_subplot(111)
X = model.X
Y = model.Y
if predict_kw is None:
predict_kw = {}
#work out what the inputs are for plotting (1D or 2D)
fixed_dims = np.array([i for i,v in fixed_inputs])
free_dims = np.setdiff1d(np.arange(model.input_dim),fixed_dims)
plots = {}
#one dimensional plotting
if len(free_dims) == 1:
m, v = model.predict(X, full_cov=False, Y_metadata=model.Y_metadata, **predict_kw)
fmu, fv = model._raw_predict(X, full_cov=False, **predict_kw)
lower, upper = model.likelihood.predictive_quantiles(fmu, fv, (2.5, 97.5), Y_metadata=model.Y_metadata)
for d in which_data_ycols:
plots['gperrors'] = gperrors(X, m[:, d], lower[:, d], upper[:, d], edgecol=linecol, ax=ax, fignum=fignum )
if plot_training_data:
plots['dataplot'] = plot_data(model=model, which_data_rows=which_data_rows,
visible_dims=free_dims, data_symbol=data_symbol, mew=1.5, ax=ax, fignum=fignum)
#set the limits of the plot to some sensible values
ymin, ymax = min(np.append(Y[which_data_rows, which_data_ycols].flatten(), lower)), max(np.append(Y[which_data_rows, which_data_ycols].flatten(), upper))
ymin, ymax = ymin - 0.1 * (ymax - ymin), ymax + 0.1 * (ymax - ymin)
ax.set_xlim(X[:,free_dims].min(), X[:,free_dims].max())
ax.set_ylim(ymin, ymax)
elif len(free_dims) == 2:
raise NotImplementedError("Not implemented yet")
else:
raise NotImplementedError("Cannot define a frame with more than two input dimensions")
return plots