mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-04 09:12:38 +02:00
82 lines
No EOL
3.4 KiB
Python
82 lines
No EOL
3.4 KiB
Python
#===============================================================================
|
|
# Copyright (c) 2015, Max Zwiessele
|
|
# All rights reserved.
|
|
#
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
#
|
|
# * Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
#
|
|
# * Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
#
|
|
# * Neither the name of GPy.plotting.gpy_plot.kernel_plots nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
#
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
#===============================================================================
|
|
import numpy as np
|
|
from . import pl
|
|
from .. import Tango
|
|
|
|
def plot_ARD(kernel, filtering=None, **kwargs):
|
|
"""
|
|
If an ARD kernel is present, plot a bar representation using matplotlib
|
|
|
|
:param fignum: figure number of the plot
|
|
:param filtering: list of names, which to use for plotting ARD parameters.
|
|
Only kernels which match names in the list of names in filtering
|
|
will be used for plotting.
|
|
:type filtering: list of names to use for ARD plot
|
|
"""
|
|
canvas, kwargs = pl.get_new_canvas(kwargs)
|
|
|
|
Tango.reset()
|
|
|
|
bars = []
|
|
ard_params = np.atleast_2d(kernel.input_sensitivity(summarize=False))
|
|
bottom = 0
|
|
last_bottom = bottom
|
|
|
|
x = np.arange(kernel.input_dim)
|
|
|
|
if filtering is None:
|
|
filtering = kernel.parameter_names(recursive=False)
|
|
|
|
for i in range(ard_params.shape[0]):
|
|
if kernel.parameters[i].name in filtering:
|
|
c = Tango.nextMedium()
|
|
bars.append(pl.barplot(canvas, x, ard_params[i,:], color=c, label=kernel.parameters[i].name, bottom=bottom))
|
|
last_bottom = ard_params[i,:]
|
|
bottom += last_bottom
|
|
else:
|
|
print("filtering out {}".format(kernel.parameters[i].name))
|
|
|
|
plt.show_canvas()
|
|
ax.set_xlim(-.5, kernel.input_dim - .5)
|
|
add_bar_labels(fig, ax, [bars[-1]], bottom=bottom-last_bottom)
|
|
|
|
if legend:
|
|
if title is '':
|
|
mode = 'expand'
|
|
if len(bars) > 1:
|
|
mode = 'expand'
|
|
ax.legend(bbox_to_anchor=(0., 1.02, 1., 1.02), loc=3,
|
|
ncol=len(bars), mode=mode, borderaxespad=0.)
|
|
fig.tight_layout(rect=(0, 0, 1, .9))
|
|
else:
|
|
ax.legend()
|
|
|
|
return dict(barplots=bars) |