[added testing and plotting] restructuring the plotting library

This commit is contained in:
mzwiessele 2015-10-02 18:32:46 +01:00
parent b9bfd0fc6d
commit c7d50ee83b
15 changed files with 509 additions and 0 deletions

View file

@ -0,0 +1,29 @@
def update_not_existing_kwargs(to_update, update_from):
return to_update.update({k:v for k,v in update_from.items() if k not in to_update})
#===============================================================================
# Implement library specific defaults in the specific plotting librarys defaults.py file.
# The following lines ensure, that an empty kwarg gets returned, when accessing a not
# existing default
from .. import plotting_library as pl
from collections import defaultdict
class defaultdict(defaultdict):
def __getattr__(self, *args, **kwargs):
return defaultdict.__getitem__(self, *args, **kwargs)
defaults = defaultdict(dict, **pl.defaults.__dict__)
pl.defaults = defaults
#===============================================================================
#===============================================================================
# Make sure that the necessary files and functions are
# defined in the plotting library:
assert hasattr(pl, 'get_new_canvas'), "Please implement a function to get a new canvas for the specific library in plotting_library.get_new_canvas(**kwargs)"
assert hasattr(pl, 'plot'), "Please implement a function to plot a simple line"
assert hasattr(pl, 'scatter'), "Please implement a function to plot a simple scatterplot"
#assert hasattr(pl, 'xerrorbar'), "Please implement a function to plot an errorbar along the xaxis"
#assert hasattr(pl, 'xerrorbar'), "Please implement a function to plot an errorbar along the yaxis"
#assert hasattr(pl, 'fill'), "Please implement a function to fill a section between points"
#assert hasattr(pl, 'imshow'), "Please implement a function to plot an image in the given boundaries"
#===============================================================================
from . import data_plots, gp_plots

View file

@ -0,0 +1,104 @@
#===============================================================================
# Copyright (c) 2012-2015, GPy authors (see AUTHORS.txt).
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of GPy nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#===============================================================================
from . import pl
from . import update_not_existing_kwargs
from . import defaults
from functools import wraps
import numpy as np
def _plot_data(self, canvas, which_data_rows='all',
which_data_ycols='all', visible_dims=None,
error_kwargs=None, **plot_kwargs):
"""
Plot the training data
- For higher dimensions than two, use fixed_inputs to plot the data points with some of the inputs fixed.
Can plot only part of the data
using which_data_rows and which_data_ycols.
:param which_data_rows: which of the training data to plot (default all)
:type which_data_rows: 'all' or a slice object to slice self.X, self.Y
:param which_data_ycols: when the data has several columns (independant outputs), only plot these
:type which_data_rows: 'all' or a list of integers
:param visible_dims: an array specifying the input dimensions to plot (maximum two)
:type visible_dims: a numpy array
:param dict error_kwargs: kwargs for the error plot for the plotting library you are using
:param kwargs plot_kwargs: kwargs for the data plot for the plotting library you are using
"""
#deal with optional arguments
if which_data_rows == 'all':
which_data_rows = slice(None)
if which_data_ycols == 'all':
which_data_ycols = np.arange(self.output_dim)
if error_kwargs is None:
error_kwargs = {}
if hasattr(self, 'has_uncertain_inputs') and self.has_uncertain_inputs():
X = self.X.mean
X_variance = self.X.variance
else:
X = self.X
X_variance = None
Y = self.Y
#work out what the inputs are for plotting (1D or 2D)
if visible_dims is None:
visible_dims = np.arange(self.input_dim)
assert visible_dims.size <= 2, "Visible inputs cannot be larger than two"
free_dims = visible_dims
#one dimensional plotting
if len(free_dims) == 1:
for d in which_data_ycols:
update_not_existing_kwargs(plot_kwargs, defaults.data_1d)
canvas.append(pl.scatter(canvas, X[which_data_rows, free_dims], Y[which_data_rows, d], **plot_kwargs))
if X_variance is not None:
update_not_existing_kwargs(error_kwargs, defaults.xerrorbar)
canvas.append(pl.xerrorbar(canvas, X[which_data_rows, free_dims].flatten(), Y[which_data_rows, d].flatten(),
2 * np.sqrt(X_variance[which_data_rows, free_dims].flatten()),
**error_kwargs))
#2D plotting
elif len(free_dims) == 2:
for d in which_data_ycols:
update_not_existing_kwargs(plot_kwargs, defaults.data_2d)
canvas = pl.scatter(canvas, X[which_data_rows, free_dims[0]], X[which_data_rows, free_dims[1]],
c=Y[which_data_rows, d], vmin=Y.min(), vmax=Y.max(), **plot_kwargs)
else:
raise NotImplementedError("Cannot plot in more then two dimensions")
return canvas
@wraps(_plot_data)
def plot_data(self, which_data_rows='all',
which_data_ycols='all', visible_dims=None,
error_kwargs=None, **plot_kwargs):
canvas, kwargs = pl.get_new_canvas(plot_kwargs)
_plot_data(self, canvas, which_data_rows, which_data_ycols, visible_dims, error_kwargs, **kwargs)
return pl.show_canvas(canvas)

View file

