diff --git a/GPy/models/mrd.py b/GPy/models/mrd.py index 4e7f2f3b..547f096f 100644 --- a/GPy/models/mrd.py +++ b/GPy/models/mrd.py @@ -236,7 +236,7 @@ class MRD(BayesianGPLVMMiniBatch): # sharex=sharex, sharey=sharey) # return fig - def plot_scales(self, titles=None, fig_kwargs=dict(figsize=None, tight_layout=True), **kwargs): + def plot_scales(self, titles=None, fig_kwargs={}, **kwargs): """ Plot input sensitivity for all datasets, to see which input dimensions are significant for which dataset. @@ -252,12 +252,9 @@ class MRD(BayesianGPLVMMiniBatch): M = len(self.bgplvms) fig = pl().figure(rows=1, cols=M, **fig_kwargs) - plots = {} for c in range(M): canvas = self.bgplvms[c].kern.plot_ARD(title=titles[c], figure=fig, col=c+1, **kwargs) - plots[titles[c]] = canvas - pl().show_canvas(canvas) - return plots + return canvas def plot_latent(self, labels=None, which_indices=None, resolution=60, legend=True,