diff --git a/GPy/plotting/gpy_plot/__init__.py b/GPy/plotting/gpy_plot/__init__.py new file mode 100644 index 00000000..c5a99dad --- /dev/null +++ b/GPy/plotting/gpy_plot/__init__.py @@ -0,0 +1,29 @@ +def update_not_existing_kwargs(to_update, update_from): + return to_update.update({k:v for k,v in update_from.items() if k not in to_update}) + +#=============================================================================== +# Implement library specific defaults in the specific plotting librarys defaults.py file. +# The following lines ensure, that an empty kwarg gets returned, when accessing a not +# existing default +from .. import plotting_library as pl +from collections import defaultdict +class defaultdict(defaultdict): + def __getattr__(self, *args, **kwargs): + return defaultdict.__getitem__(self, *args, **kwargs) +defaults = defaultdict(dict, **pl.defaults.__dict__) +pl.defaults = defaults +#=============================================================================== + +#=============================================================================== +# Make sure that the necessary files and functions are +# defined in the plotting library: +assert hasattr(pl, 'get_new_canvas'), "Please implement a function to get a new canvas for the specific library in plotting_library.get_new_canvas(**kwargs)" +assert hasattr(pl, 'plot'), "Please implement a function to plot a simple line" +assert hasattr(pl, 'scatter'), "Please implement a function to plot a simple scatterplot" +#assert hasattr(pl, 'xerrorbar'), "Please implement a function to plot an errorbar along the xaxis" +#assert hasattr(pl, 'xerrorbar'), "Please implement a function to plot an errorbar along the yaxis" +#assert hasattr(pl, 'fill'), "Please implement a function to fill a section between points" +#assert hasattr(pl, 'imshow'), "Please implement a function to plot an image in the given boundaries" +#=============================================================================== + +from . import data_plots, gp_plots diff --git a/GPy/plotting/gpy_plot/data_plots.py b/GPy/plotting/gpy_plot/data_plots.py new file mode 100644 index 00000000..8840e153 --- /dev/null +++ b/GPy/plotting/gpy_plot/data_plots.py @@ -0,0 +1,104 @@ +#=============================================================================== +# Copyright (c) 2012-2015, GPy authors (see AUTHORS.txt). +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of GPy nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#=============================================================================== +from . import pl +from . import update_not_existing_kwargs +from . import defaults + +from functools import wraps +import numpy as np + +def _plot_data(self, canvas, which_data_rows='all', + which_data_ycols='all', visible_dims=None, + error_kwargs=None, **plot_kwargs): + """ + Plot the training data + - For higher dimensions than two, use fixed_inputs to plot the data points with some of the inputs fixed. + + Can plot only part of the data + using which_data_rows and which_data_ycols. + + :param which_data_rows: which of the training data to plot (default all) + :type which_data_rows: 'all' or a slice object to slice self.X, self.Y + :param which_data_ycols: when the data has several columns (independant outputs), only plot these + :type which_data_rows: 'all' or a list of integers + :param visible_dims: an array specifying the input dimensions to plot (maximum two) + :type visible_dims: a numpy array + :param dict error_kwargs: kwargs for the error plot for the plotting library you are using + :param kwargs plot_kwargs: kwargs for the data plot for the plotting library you are using + """ + #deal with optional arguments + if which_data_rows == 'all': + which_data_rows = slice(None) + if which_data_ycols == 'all': + which_data_ycols = np.arange(self.output_dim) + if error_kwargs is None: + error_kwargs = {} + + if hasattr(self, 'has_uncertain_inputs') and self.has_uncertain_inputs(): + X = self.X.mean + X_variance = self.X.variance + else: + X = self.X + X_variance = None + Y = self.Y + + #work out what the inputs are for plotting (1D or 2D) + if visible_dims is None: + visible_dims = np.arange(self.input_dim) + assert visible_dims.size <= 2, "Visible inputs cannot be larger than two" + free_dims = visible_dims + + #one dimensional plotting + if len(free_dims) == 1: + for d in which_data_ycols: + update_not_existing_kwargs(plot_kwargs, defaults.data_1d) + canvas.append(pl.scatter(canvas, X[which_data_rows, free_dims], Y[which_data_rows, d], **plot_kwargs)) + if X_variance is not None: + update_not_existing_kwargs(error_kwargs, defaults.xerrorbar) + canvas.append(pl.xerrorbar(canvas, X[which_data_rows, free_dims].flatten(), Y[which_data_rows, d].flatten(), + 2 * np.sqrt(X_variance[which_data_rows, free_dims].flatten()), + **error_kwargs)) + #2D plotting + elif len(free_dims) == 2: + for d in which_data_ycols: + update_not_existing_kwargs(plot_kwargs, defaults.data_2d) + canvas = pl.scatter(canvas, X[which_data_rows, free_dims[0]], X[which_data_rows, free_dims[1]], + c=Y[which_data_rows, d], vmin=Y.min(), vmax=Y.max(), **plot_kwargs) + else: + raise NotImplementedError("Cannot plot in more then two dimensions") + return canvas + +@wraps(_plot_data) +def plot_data(self, which_data_rows='all', + which_data_ycols='all', visible_dims=None, + error_kwargs=None, **plot_kwargs): + canvas, kwargs = pl.get_new_canvas(plot_kwargs) + _plot_data(self, canvas, which_data_rows, which_data_ycols, visible_dims, error_kwargs, **kwargs) + return pl.show_canvas(canvas) diff --git a/GPy/plotting/gpy_plot/gp_plots.py b/GPy/plotting/gpy_plot/gp_plots.py new file mode 100644 index 00000000..2e4ab7e6 --- /dev/null +++ b/GPy/plotting/gpy_plot/gp_plots.py @@ -0,0 +1,105 @@ +#=============================================================================== +# Copyright (c) 2012-2015, GPy authors (see AUTHORS.txt). +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of GPy nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#=============================================================================== + +from . import pl +from . import update_not_existing_kwargs, defaults +from .util import x_frame1D +from scipy import sparse +import numpy as np + +def plot_mean(self, plot_limits=None, fixed_inputs=[], + resolution=None, plot_raw=False, + Y_metadata=None, apply_link=False, + plot_uncertain_inputs=True, predict_kw=None, + **kwargs): + """ + Plot the mean of a GP. + + :param plot_limits: The limits of the plot. If 1D [xmin,xmax], if 2D [[xmin,ymin],[xmax,ymax]]. Defaluts to data limits + :type plot_limits: np.array + :param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input index i should be set to value v. + :type fixed_inputs: a list of tuples + :param levels: for 2D plotting, the number of contour levels to use is ax is None, create a new figure + :type levels: int + """ + if hasattr(self, 'has_uncertain_inputs') and self.has_uncertain_inputs(): + X = self.X.mean + X_variance = self.X.variance + else: + X = self.X + + Y = self.Y + + if sparse.issparse(Y): Y = Y.todense().view(np.ndarray) + + if predict_kw is None: + predict_kw = {} + + #work out what the inputs are for plotting (1D or 2D) + fixed_dims = np.array([i for i,v in fixed_inputs]) + free_dims = np.setdiff1d(np.arange(self.input_dim),fixed_dims) + + #define the frame on which to plot + Xnew, xmin, xmax = x_frame1D(X[:,free_dims], plot_limits=plot_limits, resolution=resolution or 200) + Xgrid = np.empty((Xnew.shape[0],self.input_dim)) + Xgrid[:,free_dims] = Xnew + for i,v in fixed_inputs: + Xgrid[:,i] = v + + if plot_raw: + mu = self._raw_predict(Xgrid)[0] + + update_not_existing_kwargs(kwargs, defaults.meanplot) + return pl.plot(Xgrid, mu, **kwargs) + +def gpplot(x, mu, lower, upper, edgecol='#3300FF', fillcol='#33CCFF', ax=None, fignum=None, **kwargs): + _, axes = ax_default(fignum, ax) + + mu = mu.flatten() + x = x.flatten() + lower = lower.flatten() + upper = upper.flatten() + + plots = [] + + #here's the mean + plots.append(meanplot(x, mu, edgecol, axes)) + + #here's the box + kwargs['linewidth']=0.5 + if not 'alpha' in kwargs.keys(): + kwargs['alpha'] = 0.3 + plots.append(axes.fill(np.hstack((x,x[::-1])),np.hstack((upper,lower[::-1])),color=fillcol,**kwargs)) + + #this is the edge: + plots.append(meanplot(x, upper,color=edgecol, linewidth=0.2, ax=axes)) + plots.append(meanplot(x, lower,color=edgecol, linewidth=0.2, ax=axes)) + + return plots diff --git a/GPy/plotting/gpy_plot/plot_util.py b/GPy/plotting/gpy_plot/plot_util.py new file mode 100644 index 00000000..5e38f863 --- /dev/null +++ b/GPy/plotting/gpy_plot/plot_util.py @@ -0,0 +1,29 @@ +#=============================================================================== +# Copyright (c) 2012-2015, GPy authors (see AUTHORS.txt). +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of GPy nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#=============================================================================== diff --git a/GPy/plotting/gpy_plot/util.py b/GPy/plotting/gpy_plot/util.py new file mode 100644 index 00000000..a3c64c90 --- /dev/null +++ b/GPy/plotting/gpy_plot/util.py @@ -0,0 +1,68 @@ +#=============================================================================== +# Copyright (c) 2012-2015 GPy Authors +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of GPy nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#=============================================================================== +import numpy as np + +def x_frame1D(X,plot_limits=None,resolution=None): + """ + Internal helper function for making plots, returns a set of input values to plot as well as lower and upper limits + """ + assert X.shape[1] ==1, "x_frame1D is defined for one-dimensional inputs" + if plot_limits is None: + from ...core.parameterization.variational import VariationalPosterior + if isinstance(X, VariationalPosterior): + xmin,xmax = X.mean.min(0),X.mean.max(0) + else: + xmin,xmax = X.min(0),X.max(0) + xmin, xmax = xmin-0.2*(xmax-xmin), xmax+0.2*(xmax-xmin) + elif len(plot_limits)==2: + xmin, xmax = plot_limits + else: + raise ValueError("Bad limits for plotting") + + Xnew = np.linspace(xmin,xmax,resolution or 200)[:,None] + return Xnew, xmin, xmax + +def x_frame2D(X,plot_limits=None,resolution=None): + """ + Internal helper function for making plots, returns a set of input values to plot as well as lower and upper limits + """ + assert X.shape[1] ==2, "x_frame2D is defined for two-dimensional inputs" + if plot_limits is None: + xmin,xmax = X.min(0),X.max(0) + xmin, xmax = xmin-0.2*(xmax-xmin), xmax+0.2*(xmax-xmin) + elif len(plot_limits)==2: + xmin, xmax = plot_limits + else: + raise ValueError("Bad limits for plotting") + + resolution = resolution or 50 + xx,yy = np.mgrid[xmin[0]:xmax[0]:1j*resolution,xmin[1]:xmax[1]:1j*resolution] + Xnew = np.vstack((xx.flatten(),yy.flatten())).T + return Xnew, xx, yy, xmin, xmax diff --git a/GPy/plotting/matplot_dep/defaults.py b/GPy/plotting/matplot_dep/defaults.py new file mode 100644 index 00000000..d7bdf59f --- /dev/null +++ b/GPy/plotting/matplot_dep/defaults.py @@ -0,0 +1,50 @@ +#=============================================================================== +# Copyright (c) 2015, Max Zwiessele +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of GPy nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#=============================================================================== +''' +This file is for defaults for the gpy plot, specific to the plotting library. + +Create a kwargs dictionary with the right name for the plotting function +you are implementing. If you do not provide defaults, the default behaviour of +the plotting library will be used. + +In the code, always ise plotting.gpy_plots.defaults to get the defaults, as +it gives back an empty default, when defaults are not defined. +''' + +from matplotlib import cm + +# Data: +data_1d = dict(lw=1.5, marker='x', edgecolor='k') +data_2d = dict(s=35, edgecolors='none', linewidth=0., cmap=cm.get_cmap('hot')) +xerrorbar = dict(ecolor='k', fmt='none', elinewidth=.5, alpha=.5) +yerrorbar = dict(ecolor='darkred', fmt='none', elinewidth=.5, alpha=.5) + +# GP plots +meanplot = dict(color='#3300FF', linewidth=2) \ No newline at end of file diff --git a/GPy/testing/plotting_tests.py b/GPy/testing/plotting_tests.py new file mode 100644 index 00000000..0b5c4c7e --- /dev/null +++ b/GPy/testing/plotting_tests.py @@ -0,0 +1,124 @@ +#=============================================================================== +# Copyright (c) 2015, Max Zwiessele +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of GPy nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#=============================================================================== +import numpy as np +import GPy, os, sys +from nose import SkipTest +try: + from matplotlib import cbook +except: + raise SkipTest("Matplotlib not installed, not testing plots") + +def _image_directories(func): + """ + Compute the baseline and result image directories for testing *func*. + Create the result directory if it doesn't exist. + """ + module_name = func.__module__ + + path = module_name + + mods = module_name.split('.') + subdir = os.path.join(*mods) + + basedir = os.path.join(*mods) + + result_dir = os.path.join(basedir, 'testresult') + baseline_dir = os.path.join(basedir, 'baseline') + + if not os.path.exists(result_dir): + cbook.mkdirs(result_dir) + + return baseline_dir, result_dir + +import matplotlib.testing.decorators +matplotlib.testing.decorators._image_directories = _image_directories +from matplotlib.testing.decorators import image_comparison +import matplotlib.pyplot as plt + +@image_comparison(baseline_images=['gp'], extensions=['pdf','png']) +def testPlot(): + fig, ax = plt.subplots() + np.random.seed(11111) + X = np.random.uniform(0, 1, (40, 1)) + f = .2 * np.sin(1.3*X) + 1.3*np.cos(2*X) + Y = f+np.random.normal(0, .1, f.shape) + m = GPy.models.GPRegression(X, Y) + m.optimize() + m.plot_data(ax=ax) + m.plot_mean(ax=ax) + m.plot_confidence(ax=ax) + m.plot_density(ax=ax) + return ax + + +@image_comparison(baseline_images=['gp_class'], extensions=['pdf','png']) +def testPlotClassification(): + fig, ax = plt.subplots() + np.random.seed(11111) + X = np.random.uniform(0, 1, (40, 1)) + f = .2 * np.sin(1.3*X) + 1.3*np.cos(2*X) + Y = f+np.random.normal(0, .1, f.shape) + m = GPy.models.GPClassification(X, Y>Y.mean()) + m.optimize() + m.plot_data(ax=ax) + m.plot_mean(ax=ax) + m.plot_confidence(ax=ax) + m.plot_density(ax=ax) + return ax + +@image_comparison(baseline_images=['sparse_gp_class'], extensions=['pdf','png']) +def testPlotSparseClassification(): + fig, ax = plt.subplots() + np.random.seed(11111) + X = np.random.uniform(0, 1, (40, 1)) + f = .2 * np.sin(1.3*X) + 1.3*np.cos(2*X) + Y = f+np.random.normal(0, .1, f.shape) + m = GPy.models.SparseGPClassification(X, Y>Y.mean()) + m.optimize() + m.plot_data(ax=ax) + m.plot_mean(ax=ax) + m.plot_confidence(ax=ax) + m.plot_density(ax=ax) + return ax + +@image_comparison(baseline_images=['sparse_gp'], extensions=['pdf','png']) +def testPlotSparse(): + fig, ax = plt.subplots() + np.random.seed(11111) + X = np.random.uniform(0, 1, (40, 1)) + f = .2 * np.sin(1.3*X) + 1.3*np.cos(2*X) + Y = f+np.random.normal(0, .1, f.shape) + m = GPy.models.SparseGPRegression(X, Y) + m.optimize() + m.plot_data(ax=ax) + m.plot_mean(ax=ax) + m.plot_confidence(ax=ax) + m.plot_density(ax=ax) + return ax diff --git a/GPy/testing/plotting_tests/baseline/gp.pdf b/GPy/testing/plotting_tests/baseline/gp.pdf new file mode 100644 index 00000000..df4fb267 Binary files /dev/null and b/GPy/testing/plotting_tests/baseline/gp.pdf differ diff --git a/GPy/testing/plotting_tests/baseline/gp.png b/GPy/testing/plotting_tests/baseline/gp.png new file mode 100644 index 00000000..d2b241d2 Binary files /dev/null and b/GPy/testing/plotting_tests/baseline/gp.png differ diff --git a/GPy/testing/plotting_tests/baseline/gp_class.pdf b/GPy/testing/plotting_tests/baseline/gp_class.pdf new file mode 100644 index 00000000..296233ec Binary files /dev/null and b/GPy/testing/plotting_tests/baseline/gp_class.pdf differ diff --git a/GPy/testing/plotting_tests/baseline/gp_class.png b/GPy/testing/plotting_tests/baseline/gp_class.png new file mode 100644 index 00000000..ea65630b Binary files /dev/null and b/GPy/testing/plotting_tests/baseline/gp_class.png differ diff --git a/GPy/testing/plotting_tests/baseline/sparse_gp.pdf b/GPy/testing/plotting_tests/baseline/sparse_gp.pdf new file mode 100644 index 00000000..9072bbc7 Binary files /dev/null and b/GPy/testing/plotting_tests/baseline/sparse_gp.pdf differ diff --git a/GPy/testing/plotting_tests/baseline/sparse_gp.png b/GPy/testing/plotting_tests/baseline/sparse_gp.png new file mode 100644 index 00000000..44a1dfc0 Binary files /dev/null and b/GPy/testing/plotting_tests/baseline/sparse_gp.png differ diff --git a/GPy/testing/plotting_tests/baseline/sparse_gp_class.pdf b/GPy/testing/plotting_tests/baseline/sparse_gp_class.pdf new file mode 100644 index 00000000..b9e926c2 Binary files /dev/null and b/GPy/testing/plotting_tests/baseline/sparse_gp_class.pdf differ diff --git a/GPy/testing/plotting_tests/baseline/sparse_gp_class.png b/GPy/testing/plotting_tests/baseline/sparse_gp_class.png new file mode 100644 index 00000000..fc461de3 Binary files /dev/null and b/GPy/testing/plotting_tests/baseline/sparse_gp_class.png differ