plotting behaviour adapted for kern and mrd

This commit is contained in:
Max Zwiessele 2013-06-04 18:19:14 +01:00
parent a32f9bf9dd
commit cadf822292
2 changed files with 23 additions and 19 deletions

View file

@ -46,10 +46,11 @@ class kern(parameterised):
parameterised.__init__(self) parameterised.__init__(self)
def plot_ARD(self, ax=None): def plot_ARD(self, fignum=None, ax=None):
"""If an ARD kernel is present, it bar-plots the ARD parameters""" """If an ARD kernel is present, it bar-plots the ARD parameters"""
if ax is None: if ax is None:
ax = pb.gca() fig = pb.figure(fignum)
ax = fig.add_subplot(111)
for p in self.parts: for p in self.parts:
if hasattr(p, 'ARD') and p.ARD: if hasattr(p, 'ARD') and p.ARD:
ax.set_title('ARD parameters, %s kernel' % p.name) ax.set_title('ARD parameters, %s kernel' % p.name)

View file

@ -256,17 +256,20 @@ class MRD(model):
self.Z = Z self.Z = Z
return Z return Z
def _handle_plotting(self, fig_num, axes, plotf): def _handle_plotting(self, fignum, ax, plotf):
if axes is None: if ax is None:
fig = pylab.figure(num=fig_num, figsize=(4 * len(self.bgplvms), 3)) fig = pylab.figure(num=fignum)
ax = fig.add_subplot(111)
if ax is None:
fig = pylab.figure(num=fignum, figsize=(4 * len(self.bgplvms), 3))
for i, g in enumerate(self.bgplvms): for i, g in enumerate(self.bgplvms):
if axes is None: if ax is None:
ax = fig.add_subplot(1, len(self.bgplvms), i + 1) ax = fig.add_subplot(1, len(self.bgplvms), i + 1)
else: else:
ax = axes[i] ax = ax[i]
plotf(i, g, ax) plotf(i, g, ax)
pylab.draw() pylab.draw()
if axes is None: if ax is None:
fig.tight_layout() fig.tight_layout()
return fig return fig
else: else:
@ -275,20 +278,20 @@ class MRD(model):
def plot_X_1d(self): def plot_X_1d(self):
return self.gref.plot_X_1d() return self.gref.plot_X_1d()
def plot_X(self, fig_num="MRD Predictions", axes=None): def plot_X(self, fignum="MRD Predictions", ax=None):
fig = self._handle_plotting(fig_num, axes, lambda i, g, ax: ax.imshow(g.X)) fig = self._handle_plotting(fignum, ax, lambda i, g, ax: ax.imshow(g.X))
return fig return fig
def plot_predict(self, fig_num="MRD Predictions", axes=None, **kwargs): def plot_predict(self, fignum="MRD Predictions", ax=None, **kwargs):
fig = self._handle_plotting(fig_num, axes, lambda i, g, ax: ax.imshow(g.predict(g.X)[0], **kwargs)) fig = self._handle_plotting(fignum, ax, lambda i, g, ax: ax.imshow(g.predict(g.X)[0], **kwargs))
return fig return fig
def plot_scales(self, fig_num="MRD Scales", axes=None, *args, **kwargs): def plot_scales(self, fignum="MRD Scales", ax=None, *args, **kwargs):
fig = self._handle_plotting(fig_num, axes, lambda i, g, ax: g.kern.plot_ARD(ax=ax, *args, **kwargs)) fig = self._handle_plotting(fignum, ax, lambda i, g, ax: g.kern.plot_ARD(ax=ax, *args, **kwargs))
return fig return fig
def plot_latent(self, fig_num="MRD Latent Spaces", axes=None, *args, **kwargs): def plot_latent(self, fignum="MRD Latent Spaces", ax=None, *args, **kwargs):
fig = self._handle_plotting(fig_num, axes, lambda i, g, ax: g.plot_latent(ax=ax, *args, **kwargs)) fig = self._handle_plotting(fignum, ax, lambda i, g, ax: g.plot_latent(ax=ax, *args, **kwargs))
return fig return fig
def _debug_plot(self): def _debug_plot(self):
@ -296,11 +299,11 @@ class MRD(model):
fig = pylab.figure("MRD DEBUG PLOT", figsize=(4 * len(self.bgplvms), 9)) fig = pylab.figure("MRD DEBUG PLOT", figsize=(4 * len(self.bgplvms), 9))
fig.clf() fig.clf()
axes = [fig.add_subplot(3, len(self.bgplvms), i + 1) for i in range(len(self.bgplvms))] axes = [fig.add_subplot(3, len(self.bgplvms), i + 1) for i in range(len(self.bgplvms))]
self.plot_X(axes=axes) self.plot_X(ax=axes)
axes = [fig.add_subplot(3, len(self.bgplvms), i + len(self.bgplvms) + 1) for i in range(len(self.bgplvms))] axes = [fig.add_subplot(3, len(self.bgplvms), i + len(self.bgplvms) + 1) for i in range(len(self.bgplvms))]
self.plot_latent(axes=axes) self.plot_latent(ax=axes)
axes = [fig.add_subplot(3, len(self.bgplvms), i + 2 * len(self.bgplvms) + 1) for i in range(len(self.bgplvms))] axes = [fig.add_subplot(3, len(self.bgplvms), i + 2 * len(self.bgplvms) + 1) for i in range(len(self.bgplvms))]
self.plot_scales(axes=axes) self.plot_scales(ax=axes)
pylab.draw() pylab.draw()
fig.tight_layout() fig.tight_layout()