mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-06 02:24:17 +02:00
input_sensitivity and ard plotting
This commit is contained in:
parent
d90d67a8c1
commit
b200b9fa90
11 changed files with 108 additions and 83 deletions
|
|
@ -9,8 +9,41 @@ from matplotlib.transforms import offset_copy
|
|||
from ...kern import Linear
|
||||
|
||||
|
||||
|
||||
def add_bar_labels(fig, ax, bars, bottom=0):
|
||||
transOffset = offset_copy(ax.transData, fig=fig,
|
||||
x=0., y= -2., units='points')
|
||||
transOffsetUp = offset_copy(ax.transData, fig=fig,
|
||||
x=0., y=1., units='points')
|
||||
for bar in bars:
|
||||
for i, [patch, num] in enumerate(zip(bar.patches, np.arange(len(bar.patches)))):
|
||||
if len(bottom) == len(bar): b = bottom[i]
|
||||
else: b = bottom
|
||||
height = patch.get_height() + b
|
||||
xi = patch.get_x() + patch.get_width() / 2.
|
||||
va = 'top'
|
||||
c = 'w'
|
||||
t = TextPath((0, 0), "${xi}$".format(xi=xi), rotation=0, usetex=True, ha='center')
|
||||
transform = transOffset
|
||||
if patch.get_extents().height <= t.get_extents().height + 3:
|
||||
va = 'bottom'
|
||||
c = 'k'
|
||||
transform = transOffsetUp
|
||||
ax.text(xi, height, "${xi}$".format(xi=int(num)), color=c, rotation=0, ha='center', va=va, transform=transform)
|
||||
|
||||
ax.set_xticks([])
|
||||
|
||||
|
||||
def plot_bars(fig, ax, x, ard_params, color, name, bottom=0):
|
||||
from ...util.misc import param_to_array
|
||||
return ax.bar(left=x, height=param_to_array(ard_params), width=.8,
|
||||
bottom=bottom, align='center',
|
||||
color=color, edgecolor='k', linewidth=1.2,
|
||||
label=name.replace("_"," "))
|
||||
|
||||
def plot_ARD(kernel, fignum=None, ax=None, title='', legend=False):
|
||||
"""If an ARD kernel is present, plot a bar representation using matplotlib
|
||||
"""
|
||||
If an ARD kernel is present, plot a bar representation using matplotlib
|
||||
|
||||
:param fignum: figure number of the plot
|
||||
:param ax: matplotlib axis to plot on
|
||||
|
|
@ -24,50 +57,27 @@ def plot_ARD(kernel, fignum=None, ax=None, title='', legend=False):
|
|||
ax = fig.add_subplot(111)
|
||||
else:
|
||||
fig = ax.figure
|
||||
|
||||
if title is None:
|
||||
ax.set_title('ARD parameters, %s kernel' % kernel.name)
|
||||
else:
|
||||
ax.set_title(title)
|
||||
|
||||
Tango.reset()
|
||||
xticklabels = []
|
||||
bars = []
|
||||
x0 = 0
|
||||
#for p in kernel._parameters_:
|
||||
p = kernel
|
||||
c = Tango.nextMedium()
|
||||
if hasattr(p, 'ARD') and p.ARD:
|
||||
if title is None:
|
||||
ax.set_title('ARD parameters, %s kernel' % p.name)
|
||||
else:
|
||||
ax.set_title(title)
|
||||
if isinstance(p, Linear):
|
||||
ard_params = p.variances
|
||||
else:
|
||||
ard_params = 1. / p.lengthscale
|
||||
x = np.arange(x0, x0 + len(ard_params))
|
||||
from ...util.misc import param_to_array
|
||||
bars.append(ax.bar(x, param_to_array(ard_params), align='center', color=c, edgecolor='k', linewidth=1.2, label=p.name.replace("_"," ")))
|
||||
xticklabels.extend([r"$\mathrm{{{name}}}\ {x}$".format(name=p.name, x=i) for i in np.arange(len(ard_params))])
|
||||
x0 += len(ard_params)
|
||||
x = np.arange(x0)
|
||||
transOffset = offset_copy(ax.transData, fig=fig,
|
||||
x=0., y= -2., units='points')
|
||||
transOffsetUp = offset_copy(ax.transData, fig=fig,
|
||||
x=0., y=1., units='points')
|
||||
for bar in bars:
|
||||
for patch, num in zip(bar.patches, np.arange(len(bar.patches))):
|
||||
height = patch.get_height()
|
||||
xi = patch.get_x() + patch.get_width() / 2.
|
||||
va = 'top'
|
||||
c = 'w'
|
||||
t = TextPath((0, 0), "${xi}$".format(xi=xi), rotation=0, usetex=True, ha='center')
|
||||
transform = transOffset
|
||||
if patch.get_extents().height <= t.get_extents().height + 3:
|
||||
va = 'bottom'
|
||||
c = 'k'
|
||||
transform = transOffsetUp
|
||||
ax.text(xi, height, "${xi}$".format(xi=int(num)), color=c, rotation=0, ha='center', va=va, transform=transform)
|
||||
# for xi, t in zip(x, xticklabels):
|
||||
# ax.text(xi, maxi / 2, t, rotation=90, ha='center', va='center')
|
||||
# ax.set_xticklabels(xticklabels, rotation=17)
|
||||
ax.set_xticks([])
|
||||
ax.set_xlim(-.5, x0 - .5)
|
||||
|
||||
ard_params = np.atleast_2d(kernel.input_sensitivity())
|
||||
bottom = 0
|
||||
x = np.arange(kernel.input_dim)
|
||||
|
||||
for i in range(ard_params.shape[-1]):
|
||||
c = Tango.nextMedium()
|
||||
bars.append(plot_bars(fig, ax, x, ard_params[:,i], c, kernel._parameters_[i].name, bottom=bottom))
|
||||
bottom += ard_params[:,i]
|
||||
|
||||
ax.set_xlim(-.5, kernel.input_dim - .5)
|
||||
add_bar_labels(fig, ax, [bars[-1]], bottom=bottom-ard_params[:,i])
|
||||
|
||||
if legend:
|
||||
if title is '':
|
||||
mode = 'expand'
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue