improved kernel plotting

This commit is contained in:
James Hensman 2014-11-03 14:03:37 +00:00
parent 4905348dbe
commit 304db40f5b
2 changed files with 40 additions and 14 deletions

View file

@ -132,13 +132,20 @@ class Kern(Parameterized):
""" """
raise NotImplementedError raise NotImplementedError
def plot(self, *args, **kwargs): def plot(self, x=None, fignum=None, ax=None, title=None, plot_limits=None, resolution=None, **mpl_kwargs):
""" """
See GPy.plotting.matplot_dep.plot plot this kernel.
:param x: the value to use for the other kernel argument (kernels are a function of two variables!)
:param fignum: figure number of the plot
:param ax: matplotlib axis to plot on
:param title: the matplotlib title
:param plot_limits: the range over which to plot the kernel
:resolution: the resolution of the lines used in plotting
:mpl_kwargs avalid keyword arguments to pass through to matplotlib (e.g. lw=7)
""" """
assert "matplotlib" in sys.modules, "matplotlib package has not been imported." assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from ...plotting.matplot_dep import kernel_plots from ...plotting.matplot_dep import kernel_plots
kernel_plots.plot(self,*args) kernel_plots.plot(self, x, fignum, ax, title, plot_limits, resolution, **mpl_kwargs)
def plot_ARD(self, *args, **kw): def plot_ARD(self, *args, **kw):
""" """

View file

@ -99,7 +99,26 @@ def plot_ARD(kernel, fignum=None, ax=None, title='', legend=False, filtering=Non
return ax return ax
def plot(kernel, x=None, plot_limits=None, resolution=None, *args, **kwargs):
def plot(kernel,x=None, fignum=None, ax=None, title=None, plot_limits=None, resolution=None, **mpl_kwargs):
"""
plot a kernel.
:param x: the value to use for the other kernel argument (kernels are a function of two variables!)
:param fignum: figure number of the plot
:param ax: matplotlib axis to plot on
:param title: the matplotlib title
:param plot_limits: the range over which to plot the kernel
:resolution: the resolution of the lines used in plotting
:mpl_kwargs avalid keyword arguments to pass through to matplotlib (e.g. lw=7)
"""
fig, ax = ax_default(fignum,ax)
if title is None:
ax.set_title('%s kernel' % kernel.name)
else:
ax.set_title(title)
if kernel.input_dim == 1: if kernel.input_dim == 1:
if x is None: if x is None:
x = np.zeros((1, 1)) x = np.zeros((1, 1))
@ -117,10 +136,10 @@ def plot(kernel, x=None, plot_limits=None, resolution=None, *args, **kwargs):
Xnew = np.linspace(xmin, xmax, resolution or 201)[:, None] Xnew = np.linspace(xmin, xmax, resolution or 201)[:, None]
Kx = kernel.K(Xnew, x) Kx = kernel.K(Xnew, x)
pb.plot(Xnew, Kx, *args, **kwargs) ax.plot(Xnew, Kx, **mpl_kwargs)
pb.xlim(xmin, xmax) ax.set_xlim(xmin, xmax)
pb.xlabel("x") ax.set_xlabel("x")
pb.ylabel("k(x,%0.1f)" % x) ax.set_ylabel("k(x,%0.1f)" % x)
elif kernel.input_dim == 2: elif kernel.input_dim == 2:
if x is None: if x is None:
@ -142,11 +161,11 @@ def plot(kernel, x=None, plot_limits=None, resolution=None, *args, **kwargs):
Xnew = np.vstack((xx.flatten(), yy.flatten())).T Xnew = np.vstack((xx.flatten(), yy.flatten())).T
Kx = kernel.K(Xnew, x) Kx = kernel.K(Xnew, x)
Kx = Kx.reshape(resolution, resolution).T Kx = Kx.reshape(resolution, resolution).T
pb.contour(xx, xx, Kx, vmin=Kx.min(), vmax=Kx.max(), cmap=pb.cm.jet, *args, **kwargs) # @UndefinedVariable ax.contour(xx, xx, Kx, vmin=Kx.min(), vmax=Kx.max(), cmap=pb.cm.jet, **mpl_kwargs) # @UndefinedVariable
pb.xlim(xmin[0], xmax[0]) ax.set_xlim(xmin[0], xmax[0])
pb.ylim(xmin[1], xmax[1]) ax.set_ylim(xmin[1], xmax[1])
pb.xlabel("x1") ax.set_xlabel("x1")
pb.ylabel("x2") ax.set_ylabel("x2")
pb.title("k(x1,x2 ; %0.1f,%0.1f)" % (x[0, 0], x[0, 1])) ax.set_title("k(x1,x2 ; %0.1f,%0.1f)" % (x[0, 0], x[0, 1]))
else: else:
raise NotImplementedError, "Cannot plot a kernel with more than two input dimensions" raise NotImplementedError, "Cannot plot a kernel with more than two input dimensions"