improved kernel plotting

This commit is contained in:
James Hensman 2014-11-03 14:03:37 +00:00
parent 4905348dbe
commit 304db40f5b
2 changed files with 40 additions and 14 deletions

View file

@ -99,7 +99,26 @@ def plot_ARD(kernel, fignum=None, ax=None, title='', legend=False, filtering=Non
return ax
def plot(kernel, x=None, plot_limits=None, resolution=None, *args, **kwargs):
def plot(kernel,x=None, fignum=None, ax=None, title=None, plot_limits=None, resolution=None, **mpl_kwargs):
"""
plot a kernel.
:param x: the value to use for the other kernel argument (kernels are a function of two variables!)
:param fignum: figure number of the plot
:param ax: matplotlib axis to plot on
:param title: the matplotlib title
:param plot_limits: the range over which to plot the kernel
:resolution: the resolution of the lines used in plotting
:mpl_kwargs avalid keyword arguments to pass through to matplotlib (e.g. lw=7)
"""
fig, ax = ax_default(fignum,ax)
if title is None:
ax.set_title('%s kernel' % kernel.name)
else:
ax.set_title(title)
if kernel.input_dim == 1:
if x is None:
x = np.zeros((1, 1))
@ -117,10 +136,10 @@ def plot(kernel, x=None, plot_limits=None, resolution=None, *args, **kwargs):
Xnew = np.linspace(xmin, xmax, resolution or 201)[:, None]
Kx = kernel.K(Xnew, x)
pb.plot(Xnew, Kx, *args, **kwargs)
pb.xlim(xmin, xmax)
pb.xlabel("x")
pb.ylabel("k(x,%0.1f)" % x)
ax.plot(Xnew, Kx, **mpl_kwargs)
ax.set_xlim(xmin, xmax)
ax.set_xlabel("x")
ax.set_ylabel("k(x,%0.1f)" % x)
elif kernel.input_dim == 2:
if x is None:
@ -142,11 +161,11 @@ def plot(kernel, x=None, plot_limits=None, resolution=None, *args, **kwargs):
Xnew = np.vstack((xx.flatten(), yy.flatten())).T
Kx = kernel.K(Xnew, x)
Kx = Kx.reshape(resolution, resolution).T
pb.contour(xx, xx, Kx, vmin=Kx.min(), vmax=Kx.max(), cmap=pb.cm.jet, *args, **kwargs) # @UndefinedVariable
pb.xlim(xmin[0], xmax[0])
pb.ylim(xmin[1], xmax[1])
pb.xlabel("x1")
pb.ylabel("x2")
pb.title("k(x1,x2 ; %0.1f,%0.1f)" % (x[0, 0], x[0, 1]))
ax.contour(xx, xx, Kx, vmin=Kx.min(), vmax=Kx.max(), cmap=pb.cm.jet, **mpl_kwargs) # @UndefinedVariable
ax.set_xlim(xmin[0], xmax[0])
ax.set_ylim(xmin[1], xmax[1])
ax.set_xlabel("x1")
ax.set_ylabel("x2")
ax.set_title("k(x1,x2 ; %0.1f,%0.1f)" % (x[0, 0], x[0, 1]))
else:
raise NotImplementedError, "Cannot plot a kernel with more than two input dimensions"