mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-14 14:32:37 +02:00
ard plotting enhanced
This commit is contained in:
parent
14b8fd0c7d
commit
dfb96a8405
1 changed files with 10 additions and 3 deletions
|
|
@ -7,6 +7,7 @@ from ..core.parameterized import Parameterized
|
||||||
from parts.kernpart import Kernpart
|
from parts.kernpart import Kernpart
|
||||||
import itertools
|
import itertools
|
||||||
from parts.prod import Prod as prod
|
from parts.prod import Prod as prod
|
||||||
|
from matplotlib.transforms import offset_copy
|
||||||
|
|
||||||
class kern(Parameterized):
|
class kern(Parameterized):
|
||||||
def __init__(self, input_dim, parts=[], input_slices=None):
|
def __init__(self, input_dim, parts=[], input_slices=None):
|
||||||
|
|
@ -101,6 +102,8 @@ class kern(Parameterized):
|
||||||
xticklabels.extend([r"$\mathrm{{{name}}}\ {x}$".format(name=p.name, x=i) for i in np.arange(len(ard_params))])
|
xticklabels.extend([r"$\mathrm{{{name}}}\ {x}$".format(name=p.name, x=i) for i in np.arange(len(ard_params))])
|
||||||
x0 += len(ard_params)
|
x0 += len(ard_params)
|
||||||
x = np.arange(x0)
|
x = np.arange(x0)
|
||||||
|
transOffset = offset_copy(ax.transData, fig=fig,
|
||||||
|
x=0., y= -2., units='points')
|
||||||
for bar in bars:
|
for bar in bars:
|
||||||
for patch, num in zip(bar.patches, np.arange(len(bar.patches))):
|
for patch, num in zip(bar.patches, np.arange(len(bar.patches))):
|
||||||
height = patch.get_height()
|
height = patch.get_height()
|
||||||
|
|
@ -111,15 +114,19 @@ class kern(Parameterized):
|
||||||
if patch.get_extents().height <= t.get_extents().height + 2:
|
if patch.get_extents().height <= t.get_extents().height + 2:
|
||||||
va = 'bottom'
|
va = 'bottom'
|
||||||
c = 'k'
|
c = 'k'
|
||||||
ax.text(xi, height, "${xi}$".format(xi=int(num)), color=c, rotation=0, ha='center', va=va)
|
ax.text(xi, height, "${xi}$".format(xi=int(num)), color=c, rotation=0, ha='center', va=va, transform=transOffset)
|
||||||
# for xi, t in zip(x, xticklabels):
|
# for xi, t in zip(x, xticklabels):
|
||||||
# ax.text(xi, maxi / 2, t, rotation=90, ha='center', va='center')
|
# ax.text(xi, maxi / 2, t, rotation=90, ha='center', va='center')
|
||||||
# ax.set_xticklabels(xticklabels, rotation=17)
|
# ax.set_xticklabels(xticklabels, rotation=17)
|
||||||
ax.set_xticks([])
|
ax.set_xticks([])
|
||||||
ax.set_xlim(-.5, x0 - .5)
|
ax.set_xlim(-.5, x0 - .5)
|
||||||
if title is '':
|
if title is '':
|
||||||
ax.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3,
|
mode = 'expand'
|
||||||
ncol=max(2, len(bars)), mode="expand", borderaxespad=0.)
|
if len(bars) > 1:
|
||||||
|
mode = 'expand'
|
||||||
|
ax.legend(bbox_to_anchor=(0., 1.02, 1., 1.02), loc=3,
|
||||||
|
ncol=len(bars), mode=mode, borderaxespad=0.)
|
||||||
|
fig.tight_layout(rect=(0, 0, 1, .9))
|
||||||
else:
|
else:
|
||||||
ax.legend()
|
ax.legend()
|
||||||
return ax
|
return ax
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue