diff --git a/GPy/kern/_src/kern.py b/GPy/kern/_src/kern.py index 506d32e6..57b2bff5 100644 --- a/GPy/kern/_src/kern.py +++ b/GPy/kern/_src/kern.py @@ -132,13 +132,20 @@ class Kern(Parameterized): """ raise NotImplementedError - def plot(self, *args, **kwargs): + def plot(self, x=None, fignum=None, ax=None, title=None, plot_limits=None, resolution=None, **mpl_kwargs): """ - See GPy.plotting.matplot_dep.plot + plot this kernel. + :param x: the value to use for the other kernel argument (kernels are a function of two variables!) + :param fignum: figure number of the plot + :param ax: matplotlib axis to plot on + :param title: the matplotlib title + :param plot_limits: the range over which to plot the kernel + :resolution: the resolution of the lines used in plotting + :mpl_kwargs avalid keyword arguments to pass through to matplotlib (e.g. lw=7) """ assert "matplotlib" in sys.modules, "matplotlib package has not been imported." from ...plotting.matplot_dep import kernel_plots - kernel_plots.plot(self,*args) + kernel_plots.plot(self, x, fignum, ax, title, plot_limits, resolution, **mpl_kwargs) def plot_ARD(self, *args, **kw): """ diff --git a/GPy/plotting/matplot_dep/kernel_plots.py b/GPy/plotting/matplot_dep/kernel_plots.py index c2bd7d38..1e15f224 100644 --- a/GPy/plotting/matplot_dep/kernel_plots.py +++ b/GPy/plotting/matplot_dep/kernel_plots.py @@ -99,7 +99,26 @@ def plot_ARD(kernel, fignum=None, ax=None, title='', legend=False, filtering=Non return ax -def plot(kernel, x=None, plot_limits=None, resolution=None, *args, **kwargs): + +def plot(kernel,x=None, fignum=None, ax=None, title=None, plot_limits=None, resolution=None, **mpl_kwargs): + """ + plot a kernel. + :param x: the value to use for the other kernel argument (kernels are a function of two variables!) + :param fignum: figure number of the plot + :param ax: matplotlib axis to plot on + :param title: the matplotlib title + :param plot_limits: the range over which to plot the kernel + :resolution: the resolution of the lines used in plotting + :mpl_kwargs avalid keyword arguments to pass through to matplotlib (e.g. lw=7) + """ + fig, ax = ax_default(fignum,ax) + + if title is None: + ax.set_title('%s kernel' % kernel.name) + else: + ax.set_title(title) + + if kernel.input_dim == 1: if x is None: x = np.zeros((1, 1)) @@ -117,10 +136,10 @@ def plot(kernel, x=None, plot_limits=None, resolution=None, *args, **kwargs): Xnew = np.linspace(xmin, xmax, resolution or 201)[:, None] Kx = kernel.K(Xnew, x) - pb.plot(Xnew, Kx, *args, **kwargs) - pb.xlim(xmin, xmax) - pb.xlabel("x") - pb.ylabel("k(x,%0.1f)" % x) + ax.plot(Xnew, Kx, **mpl_kwargs) + ax.set_xlim(xmin, xmax) + ax.set_xlabel("x") + ax.set_ylabel("k(x,%0.1f)" % x) elif kernel.input_dim == 2: if x is None: @@ -142,11 +161,11 @@ def plot(kernel, x=None, plot_limits=None, resolution=None, *args, **kwargs): Xnew = np.vstack((xx.flatten(), yy.flatten())).T Kx = kernel.K(Xnew, x) Kx = Kx.reshape(resolution, resolution).T - pb.contour(xx, xx, Kx, vmin=Kx.min(), vmax=Kx.max(), cmap=pb.cm.jet, *args, **kwargs) # @UndefinedVariable - pb.xlim(xmin[0], xmax[0]) - pb.ylim(xmin[1], xmax[1]) - pb.xlabel("x1") - pb.ylabel("x2") - pb.title("k(x1,x2 ; %0.1f,%0.1f)" % (x[0, 0], x[0, 1])) + ax.contour(xx, xx, Kx, vmin=Kx.min(), vmax=Kx.max(), cmap=pb.cm.jet, **mpl_kwargs) # @UndefinedVariable + ax.set_xlim(xmin[0], xmax[0]) + ax.set_ylim(xmin[1], xmax[1]) + ax.set_xlabel("x1") + ax.set_ylabel("x2") + ax.set_title("k(x1,x2 ; %0.1f,%0.1f)" % (x[0, 0], x[0, 1])) else: raise NotImplementedError, "Cannot plot a kernel with more than two input dimensions"