diff --git a/GPy/kern/kern.py b/GPy/kern/kern.py index a09a94f3..8db28e0a 100644 --- a/GPy/kern/kern.py +++ b/GPy/kern/kern.py @@ -7,6 +7,7 @@ from ..core.parameterized import Parameterized from parts.kernpart import Kernpart import itertools from parts.prod import Prod as prod +from matplotlib.transforms import offset_copy class kern(Parameterized): def __init__(self, input_dim, parts=[], input_slices=None): @@ -101,6 +102,8 @@ class kern(Parameterized): xticklabels.extend([r"$\mathrm{{{name}}}\ {x}$".format(name=p.name, x=i) for i in np.arange(len(ard_params))]) x0 += len(ard_params) x = np.arange(x0) + transOffset = offset_copy(ax.transData, fig=fig, + x=0., y= -2., units='points') for bar in bars: for patch, num in zip(bar.patches, np.arange(len(bar.patches))): height = patch.get_height() @@ -111,15 +114,19 @@ class kern(Parameterized): if patch.get_extents().height <= t.get_extents().height + 2: va = 'bottom' c = 'k' - ax.text(xi, height, "${xi}$".format(xi=int(num)), color=c, rotation=0, ha='center', va=va) + ax.text(xi, height, "${xi}$".format(xi=int(num)), color=c, rotation=0, ha='center', va=va, transform=transOffset) # for xi, t in zip(x, xticklabels): # ax.text(xi, maxi / 2, t, rotation=90, ha='center', va='center') # ax.set_xticklabels(xticklabels, rotation=17) ax.set_xticks([]) ax.set_xlim(-.5, x0 - .5) if title is '': - ax.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3, - ncol=max(2, len(bars)), mode="expand", borderaxespad=0.) + mode = 'expand' + if len(bars) > 1: + mode = 'expand' + ax.legend(bbox_to_anchor=(0., 1.02, 1., 1.02), loc=3, + ncol=len(bars), mode=mode, borderaxespad=0.) + fig.tight_layout(rect=(0, 0, 1, .9)) else: ax.legend() return ax