[mrd] plot_scales and plot_latent added

This commit is contained in:
Max Zwiessele 2016-03-01 10:02:02 +00:00
parent c4020cd2eb
commit 885d3722cc
8 changed files with 112 additions and 148 deletions

View file

@ -42,10 +42,11 @@ class MatplotlibPlots(AbstractPlottingLibrary):
super(MatplotlibPlots, self).__init__()
self._defaults = defaults.__dict__
def figure(self, rows=1, cols=1, **kwargs):
fig = plt.figure(**kwargs)
def figure(self, rows=1, cols=1, gridspec_kwargs={}, tight_layout=True, **kwargs):
fig = plt.figure(tight_layout=tight_layout, **kwargs)
fig.rows = rows
fig.cols = cols
fig.gridspec = plt.GridSpec(rows, cols, **gridspec_kwargs)
return fig
def new_canvas(self, figure=None, row=1, col=1, projection='2d', xlabel=None, ylabel=None, zlabel=None, title=None, xlim=None, ylim=None, zlim=None, **kwargs):
@ -56,7 +57,9 @@ class MatplotlibPlots(AbstractPlottingLibrary):
if 'ax' in kwargs:
ax = kwargs.pop('ax')
else:
if 'num' in kwargs and 'figsize' in kwargs:
if figure is not None:
fig = figure
elif 'num' in kwargs and 'figsize' in kwargs:
fig = self.figure(num=kwargs.pop('num'), figsize=kwargs.pop('figsize'))
elif 'num' in kwargs:
fig = self.figure(num=kwargs.pop('num'))
@ -66,7 +69,7 @@ class MatplotlibPlots(AbstractPlottingLibrary):
fig = self.figure()
#if hasattr(fig, 'rows') and hasattr(fig, 'cols'):
ax = fig.add_subplot(fig.rows, fig.cols, (col,row), projection=projection)
ax = fig.add_subplot(fig.gridspec[row-1, col-1], projection=projection)
if xlim is not None: ax.set_xlim(xlim)
if ylim is not None: ax.set_ylim(ylim)
@ -79,7 +82,7 @@ class MatplotlibPlots(AbstractPlottingLibrary):
return ax, kwargs
def add_to_canvas(self, ax, plots, legend=False, title=None, **kwargs):
ax.autoscale_view()
#ax.autoscale_view()
fontdict=dict(family='sans-serif', weight='light', size=9)
if legend is True:
ax.legend(*ax.get_legend_handles_labels())
@ -89,9 +92,7 @@ class MatplotlibPlots(AbstractPlottingLibrary):
if title is not None: ax.figure.suptitle(title)
return ax
def show_canvas(self, ax, tight_layout=False, **kwargs):
if tight_layout:
ax.figure.tight_layout()
def show_canvas(self, ax):
ax.figure.canvas.draw()
return ax.figure

View file

@ -13,16 +13,16 @@ class SSGPLVM_plot(object):
self.model = model
self.imgsize= imgsize
assert model.Y.shape[1] == imgsize[0]*imgsize[1]
def plot_inducing(self):
fig1 = pylab.figure()
mean = self.model.posterior.mean
arr = mean.reshape(*(mean.shape[0],self.imgsize[1],self.imgsize[0]))
plot_2D_images(fig1, arr)
fig1.gca().set_title('The mean of inducing points')
fig2 = pylab.figure()
covar = self.model.posterior.covariance
plot_2D_images(fig2, covar)
fig2.gca().set_title('The variance of inducing points')