mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-07 11:02:38 +02:00
ard plotting
This commit is contained in:
parent
2f3e0611f8
commit
51dca0fcbc
4 changed files with 41 additions and 23 deletions
|
|
@ -196,9 +196,9 @@ class Add(Kern):
|
||||||
kernel_plots.plot(self,*args)
|
kernel_plots.plot(self,*args)
|
||||||
|
|
||||||
def input_sensitivity(self):
|
def input_sensitivity(self):
|
||||||
in_sen = np.zeros((self.input_dim, self.num_params))
|
in_sen = np.zeros((self.num_params, self.input_dim))
|
||||||
for i, [p, i_s] in enumerate(zip(self._parameters_, self.input_slices)):
|
for i, [p, i_s] in enumerate(zip(self._parameters_, self.input_slices)):
|
||||||
in_sen[i_s, i] = p.input_sensitivity()
|
in_sen[i, i_s] = p.input_sensitivity()
|
||||||
return in_sen
|
return in_sen
|
||||||
|
|
||||||
def _getstate(self):
|
def _getstate(self):
|
||||||
|
|
|
||||||
|
|
@ -62,9 +62,10 @@ class Kern(Parameterized):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def plot_ARD(self, *args, **kw):
|
def plot_ARD(self, *args, **kw):
|
||||||
if "matplotlib" in sys.modules:
|
"""
|
||||||
from ...plotting.matplot_dep import kernel_plots
|
See :class:`~GPy.plotting.matplot_dep.kernel_plots`
|
||||||
self.plot_ARD.__doc__ += kernel_plots.plot_ARD.__doc__
|
"""
|
||||||
|
import sys
|
||||||
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
||||||
from ...plotting.matplot_dep import kernel_plots
|
from ...plotting.matplot_dep import kernel_plots
|
||||||
return kernel_plots.plot_ARD(self,*args,**kw)
|
return kernel_plots.plot_ARD(self,*args,**kw)
|
||||||
|
|
|
||||||
|
|
@ -6,27 +6,48 @@ import Tango
|
||||||
import pylab as pb
|
import pylab as pb
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
def gpplot(x,mu,lower,upper,edgecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue'],axes=None,**kwargs):
|
def ax_default(fignum, ax):
|
||||||
if axes is None:
|
if ax is None:
|
||||||
axes = pb.gca()
|
fig = pb.figure(fignum)
|
||||||
|
ax = fig.add_subplot(111)
|
||||||
|
else:
|
||||||
|
fig = ax.figure
|
||||||
|
return fig, ax
|
||||||
|
|
||||||
|
def meanplot(x, mu, color=Tango.colorsHex['darkBlue'], ax=None, fignum=None, linewidth=2,**kw):
|
||||||
|
_, axes = ax_default(fignum, ax)
|
||||||
|
#here's the mean
|
||||||
|
return axes.plot(x,mu,color=color,linewidth=linewidth,**kw)
|
||||||
|
|
||||||
|
def gpplot(x,mu,lower,upper,edgecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue'],ax=None,fignum=None,xlabel='x',ylabel='y',**kwargs):
|
||||||
|
_, axes = ax_default(ax, fignum)
|
||||||
|
|
||||||
mu = mu.flatten()
|
mu = mu.flatten()
|
||||||
x = x.flatten()
|
x = x.flatten()
|
||||||
lower = lower.flatten()
|
lower = lower.flatten()
|
||||||
upper = upper.flatten()
|
upper = upper.flatten()
|
||||||
|
|
||||||
|
plots = []
|
||||||
|
|
||||||
#here's the mean
|
#here's the mean
|
||||||
axes.plot(x,mu,color=edgecol,linewidth=2)
|
plots.append(meanplot(x, mu, edgecol, axes))
|
||||||
|
|
||||||
#here's the box
|
#here's the box
|
||||||
kwargs['linewidth']=0.5
|
kwargs['linewidth']=0.5
|
||||||
if not 'alpha' in kwargs.keys():
|
if not 'alpha' in kwargs.keys():
|
||||||
kwargs['alpha'] = 0.3
|
kwargs['alpha'] = 0.3
|
||||||
axes.fill(np.hstack((x,x[::-1])),np.hstack((upper,lower[::-1])),color=fillcol,**kwargs)
|
plots.append(axes.fill(np.hstack((x,x[::-1])),np.hstack((upper,lower[::-1])),color=fillcol,**kwargs))
|
||||||
|
|
||||||
#this is the edge:
|
#this is the edge:
|
||||||
axes.plot(x,upper,color=edgecol,linewidth=0.2)
|
plots.append(meanplot(x, upper,color=edgecol,linewidth=0.2,axes=axes))
|
||||||
axes.plot(x,lower,color=edgecol,linewidth=0.2)
|
plots.append(meanplot(x, lower,color=edgecol,linewidth=0.2,axes=axes))
|
||||||
|
|
||||||
|
axes.set_xlabel(xlabel)
|
||||||
|
axes.set_ylabel(ylabel)
|
||||||
|
|
||||||
|
return plots
|
||||||
|
|
||||||
|
|
||||||
def removeRightTicks(ax=None):
|
def removeRightTicks(ax=None):
|
||||||
ax = ax or pb.gca()
|
ax = ax or pb.gca()
|
||||||
for i, line in enumerate(ax.get_yticklines()):
|
for i, line in enumerate(ax.get_yticklines()):
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import pylab as pb
|
||||||
import Tango
|
import Tango
|
||||||
from matplotlib.textpath import TextPath
|
from matplotlib.textpath import TextPath
|
||||||
from matplotlib.transforms import offset_copy
|
from matplotlib.transforms import offset_copy
|
||||||
from ...kern import Linear
|
from .base_plots import ax_default
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -52,11 +52,7 @@ def plot_ARD(kernel, fignum=None, ax=None, title='', legend=False):
|
||||||
pass '' to not print a title
|
pass '' to not print a title
|
||||||
pass None for a generic title
|
pass None for a generic title
|
||||||
"""
|
"""
|
||||||
if ax is None:
|
fig, ax = ax_default(fignum,ax)
|
||||||
fig = pb.figure(fignum)
|
|
||||||
ax = fig.add_subplot(111)
|
|
||||||
else:
|
|
||||||
fig = ax.figure
|
|
||||||
|
|
||||||
if title is None:
|
if title is None:
|
||||||
ax.set_title('ARD parameters, %s kernel' % kernel.name)
|
ax.set_title('ARD parameters, %s kernel' % kernel.name)
|
||||||
|
|
@ -70,13 +66,13 @@ def plot_ARD(kernel, fignum=None, ax=None, title='', legend=False):
|
||||||
bottom = 0
|
bottom = 0
|
||||||
x = np.arange(kernel.input_dim)
|
x = np.arange(kernel.input_dim)
|
||||||
|
|
||||||
for i in range(ard_params.shape[-1]):
|
for i in range(ard_params.shape[0]):
|
||||||
c = Tango.nextMedium()
|
c = Tango.nextMedium()
|
||||||
bars.append(plot_bars(fig, ax, x, ard_params[:,i], c, kernel._parameters_[i].name, bottom=bottom))
|
bars.append(plot_bars(fig, ax, x, ard_params[i,:], c, kernel._parameters_[i].name, bottom=bottom))
|
||||||
bottom += ard_params[:,i]
|
bottom += ard_params[i,:]
|
||||||
|
|
||||||
ax.set_xlim(-.5, kernel.input_dim - .5)
|
ax.set_xlim(-.5, kernel.input_dim - .5)
|
||||||
add_bar_labels(fig, ax, [bars[-1]], bottom=bottom-ard_params[:,i])
|
add_bar_labels(fig, ax, [bars[-1]], bottom=bottom-ard_params[i,:])
|
||||||
|
|
||||||
if legend:
|
if legend:
|
||||||
if title is '':
|
if title is '':
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue