ard plotting

This commit is contained in:
Max Zwiessele 2014-02-26 08:21:14 +00:00
parent 2f3e0611f8
commit 51dca0fcbc
4 changed files with 41 additions and 23 deletions

View file

@ -196,9 +196,9 @@ class Add(Kern):
kernel_plots.plot(self,*args) kernel_plots.plot(self,*args)
def input_sensitivity(self): def input_sensitivity(self):
in_sen = np.zeros((self.input_dim, self.num_params)) in_sen = np.zeros((self.num_params, self.input_dim))
for i, [p, i_s] in enumerate(zip(self._parameters_, self.input_slices)): for i, [p, i_s] in enumerate(zip(self._parameters_, self.input_slices)):
in_sen[i_s, i] = p.input_sensitivity() in_sen[i, i_s] = p.input_sensitivity()
return in_sen return in_sen
def _getstate(self): def _getstate(self):

View file

@ -62,9 +62,10 @@ class Kern(Parameterized):
raise NotImplementedError raise NotImplementedError
def plot_ARD(self, *args, **kw): def plot_ARD(self, *args, **kw):
if "matplotlib" in sys.modules: """
from ...plotting.matplot_dep import kernel_plots See :class:`~GPy.plotting.matplot_dep.kernel_plots`
self.plot_ARD.__doc__ += kernel_plots.plot_ARD.__doc__ """
import sys
assert "matplotlib" in sys.modules, "matplotlib package has not been imported." assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from ...plotting.matplot_dep import kernel_plots from ...plotting.matplot_dep import kernel_plots
return kernel_plots.plot_ARD(self,*args,**kw) return kernel_plots.plot_ARD(self,*args,**kw)

View file

@ -6,27 +6,48 @@ import Tango
import pylab as pb import pylab as pb
import numpy as np import numpy as np
def gpplot(x,mu,lower,upper,edgecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue'],axes=None,**kwargs): def ax_default(fignum, ax):
if axes is None: if ax is None:
axes = pb.gca() fig = pb.figure(fignum)
ax = fig.add_subplot(111)
else:
fig = ax.figure
return fig, ax
def meanplot(x, mu, color=Tango.colorsHex['darkBlue'], ax=None, fignum=None, linewidth=2,**kw):
_, axes = ax_default(fignum, ax)
#here's the mean
return axes.plot(x,mu,color=color,linewidth=linewidth,**kw)
def gpplot(x,mu,lower,upper,edgecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue'],ax=None,fignum=None,xlabel='x',ylabel='y',**kwargs):
_, axes = ax_default(ax, fignum)
mu = mu.flatten() mu = mu.flatten()
x = x.flatten() x = x.flatten()
lower = lower.flatten() lower = lower.flatten()
upper = upper.flatten() upper = upper.flatten()
plots = []
#here's the mean #here's the mean
axes.plot(x,mu,color=edgecol,linewidth=2) plots.append(meanplot(x, mu, edgecol, axes))
#here's the box #here's the box
kwargs['linewidth']=0.5 kwargs['linewidth']=0.5
if not 'alpha' in kwargs.keys(): if not 'alpha' in kwargs.keys():
kwargs['alpha'] = 0.3 kwargs['alpha'] = 0.3
axes.fill(np.hstack((x,x[::-1])),np.hstack((upper,lower[::-1])),color=fillcol,**kwargs) plots.append(axes.fill(np.hstack((x,x[::-1])),np.hstack((upper,lower[::-1])),color=fillcol,**kwargs))
#this is the edge: #this is the edge:
axes.plot(x,upper,color=edgecol,linewidth=0.2) plots.append(meanplot(x, upper,color=edgecol,linewidth=0.2,axes=axes))
axes.plot(x,lower,color=edgecol,linewidth=0.2) plots.append(meanplot(x, lower,color=edgecol,linewidth=0.2,axes=axes))
axes.set_xlabel(xlabel)
axes.set_ylabel(ylabel)
return plots
def removeRightTicks(ax=None): def removeRightTicks(ax=None):
ax = ax or pb.gca() ax = ax or pb.gca()
for i, line in enumerate(ax.get_yticklines()): for i, line in enumerate(ax.get_yticklines()):

View file

@ -6,7 +6,7 @@ import pylab as pb
import Tango import Tango
from matplotlib.textpath import TextPath from matplotlib.textpath import TextPath
from matplotlib.transforms import offset_copy from matplotlib.transforms import offset_copy
from ...kern import Linear from .base_plots import ax_default
@ -52,11 +52,7 @@ def plot_ARD(kernel, fignum=None, ax=None, title='', legend=False):
pass '' to not print a title pass '' to not print a title
pass None for a generic title pass None for a generic title
""" """
if ax is None: fig, ax = ax_default(fignum,ax)
fig = pb.figure(fignum)
ax = fig.add_subplot(111)
else:
fig = ax.figure
if title is None: if title is None:
ax.set_title('ARD parameters, %s kernel' % kernel.name) ax.set_title('ARD parameters, %s kernel' % kernel.name)
@ -70,13 +66,13 @@ def plot_ARD(kernel, fignum=None, ax=None, title='', legend=False):
bottom = 0 bottom = 0
x = np.arange(kernel.input_dim) x = np.arange(kernel.input_dim)
for i in range(ard_params.shape[-1]): for i in range(ard_params.shape[0]):
c = Tango.nextMedium() c = Tango.nextMedium()
bars.append(plot_bars(fig, ax, x, ard_params[:,i], c, kernel._parameters_[i].name, bottom=bottom)) bars.append(plot_bars(fig, ax, x, ard_params[i,:], c, kernel._parameters_[i].name, bottom=bottom))
bottom += ard_params[:,i] bottom += ard_params[i,:]
ax.set_xlim(-.5, kernel.input_dim - .5) ax.set_xlim(-.5, kernel.input_dim - .5)
add_bar_labels(fig, ax, [bars[-1]], bottom=bottom-ard_params[:,i]) add_bar_labels(fig, ax, [bars[-1]], bottom=bottom-ard_params[i,:])
if legend: if legend:
if title is '': if title is '':