From 48fb60489160de6fb0e84f6559b85b07dd16e274 Mon Sep 17 00:00:00 2001 From: James Hensman Date: Wed, 17 Sep 2014 12:30:56 +0100 Subject: [PATCH] some improvements to plotting 2d kernels --- GPy/plotting/matplot_dep/kernel_plots.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/GPy/plotting/matplot_dep/kernel_plots.py b/GPy/plotting/matplot_dep/kernel_plots.py index f2082db0..c0bd1599 100644 --- a/GPy/plotting/matplot_dep/kernel_plots.py +++ b/GPy/plotting/matplot_dep/kernel_plots.py @@ -100,9 +100,7 @@ def plot_ARD(kernel, fignum=None, ax=None, title='', legend=False, filtering=Non return ax -def plot(kernel, x=None, plot_limits=None, which_parts='all', resolution=None, *args, **kwargs): - if which_parts == 'all': - which_parts = [True] * kernel.size +def plot(kernel, x=None, plot_limits=None, resolution=None, *args, **kwargs): if kernel.input_dim == 1: if x is None: x = np.zeros((1, 1)) @@ -133,7 +131,7 @@ def plot(kernel, x=None, plot_limits=None, which_parts='all', resolution=None, * assert x.size == 2, "The size of the fixed variable x is not 2" x = x.reshape((1, 2)) - if plot_limits == None: + if plot_limits is None: xmin, xmax = (x - 5).flatten(), (x + 5).flatten() elif len(plot_limits) == 2: xmin, xmax = plot_limits @@ -142,12 +140,10 @@ def plot(kernel, x=None, plot_limits=None, which_parts='all', resolution=None, * resolution = resolution or 51 xx, yy = np.mgrid[xmin[0]:xmax[0]:1j * resolution, xmin[1]:xmax[1]:1j * resolution] - xg = np.linspace(xmin[0], xmax[0], resolution) - yg = np.linspace(xmin[1], xmax[1], resolution) Xnew = np.vstack((xx.flatten(), yy.flatten())).T - Kx = kernel.K(Xnew, x, which_parts) + Kx = kernel.K(Xnew, x) Kx = Kx.reshape(resolution, resolution).T - pb.contour(xg, yg, Kx, vmin=Kx.min(), vmax=Kx.max(), cmap=pb.cm.jet, *args, **kwargs) # @UndefinedVariable + pb.contour(xx, xx, Kx, vmin=Kx.min(), vmax=Kx.max(), cmap=pb.cm.jet, *args, **kwargs) # @UndefinedVariable pb.xlim(xmin[0], xmax[0]) pb.ylim(xmin[1], xmax[1]) pb.xlabel("x1")