mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-07 11:02:38 +02:00
improved kernel plotting
This commit is contained in:
parent
4905348dbe
commit
304db40f5b
2 changed files with 40 additions and 14 deletions
|
|
@ -132,13 +132,20 @@ class Kern(Parameterized):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def plot(self, *args, **kwargs):
|
def plot(self, x=None, fignum=None, ax=None, title=None, plot_limits=None, resolution=None, **mpl_kwargs):
|
||||||
"""
|
"""
|
||||||
See GPy.plotting.matplot_dep.plot
|
plot this kernel.
|
||||||
|
:param x: the value to use for the other kernel argument (kernels are a function of two variables!)
|
||||||
|
:param fignum: figure number of the plot
|
||||||
|
:param ax: matplotlib axis to plot on
|
||||||
|
:param title: the matplotlib title
|
||||||
|
:param plot_limits: the range over which to plot the kernel
|
||||||
|
:resolution: the resolution of the lines used in plotting
|
||||||
|
:mpl_kwargs avalid keyword arguments to pass through to matplotlib (e.g. lw=7)
|
||||||
"""
|
"""
|
||||||
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
||||||
from ...plotting.matplot_dep import kernel_plots
|
from ...plotting.matplot_dep import kernel_plots
|
||||||
kernel_plots.plot(self,*args)
|
kernel_plots.plot(self, x, fignum, ax, title, plot_limits, resolution, **mpl_kwargs)
|
||||||
|
|
||||||
def plot_ARD(self, *args, **kw):
|
def plot_ARD(self, *args, **kw):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -99,7 +99,26 @@ def plot_ARD(kernel, fignum=None, ax=None, title='', legend=False, filtering=Non
|
||||||
return ax
|
return ax
|
||||||
|
|
||||||
|
|
||||||
def plot(kernel, x=None, plot_limits=None, resolution=None, *args, **kwargs):
|
|
||||||
|
def plot(kernel,x=None, fignum=None, ax=None, title=None, plot_limits=None, resolution=None, **mpl_kwargs):
|
||||||
|
"""
|
||||||
|
plot a kernel.
|
||||||
|
:param x: the value to use for the other kernel argument (kernels are a function of two variables!)
|
||||||
|
:param fignum: figure number of the plot
|
||||||
|
:param ax: matplotlib axis to plot on
|
||||||
|
:param title: the matplotlib title
|
||||||
|
:param plot_limits: the range over which to plot the kernel
|
||||||
|
:resolution: the resolution of the lines used in plotting
|
||||||
|
:mpl_kwargs avalid keyword arguments to pass through to matplotlib (e.g. lw=7)
|
||||||
|
"""
|
||||||
|
fig, ax = ax_default(fignum,ax)
|
||||||
|
|
||||||
|
if title is None:
|
||||||
|
ax.set_title('%s kernel' % kernel.name)
|
||||||
|
else:
|
||||||
|
ax.set_title(title)
|
||||||
|
|
||||||
|
|
||||||
if kernel.input_dim == 1:
|
if kernel.input_dim == 1:
|
||||||
if x is None:
|
if x is None:
|
||||||
x = np.zeros((1, 1))
|
x = np.zeros((1, 1))
|
||||||
|
|
@ -117,10 +136,10 @@ def plot(kernel, x=None, plot_limits=None, resolution=None, *args, **kwargs):
|
||||||
|
|
||||||
Xnew = np.linspace(xmin, xmax, resolution or 201)[:, None]
|
Xnew = np.linspace(xmin, xmax, resolution or 201)[:, None]
|
||||||
Kx = kernel.K(Xnew, x)
|
Kx = kernel.K(Xnew, x)
|
||||||
pb.plot(Xnew, Kx, *args, **kwargs)
|
ax.plot(Xnew, Kx, **mpl_kwargs)
|
||||||
pb.xlim(xmin, xmax)
|
ax.set_xlim(xmin, xmax)
|
||||||
pb.xlabel("x")
|
ax.set_xlabel("x")
|
||||||
pb.ylabel("k(x,%0.1f)" % x)
|
ax.set_ylabel("k(x,%0.1f)" % x)
|
||||||
|
|
||||||
elif kernel.input_dim == 2:
|
elif kernel.input_dim == 2:
|
||||||
if x is None:
|
if x is None:
|
||||||
|
|
@ -142,11 +161,11 @@ def plot(kernel, x=None, plot_limits=None, resolution=None, *args, **kwargs):
|
||||||
Xnew = np.vstack((xx.flatten(), yy.flatten())).T
|
Xnew = np.vstack((xx.flatten(), yy.flatten())).T
|
||||||
Kx = kernel.K(Xnew, x)
|
Kx = kernel.K(Xnew, x)
|
||||||
Kx = Kx.reshape(resolution, resolution).T
|
Kx = Kx.reshape(resolution, resolution).T
|
||||||
pb.contour(xx, xx, Kx, vmin=Kx.min(), vmax=Kx.max(), cmap=pb.cm.jet, *args, **kwargs) # @UndefinedVariable
|
ax.contour(xx, xx, Kx, vmin=Kx.min(), vmax=Kx.max(), cmap=pb.cm.jet, **mpl_kwargs) # @UndefinedVariable
|
||||||
pb.xlim(xmin[0], xmax[0])
|
ax.set_xlim(xmin[0], xmax[0])
|
||||||
pb.ylim(xmin[1], xmax[1])
|
ax.set_ylim(xmin[1], xmax[1])
|
||||||
pb.xlabel("x1")
|
ax.set_xlabel("x1")
|
||||||
pb.ylabel("x2")
|
ax.set_ylabel("x2")
|
||||||
pb.title("k(x1,x2 ; %0.1f,%0.1f)" % (x[0, 0], x[0, 1]))
|
ax.set_title("k(x1,x2 ; %0.1f,%0.1f)" % (x[0, 0], x[0, 1]))
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError, "Cannot plot a kernel with more than two input dimensions"
|
raise NotImplementedError, "Cannot plot a kernel with more than two input dimensions"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue