[ard] enhanced ard handling and plotting

Conflicts:
	GPy/kern/_src/linear.py
	GPy/models/ss_gplvm.py
This commit is contained in:
mzwiessele 2014-08-25 09:46:20 -07:00
parent 3972b4bd9a
commit d000893878
8 changed files with 323 additions and 118 deletions

View file

@ -30,18 +30,18 @@ def add_bar_labels(fig, ax, bars, bottom=0):
c = 'k'
transform = transOffsetUp
ax.text(xi, height, "${xi}$".format(xi=int(num)), color=c, rotation=0, ha='center', va=va, transform=transform)
ax.set_xticks([])
def plot_bars(fig, ax, x, ard_params, color, name, bottom=0):
from ...util.misc import param_to_array
return ax.bar(left=x, height=param_to_array(ard_params), width=.8,
bottom=bottom, align='center',
color=color, edgecolor='k', linewidth=1.2,
return ax.bar(left=x, height=param_to_array(ard_params), width=.8,
bottom=bottom, align='center',
color=color, edgecolor='k', linewidth=1.2,
label=name.replace("_"," "))
def plot_ARD(kernel, fignum=None, ax=None, title='', legend=False):
def plot_ARD(kernel, fignum=None, ax=None, title='', legend=False, filtering=None):
"""
If an ARD kernel is present, plot a bar representation using matplotlib
@ -51,6 +51,10 @@ def plot_ARD(kernel, fignum=None, ax=None, title='', legend=False):
title of the plot,
pass '' to not print a title
pass None for a generic title
:param filtering: list of names, which to use for plotting ARD parameters.
Only kernels which match names in the list of names in filtering
will be used for plotting.
:type filtering: list of names to use for ARD plot
"""
fig, ax = ax_default(fignum,ax)
@ -58,19 +62,25 @@ def plot_ARD(kernel, fignum=None, ax=None, title='', legend=False):
ax.set_title('ARD parameters, %s kernel' % kernel.name)
else:
ax.set_title(title)
Tango.reset()
bars = []
ard_params = np.atleast_2d(kernel.input_sensitivity())
ard_params = np.atleast_2d(kernel.input_sensitivity(summarize=False))
bottom = 0
x = np.arange(kernel.input_dim)
if order is None:
order = kernel.parameter_names(recursive=False)
for i in range(ard_params.shape[0]):
c = Tango.nextMedium()
bars.append(plot_bars(fig, ax, x, ard_params[i,:], c, kernel.parameters[i].name, bottom=bottom))
bottom += ard_params[i,:]
if kernel.parameters[i].name in order:
c = Tango.nextMedium()
bars.append(plot_bars(fig, ax, x, ard_params[i,:], c, kernel.parameters[i].name, bottom=bottom))
bottom += ard_params[i,:]
else:
print "filtering out {}".format(kernel.parameters[i].name)
ax.set_xlim(-.5, kernel.input_dim - .5)
add_bar_labels(fig, ax, [bars[-1]], bottom=bottom-ard_params[i,:])