plot_ARD greatly improved, crossterm plotting enabled

This commit is contained in:
Max Zwiessele 2013-07-17 17:49:50 +01:00
parent d2528a811c
commit 06d540f056

View file

@ -66,12 +66,26 @@ class kern(Parameterized):
Parameterized.setstate(self, state)
def plot_ARD(self, fignum=None, ax=None, title=None):
"""If an ARD kernel is present, it bar-plots the ARD parameters"""
def plot_ARD(self, fignum=None, ax=None, title='', legend=False):
"""If an ARD kernel is present, it bar-plots the ARD parameters,
:param fignum: figure number of the plot
:param ax: matplotlib axis to plot on
:param title:
title of the plot,
pass '' to not print a title
pass None for a generic title
"""
if ax is None:
fig = pb.figure(fignum)
ax = fig.add_subplot(111)
from GPy.util import Tango
from matplotlib.textpath import TextPath
Tango.reset()
xticklabels = []
bars = []
x0 = 0
for p in self.parts:
c = Tango.nextMedium()
if hasattr(p, 'ARD') and p.ARD:
if title is None:
ax.set_title('ARD parameters, %s kernel' % p.name)
@ -82,10 +96,32 @@ class kern(Parameterized):
else:
ard_params = 1. / p.lengthscale
x = np.arange(len(ard_params))
ax.bar(x - 0.4, ard_params)
ax.set_xticks(x)
ax.set_xticklabels([r"${}$".format(i) for i in x])
x = np.arange(x0, x0 + len(ard_params))
bars.append(ax.bar(x, ard_params, align='center', color=c, edgecolor='k', linewidth=1.2, label=p.name))
xticklabels.extend([r"$\mathrm{{{name}}}\ {x}$".format(name=p.name, x=i) for i in np.arange(len(ard_params))])
x0 += len(ard_params)
x = np.arange(x0)
for bar in bars:
for patch, num in zip(bar.patches, np.arange(len(bar.patches))):
height = patch.get_height()
xi = patch.get_x() + patch.get_width() / 2.
va = 'top'
c = 'w'
t = TextPath((0, 0), "${xi}$".format(xi=xi), rotation=0, usetex=True, ha='center')
if patch.get_extents().height <= t.get_extents().height + 2:
va = 'bottom'
c = 'k'
ax.text(xi, height, "${xi}$".format(xi=int(num)), color=c, rotation=0, ha='center', va=va)
# for xi, t in zip(x, xticklabels):
# ax.text(xi, maxi / 2, t, rotation=90, ha='center', va='center')
# ax.set_xticklabels(xticklabels, rotation=17)
ax.set_xticks([])
ax.set_xlim(-.5, x0 - .5)
if title is '':
ax.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3,
ncol=len(bars), mode="expand", borderaxespad=0.)
else:
ax.legend()
return ax
def _transform_gradients(self, g):