some improvements to plotting 2d kernels

This commit is contained in:
James Hensman 2014-09-17 12:30:56 +01:00
parent 31478d4d59
commit 48fb604891

View file

@ -100,9 +100,7 @@ def plot_ARD(kernel, fignum=None, ax=None, title='', legend=False, filtering=Non
return ax return ax
def plot(kernel, x=None, plot_limits=None, which_parts='all', resolution=None, *args, **kwargs): def plot(kernel, x=None, plot_limits=None, resolution=None, *args, **kwargs):
if which_parts == 'all':
which_parts = [True] * kernel.size
if kernel.input_dim == 1: if kernel.input_dim == 1:
if x is None: if x is None:
x = np.zeros((1, 1)) x = np.zeros((1, 1))
@ -133,7 +131,7 @@ def plot(kernel, x=None, plot_limits=None, which_parts='all', resolution=None, *
assert x.size == 2, "The size of the fixed variable x is not 2" assert x.size == 2, "The size of the fixed variable x is not 2"
x = x.reshape((1, 2)) x = x.reshape((1, 2))
if plot_limits == None: if plot_limits is None:
xmin, xmax = (x - 5).flatten(), (x + 5).flatten() xmin, xmax = (x - 5).flatten(), (x + 5).flatten()
elif len(plot_limits) == 2: elif len(plot_limits) == 2:
xmin, xmax = plot_limits xmin, xmax = plot_limits
@ -142,12 +140,10 @@ def plot(kernel, x=None, plot_limits=None, which_parts='all', resolution=None, *
resolution = resolution or 51 resolution = resolution or 51
xx, yy = np.mgrid[xmin[0]:xmax[0]:1j * resolution, xmin[1]:xmax[1]:1j * resolution] xx, yy = np.mgrid[xmin[0]:xmax[0]:1j * resolution, xmin[1]:xmax[1]:1j * resolution]
xg = np.linspace(xmin[0], xmax[0], resolution)
yg = np.linspace(xmin[1], xmax[1], resolution)
Xnew = np.vstack((xx.flatten(), yy.flatten())).T Xnew = np.vstack((xx.flatten(), yy.flatten())).T
Kx = kernel.K(Xnew, x, which_parts) Kx = kernel.K(Xnew, x)
Kx = Kx.reshape(resolution, resolution).T Kx = Kx.reshape(resolution, resolution).T
pb.contour(xg, yg, Kx, vmin=Kx.min(), vmax=Kx.max(), cmap=pb.cm.jet, *args, **kwargs) # @UndefinedVariable pb.contour(xx, xx, Kx, vmin=Kx.min(), vmax=Kx.max(), cmap=pb.cm.jet, *args, **kwargs) # @UndefinedVariable
pb.xlim(xmin[0], xmax[0]) pb.xlim(xmin[0], xmax[0])
pb.ylim(xmin[1], xmax[1]) pb.ylim(xmin[1], xmax[1])
pb.xlabel("x1") pb.xlabel("x1")