From 51dca0fcbce9e9dec050ab6200d00ea2287525d0 Mon Sep 17 00:00:00 2001 From: Max Zwiessele Date: Wed, 26 Feb 2014 08:21:14 +0000 Subject: [PATCH] ard plotting --- GPy/kern/_src/add.py | 4 +-- GPy/kern/_src/kern.py | 7 +++-- GPy/plotting/matplot_dep/base_plots.py | 37 +++++++++++++++++++----- GPy/plotting/matplot_dep/kernel_plots.py | 16 ++++------ 4 files changed, 41 insertions(+), 23 deletions(-) diff --git a/GPy/kern/_src/add.py b/GPy/kern/_src/add.py index 45800dbf..a91a9c9e 100644 --- a/GPy/kern/_src/add.py +++ b/GPy/kern/_src/add.py @@ -196,9 +196,9 @@ class Add(Kern): kernel_plots.plot(self,*args) def input_sensitivity(self): - in_sen = np.zeros((self.input_dim, self.num_params)) + in_sen = np.zeros((self.num_params, self.input_dim)) for i, [p, i_s] in enumerate(zip(self._parameters_, self.input_slices)): - in_sen[i_s, i] = p.input_sensitivity() + in_sen[i, i_s] = p.input_sensitivity() return in_sen def _getstate(self): diff --git a/GPy/kern/_src/kern.py b/GPy/kern/_src/kern.py index 1eec7af5..9f7dcb0a 100644 --- a/GPy/kern/_src/kern.py +++ b/GPy/kern/_src/kern.py @@ -62,9 +62,10 @@ class Kern(Parameterized): raise NotImplementedError def plot_ARD(self, *args, **kw): - if "matplotlib" in sys.modules: - from ...plotting.matplot_dep import kernel_plots - self.plot_ARD.__doc__ += kernel_plots.plot_ARD.__doc__ + """ + See :class:`~GPy.plotting.matplot_dep.kernel_plots` + """ + import sys assert "matplotlib" in sys.modules, "matplotlib package has not been imported." from ...plotting.matplot_dep import kernel_plots return kernel_plots.plot_ARD(self,*args,**kw) diff --git a/GPy/plotting/matplot_dep/base_plots.py b/GPy/plotting/matplot_dep/base_plots.py index d5d4d6ee..a9d25223 100644 --- a/GPy/plotting/matplot_dep/base_plots.py +++ b/GPy/plotting/matplot_dep/base_plots.py @@ -6,27 +6,48 @@ import Tango import pylab as pb import numpy as np -def gpplot(x,mu,lower,upper,edgecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue'],axes=None,**kwargs): - if axes is None: - axes = pb.gca() +def ax_default(fignum, ax): + if ax is None: + fig = pb.figure(fignum) + ax = fig.add_subplot(111) + else: + fig = ax.figure + return fig, ax + +def meanplot(x, mu, color=Tango.colorsHex['darkBlue'], ax=None, fignum=None, linewidth=2,**kw): + _, axes = ax_default(fignum, ax) + #here's the mean + return axes.plot(x,mu,color=color,linewidth=linewidth,**kw) + +def gpplot(x,mu,lower,upper,edgecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue'],ax=None,fignum=None,xlabel='x',ylabel='y',**kwargs): + _, axes = ax_default(ax, fignum) + mu = mu.flatten() x = x.flatten() lower = lower.flatten() upper = upper.flatten() + plots = [] + #here's the mean - axes.plot(x,mu,color=edgecol,linewidth=2) + plots.append(meanplot(x, mu, edgecol, axes)) #here's the box kwargs['linewidth']=0.5 if not 'alpha' in kwargs.keys(): kwargs['alpha'] = 0.3 - axes.fill(np.hstack((x,x[::-1])),np.hstack((upper,lower[::-1])),color=fillcol,**kwargs) + plots.append(axes.fill(np.hstack((x,x[::-1])),np.hstack((upper,lower[::-1])),color=fillcol,**kwargs)) #this is the edge: - axes.plot(x,upper,color=edgecol,linewidth=0.2) - axes.plot(x,lower,color=edgecol,linewidth=0.2) - + plots.append(meanplot(x, upper,color=edgecol,linewidth=0.2,axes=axes)) + plots.append(meanplot(x, lower,color=edgecol,linewidth=0.2,axes=axes)) + + axes.set_xlabel(xlabel) + axes.set_ylabel(ylabel) + + return plots + + def removeRightTicks(ax=None): ax = ax or pb.gca() for i, line in enumerate(ax.get_yticklines()): diff --git a/GPy/plotting/matplot_dep/kernel_plots.py b/GPy/plotting/matplot_dep/kernel_plots.py index 6d4a7f0f..b55a0e53 100644 --- a/GPy/plotting/matplot_dep/kernel_plots.py +++ b/GPy/plotting/matplot_dep/kernel_plots.py @@ -6,7 +6,7 @@ import pylab as pb import Tango from matplotlib.textpath import TextPath from matplotlib.transforms import offset_copy -from ...kern import Linear +from .base_plots import ax_default @@ -52,11 +52,7 @@ def plot_ARD(kernel, fignum=None, ax=None, title='', legend=False): pass '' to not print a title pass None for a generic title """ - if ax is None: - fig = pb.figure(fignum) - ax = fig.add_subplot(111) - else: - fig = ax.figure + fig, ax = ax_default(fignum,ax) if title is None: ax.set_title('ARD parameters, %s kernel' % kernel.name) @@ -70,13 +66,13 @@ def plot_ARD(kernel, fignum=None, ax=None, title='', legend=False): bottom = 0 x = np.arange(kernel.input_dim) - for i in range(ard_params.shape[-1]): + for i in range(ard_params.shape[0]): c = Tango.nextMedium() - bars.append(plot_bars(fig, ax, x, ard_params[:,i], c, kernel._parameters_[i].name, bottom=bottom)) - bottom += ard_params[:,i] + bars.append(plot_bars(fig, ax, x, ard_params[i,:], c, kernel._parameters_[i].name, bottom=bottom)) + bottom += ard_params[i,:] ax.set_xlim(-.5, kernel.input_dim - .5) - add_bar_labels(fig, ax, [bars[-1]], bottom=bottom-ard_params[:,i]) + add_bar_labels(fig, ax, [bars[-1]], bottom=bottom-ard_params[i,:]) if legend: if title is '':