From 06d540f056c6b1489c904a80683d96055ebc8629 Mon Sep 17 00:00:00 2001 From: Max Zwiessele Date: Wed, 17 Jul 2013 17:49:50 +0100 Subject: [PATCH] plot_ARD greatly improved, crossterm plotting enabled --- GPy/kern/kern.py | 48 ++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/GPy/kern/kern.py b/GPy/kern/kern.py index 5cd90749..491f9ed7 100644 --- a/GPy/kern/kern.py +++ b/GPy/kern/kern.py @@ -66,12 +66,26 @@ class kern(Parameterized): Parameterized.setstate(self, state) - def plot_ARD(self, fignum=None, ax=None, title=None): - """If an ARD kernel is present, it bar-plots the ARD parameters""" + def plot_ARD(self, fignum=None, ax=None, title='', legend=False): + """If an ARD kernel is present, it bar-plots the ARD parameters, + :param fignum: figure number of the plot + :param ax: matplotlib axis to plot on + :param title: + title of the plot, + pass '' to not print a title + pass None for a generic title + """ if ax is None: fig = pb.figure(fignum) ax = fig.add_subplot(111) + from GPy.util import Tango + from matplotlib.textpath import TextPath + Tango.reset() + xticklabels = [] + bars = [] + x0 = 0 for p in self.parts: + c = Tango.nextMedium() if hasattr(p, 'ARD') and p.ARD: if title is None: ax.set_title('ARD parameters, %s kernel' % p.name) @@ -82,10 +96,32 @@ class kern(Parameterized): else: ard_params = 1. / p.lengthscale - x = np.arange(len(ard_params)) - ax.bar(x - 0.4, ard_params) - ax.set_xticks(x) - ax.set_xticklabels([r"${}$".format(i) for i in x]) + x = np.arange(x0, x0 + len(ard_params)) + bars.append(ax.bar(x, ard_params, align='center', color=c, edgecolor='k', linewidth=1.2, label=p.name)) + xticklabels.extend([r"$\mathrm{{{name}}}\ {x}$".format(name=p.name, x=i) for i in np.arange(len(ard_params))]) + x0 += len(ard_params) + x = np.arange(x0) + for bar in bars: + for patch, num in zip(bar.patches, np.arange(len(bar.patches))): + height = patch.get_height() + xi = patch.get_x() + patch.get_width() / 2. + va = 'top' + c = 'w' + t = TextPath((0, 0), "${xi}$".format(xi=xi), rotation=0, usetex=True, ha='center') + if patch.get_extents().height <= t.get_extents().height + 2: + va = 'bottom' + c = 'k' + ax.text(xi, height, "${xi}$".format(xi=int(num)), color=c, rotation=0, ha='center', va=va) + # for xi, t in zip(x, xticklabels): + # ax.text(xi, maxi / 2, t, rotation=90, ha='center', va='center') + # ax.set_xticklabels(xticklabels, rotation=17) + ax.set_xticks([]) + ax.set_xlim(-.5, x0 - .5) + if title is '': + ax.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3, + ncol=len(bars), mode="expand", borderaxespad=0.) + else: + ax.legend() return ax def _transform_gradients(self, g):