mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-06-02 14:45:15 +02:00
[ard] enhanced ard handling and plotting
Conflicts: GPy/kern/_src/linear.py GPy/models/ss_gplvm.py
This commit is contained in:
parent
3972b4bd9a
commit
d000893878
8 changed files with 323 additions and 118 deletions
|
|
@ -30,18 +30,18 @@ def add_bar_labels(fig, ax, bars, bottom=0):
|
|||
c = 'k'
|
||||
transform = transOffsetUp
|
||||
ax.text(xi, height, "${xi}$".format(xi=int(num)), color=c, rotation=0, ha='center', va=va, transform=transform)
|
||||
|
||||
|
||||
ax.set_xticks([])
|
||||
|
||||
|
||||
def plot_bars(fig, ax, x, ard_params, color, name, bottom=0):
|
||||
from ...util.misc import param_to_array
|
||||
return ax.bar(left=x, height=param_to_array(ard_params), width=.8,
|
||||
bottom=bottom, align='center',
|
||||
color=color, edgecolor='k', linewidth=1.2,
|
||||
return ax.bar(left=x, height=param_to_array(ard_params), width=.8,
|
||||
bottom=bottom, align='center',
|
||||
color=color, edgecolor='k', linewidth=1.2,
|
||||
label=name.replace("_"," "))
|
||||
|
||||
def plot_ARD(kernel, fignum=None, ax=None, title='', legend=False):
|
||||
def plot_ARD(kernel, fignum=None, ax=None, title='', legend=False, filtering=None):
|
||||
"""
|
||||
If an ARD kernel is present, plot a bar representation using matplotlib
|
||||
|
||||
|
|
@ -51,6 +51,10 @@ def plot_ARD(kernel, fignum=None, ax=None, title='', legend=False):
|
|||
title of the plot,
|
||||
pass '' to not print a title
|
||||
pass None for a generic title
|
||||
:param filtering: list of names, which to use for plotting ARD parameters.
|
||||
Only kernels which match names in the list of names in filtering
|
||||
will be used for plotting.
|
||||
:type filtering: list of names to use for ARD plot
|
||||
"""
|
||||
fig, ax = ax_default(fignum,ax)
|
||||
|
||||
|
|
@ -58,19 +62,25 @@ def plot_ARD(kernel, fignum=None, ax=None, title='', legend=False):
|
|||
ax.set_title('ARD parameters, %s kernel' % kernel.name)
|
||||
else:
|
||||
ax.set_title(title)
|
||||
|
||||
|
||||
Tango.reset()
|
||||
bars = []
|
||||
|
||||
ard_params = np.atleast_2d(kernel.input_sensitivity())
|
||||
|
||||
ard_params = np.atleast_2d(kernel.input_sensitivity(summarize=False))
|
||||
bottom = 0
|
||||
x = np.arange(kernel.input_dim)
|
||||
|
||||
|
||||
if order is None:
|
||||
order = kernel.parameter_names(recursive=False)
|
||||
|
||||
for i in range(ard_params.shape[0]):
|
||||
c = Tango.nextMedium()
|
||||
bars.append(plot_bars(fig, ax, x, ard_params[i,:], c, kernel.parameters[i].name, bottom=bottom))
|
||||
bottom += ard_params[i,:]
|
||||
|
||||
if kernel.parameters[i].name in order:
|
||||
c = Tango.nextMedium()
|
||||
bars.append(plot_bars(fig, ax, x, ard_params[i,:], c, kernel.parameters[i].name, bottom=bottom))
|
||||
bottom += ard_params[i,:]
|
||||
else:
|
||||
print "filtering out {}".format(kernel.parameters[i].name)
|
||||
|
||||
ax.set_xlim(-.5, kernel.input_dim - .5)
|
||||
add_bar_labels(fig, ax, [bars[-1]], bottom=bottom-ard_params[i,:])
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue