ard plotting enhanced

This commit is contained in:
Max Zwiessele 2013-07-18 14:20:32 +01:00
parent 14b8fd0c7d
commit dfb96a8405

View file

@ -7,6 +7,7 @@ from ..core.parameterized import Parameterized
from parts.kernpart import Kernpart from parts.kernpart import Kernpart
import itertools import itertools
from parts.prod import Prod as prod from parts.prod import Prod as prod
from matplotlib.transforms import offset_copy
class kern(Parameterized): class kern(Parameterized):
def __init__(self, input_dim, parts=[], input_slices=None): def __init__(self, input_dim, parts=[], input_slices=None):
@ -101,6 +102,8 @@ class kern(Parameterized):
xticklabels.extend([r"$\mathrm{{{name}}}\ {x}$".format(name=p.name, x=i) for i in np.arange(len(ard_params))]) xticklabels.extend([r"$\mathrm{{{name}}}\ {x}$".format(name=p.name, x=i) for i in np.arange(len(ard_params))])
x0 += len(ard_params) x0 += len(ard_params)
x = np.arange(x0) x = np.arange(x0)
transOffset = offset_copy(ax.transData, fig=fig,
x=0., y= -2., units='points')
for bar in bars: for bar in bars:
for patch, num in zip(bar.patches, np.arange(len(bar.patches))): for patch, num in zip(bar.patches, np.arange(len(bar.patches))):
height = patch.get_height() height = patch.get_height()
@ -111,15 +114,19 @@ class kern(Parameterized):
if patch.get_extents().height <= t.get_extents().height + 2: if patch.get_extents().height <= t.get_extents().height + 2:
va = 'bottom' va = 'bottom'
c = 'k' c = 'k'
ax.text(xi, height, "${xi}$".format(xi=int(num)), color=c, rotation=0, ha='center', va=va) ax.text(xi, height, "${xi}$".format(xi=int(num)), color=c, rotation=0, ha='center', va=va, transform=transOffset)
# for xi, t in zip(x, xticklabels): # for xi, t in zip(x, xticklabels):
# ax.text(xi, maxi / 2, t, rotation=90, ha='center', va='center') # ax.text(xi, maxi / 2, t, rotation=90, ha='center', va='center')
# ax.set_xticklabels(xticklabels, rotation=17) # ax.set_xticklabels(xticklabels, rotation=17)
ax.set_xticks([]) ax.set_xticks([])
ax.set_xlim(-.5, x0 - .5) ax.set_xlim(-.5, x0 - .5)
if title is '': if title is '':
ax.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3, mode = 'expand'
ncol=max(2, len(bars)), mode="expand", borderaxespad=0.) if len(bars) > 1:
mode = 'expand'
ax.legend(bbox_to_anchor=(0., 1.02, 1., 1.02), loc=3,
ncol=len(bars), mode=mode, borderaxespad=0.)
fig.tight_layout(rect=(0, 0, 1, .9))
else: else:
ax.legend() ax.legend()
return ax return ax