@ -0,0 +1,105 @@
#===============================================================================
# Copyright (c) 2012-2015, GPy authors (see AUTHORS.txt).
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of GPy nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#===============================================================================
from . import pl
from . import update_not_existing_kwargs, defaults
from .util import x_frame1D
from scipy import sparse
import numpy as np
def plot_mean(self, plot_limits=None, fixed_inputs=[],
resolution=None, plot_raw=False,
Y_metadata=None, apply_link=False,
plot_uncertain_inputs=True, predict_kw=None,
**kwargs):
"""
Plot the mean of a GP.
:param plot_limits: The limits of the plot. If 1D [xmin,xmax], if 2D [[xmin,ymin],[xmax,ymax]]. Defaluts to data limits
:type plot_limits: np.array
:param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input index i should be set to value v.
:type fixed_inputs: a list of tuples
:param levels: for 2D plotting, the number of contour levels to use is ax is None, create a new figure
:type levels: int
"""
if hasattr(self, 'has_uncertain_inputs') and self.has_uncertain_inputs():
X = self.X.mean
X_variance = self.X.variance
else:
X = self.X
Y = self.Y
if sparse.issparse(Y): Y = Y.todense().view(np.ndarray)
if predict_kw is None:
predict_kw = {}
#work out what the inputs are for plotting (1D or 2D)
fixed_dims = np.array([i for i,v in fixed_inputs])
free_dims = np.setdiff1d(np.arange(self.input_dim),fixed_dims)
#define the frame on which to plot
Xnew, xmin, xmax = x_frame1D(X[:,free_dims], plot_limits=plot_limits, resolution=resolution or 200)
Xgrid = np.empty((Xnew.shape[0],self.input_dim))
Xgrid[:,free_dims] = Xnew
for i,v in fixed_inputs:
Xgrid[:,i] = v
if plot_raw:
mu = self._raw_predict(Xgrid)[0]
update_not_existing_kwargs(kwargs, defaults.meanplot)
return pl.plot(Xgrid, mu, **kwargs)
def gpplot(x, mu, lower, upper, edgecol='#3300FF', fillcol='#33CCFF', ax=None, fignum=None, **kwargs):
_, axes = ax_default(fignum, ax)
mu = mu.flatten()
x = x.flatten()
lower = lower.flatten()
upper = upper.flatten()
plots = []
#here's the mean
plots.append(meanplot(x, mu, edgecol, axes))
#here's the box
kwargs['linewidth']=0.5
if not 'alpha' in kwargs.keys():
kwargs['alpha'] = 0.3
plots.append(axes.fill(np.hstack((x,x[::-1])),np.hstack((upper,lower[::-1])),color=fillcol,**kwargs))
#this is the edge:
plots.append(meanplot(x, upper,color=edgecol, linewidth=0.2, ax=axes))
plots.append(meanplot(x, lower,color=edgecol, linewidth=0.2, ax=axes))
return plots

View file

@ -0,0 +1,29 @@
#===============================================================================
# Copyright (c) 2012-2015, GPy authors (see AUTHORS.txt).
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of GPy nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#===============================================================================

View file

@ -0,0 +1,68 @@
#===============================================================================
# Copyright (c) 2012-2015 GPy Authors
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of GPy nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#===============================================================================
import numpy as np
def x_frame1D(X,plot_limits=None,resolution=None):
"""
Internal helper function for making plots, returns a set of input values to plot as well as lower and upper limits
"""
assert X.shape[1] ==1, "x_frame1D is defined for one-dimensional inputs"
if plot_limits is None:
from ...core.parameterization.variational import VariationalPosterior
if isinstance(X, VariationalPosterior):
xmin,xmax = X.mean.min(0),X.mean.max(0)
else:
xmin,xmax = X.min(0),X.max(0)
xmin, xmax = xmin-0.2*(xmax-xmin), xmax+0.2*(xmax-xmin)
elif len(plot_limits)==2:
xmin, xmax = plot_limits
else:
raise ValueError("Bad limits for plotting")
Xnew = np.linspace(xmin,xmax,resolution or 200)[:,None]
return Xnew, xmin, xmax
def x_frame2D(X,plot_limits=None,resolution=None):
"""
Internal helper function for making plots, returns a set of input values to plot as well as lower and upper limits
"""
assert X.shape[1] ==2, "x_frame2D is defined for two-dimensional inputs"
if plot_limits is None:
xmin,xmax = X.min(0),X.max(0)
xmin, xmax = xmin-0.2*(xmax-xmin), xmax+0.2*(xmax-xmin)
elif len(plot_limits)==2:
xmin, xmax = plot_limits
else:
raise ValueError("Bad limits for plotting")
resolution = resolution or 50
xx,yy = np.mgrid[xmin[0]:xmax[0]:1j*resolution,xmin[1]:xmax[1]:1j*resolution]
Xnew = np.vstack((xx.flatten(),yy.flatten())).T
return Xnew, xx, yy, xmin, xmax

View file

@ -0,0 +1,50 @@
#===============================================================================
# Copyright (c) 2015, Max Zwiessele
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of GPy nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#===============================================================================
'''
This file is for defaults for the gpy plot, specific to the plotting library.
Create a kwargs dictionary with the right name for the plotting function
you are implementing. If you do not provide defaults, the default behaviour of
the plotting library will be used.
In the code, always ise plotting.gpy_plots.defaults to get the defaults, as
it gives back an empty default, when defaults are not defined.
'''
from matplotlib import cm
# Data:
data_1d = dict(lw=1.5, marker='x', edgecolor='k')
data_2d = dict(s=35, edgecolors='none', linewidth=0., cmap=cm.get_cmap('hot'))
xerrorbar = dict(ecolor='k', fmt='none', elinewidth=.5, alpha=.5)
yerrorbar = dict(ecolor='darkred', fmt='none', elinewidth=.5, alpha=.5)
# GP plots
meanplot = dict(color='#3300FF', linewidth=2)