From 75f4e26b23d8bf50d432a4397ab6b1716579cd7c Mon Sep 17 00:00:00 2001 From: Max Zwiessele Date: Tue, 4 Jun 2013 18:09:02 +0100 Subject: [PATCH 1/5] added priors behaviour as intended and issue #38 closed and fixed --- GPy/core/priors.py | 1 + GPy/inference/SCG.py | 2 +- GPy/inference/SGD.py | 8 ++++---- GPy/likelihoods/Gaussian.py | 8 ++++---- GPy/models/Bayesian_GPLVM.py | 13 ++++++------- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/GPy/core/priors.py b/GPy/core/priors.py index b86024d0..7b6379de 100644 --- a/GPy/core/priors.py +++ b/GPy/core/priors.py @@ -136,6 +136,7 @@ def gamma_from_EV(E, V): warnings.warn("use Gamma.from_EV to create Gamma Prior", FutureWarning) return Gamma.from_EV(E, V) + class Gamma(Prior): """ Implementation of the Gamma probability function, coupled with random variables. diff --git a/GPy/inference/SCG.py b/GPy/inference/SCG.py index 1cec0baa..5753be7f 100644 --- a/GPy/inference/SCG.py +++ b/GPy/inference/SCG.py @@ -63,7 +63,7 @@ def SCG(f, gradf, x, optargs=(), maxiters=500, max_f_eval=500, display=True, xto success = True # Force calculation of directional derivs. nsuccess = 0 # nsuccess counts number of successes. beta = 1.0 # Initial scale parameter. - betamin = 1.0e-15 # Lower bound on scale. + betamin = 1.0e-60 # Lower bound on scale. betamax = 1.0e100 # Upper bound on scale. status = "Not converged" diff --git a/GPy/inference/SGD.py b/GPy/inference/SGD.py index 7cb7566f..7543a819 100644 --- a/GPy/inference/SGD.py +++ b/GPy/inference/SGD.py @@ -192,7 +192,7 @@ class opt_SGD(Optimizer): if self.model.N == 0 or Y.std() == 0.0: return 0, step, self.model.N - self.model.likelihood._bias = Y.mean() + self.model.likelihood._offset = Y.mean() self.model.likelihood._scale = Y.std() self.model.likelihood.set_data(Y) # self.model.likelihood.V = self.model.likelihood.Y*self.model.likelihood.precision @@ -219,9 +219,9 @@ class opt_SGD(Optimizer): self.restore_constraints(ci) self.model.grads[j] = fp - # restore likelihood _bias and _scale, otherwise when we call set_data(y) on + # restore likelihood _offset and _scale, otherwise when we call set_data(y) on # the next feature, it will get normalized with the mean and std of this one. - self.model.likelihood._bias = 0 + self.model.likelihood._offset = 0 self.model.likelihood._scale = 1 return f, step, self.model.N @@ -266,7 +266,7 @@ class opt_SGD(Optimizer): self.model.likelihood.YYT = 0 self.model.likelihood.trYYT = 0 - self.model.likelihood._bias = 0.0 + self.model.likelihood._offset = 0.0 self.model.likelihood._scale = 1.0 N, Q = self.model.X.shape diff --git a/GPy/likelihoods/Gaussian.py b/GPy/likelihoods/Gaussian.py index d87b1b98..6c3c4edf 100644 --- a/GPy/likelihoods/Gaussian.py +++ b/GPy/likelihoods/Gaussian.py @@ -19,12 +19,12 @@ class Gaussian(likelihood): # normalization if normalize: - self._bias = data.mean(0)[None, :] + self._offset = data.mean(0)[None, :] self._scale = data.std(0)[None, :] # Don't scale outputs which have zero variance to zero. self._scale[np.nonzero(self._scale == 0.)] = 1.0e-3 else: - self._bias = np.zeros((1, self.D)) + self._offset = np.zeros((1, self.D)) self._scale = np.ones((1, self.D)) self.set_data(data) @@ -36,7 +36,7 @@ class Gaussian(likelihood): self.data = data self.N, D = data.shape assert D == self.D - self.Y = (self.data - self._bias) / self._scale + self.Y = (self.data - self._offset) / self._scale if D > self.N: self.YYT = np.dot(self.Y, self.Y.T) self.trYYT = np.trace(self.YYT) @@ -66,7 +66,7 @@ class Gaussian(likelihood): """ Un-normalize the prediction and add the likelihood variance, then return the 5%, 95% interval """ - mean = mu * self._scale + self._bias + mean = mu * self._scale + self._offset if full_cov: if self.D > 1: raise NotImplementedError, "TODO" diff --git a/GPy/models/Bayesian_GPLVM.py b/GPy/models/Bayesian_GPLVM.py index dcef5291..0b0797a5 100644 --- a/GPy/models/Bayesian_GPLVM.py +++ b/GPy/models/Bayesian_GPLVM.py @@ -218,20 +218,19 @@ class Bayesian_GPLVM(sparse_GP, GPLVM): return means, covars - def plot_X_1d(self, fig=None, axes=None, fig_num="LVM mu S 1d", colors=None): + def plot_X_1d(self, ax=None, fignum=None, colors=None): """ Plot latent space X in 1D: -if fig is given, create Q subplots in fig and plot in these - -if axes is given plot Q 1D latent space plots of X into each `axis` - -if neither fig nor axes is given create a figure with fig_num and plot in there + -if ax is given plot Q 1D latent space plots of X into each `axis` + -if neither fig nor ax is given create a figure with fignum and plot in there colors: colors of different latent space dimensions Q """ import pylab - if fig is None and axes is None: - fig = pylab.figure(num=fig_num, figsize=(8, min(12, (2 * self.X.shape[1])))) + fig = pylab.figure(num=fignum, figsize=(8, min(12, (2 * self.X.shape[1])))) if colors is None: colors = pylab.gca()._get_lines.color_cycle pylab.clf() @@ -240,10 +239,10 @@ class Bayesian_GPLVM(sparse_GP, GPLVM): plots = [] x = np.arange(self.X.shape[0]) for i in range(self.X.shape[1]): - if axes is None: + if ax is None: ax = fig.add_subplot(self.X.shape[1], 1, i + 1) else: - ax = axes[i] + ax = ax[i] ax.plot(self.X, c='k', alpha=.3) plots.extend(ax.plot(x, self.X.T[i], c=colors.next(), label=r"$\mathbf{{X_{{{}}}}}$".format(i))) ax.fill_between(x, From cadf822292839b3a1635e49e474df108e579132b Mon Sep 17 00:00:00 2001 From: Max Zwiessele Date: Tue, 4 Jun 2013 18:19:14 +0100 Subject: [PATCH 2/5] plotting behaviour adapted for kern and mrd --- GPy/kern/kern.py | 5 +++-- GPy/models/mrd.py | 37 ++++++++++++++++++++----------------- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/GPy/kern/kern.py b/GPy/kern/kern.py index c9582ac8..28e33b4b 100644 --- a/GPy/kern/kern.py +++ b/GPy/kern/kern.py @@ -46,10 +46,11 @@ class kern(parameterised): parameterised.__init__(self) - def plot_ARD(self, ax=None): + def plot_ARD(self, fignum=None, ax=None): """If an ARD kernel is present, it bar-plots the ARD parameters""" if ax is None: - ax = pb.gca() + fig = pb.figure(fignum) + ax = fig.add_subplot(111) for p in self.parts: if hasattr(p, 'ARD') and p.ARD: ax.set_title('ARD parameters, %s kernel' % p.name) diff --git a/GPy/models/mrd.py b/GPy/models/mrd.py index 6f2a5f6a..5165f5f8 100644 --- a/GPy/models/mrd.py +++ b/GPy/models/mrd.py @@ -256,17 +256,20 @@ class MRD(model): self.Z = Z return Z - def _handle_plotting(self, fig_num, axes, plotf): - if axes is None: - fig = pylab.figure(num=fig_num, figsize=(4 * len(self.bgplvms), 3)) + def _handle_plotting(self, fignum, ax, plotf): + if ax is None: + fig = pylab.figure(num=fignum) + ax = fig.add_subplot(111) + if ax is None: + fig = pylab.figure(num=fignum, figsize=(4 * len(self.bgplvms), 3)) for i, g in enumerate(self.bgplvms): - if axes is None: + if ax is None: ax = fig.add_subplot(1, len(self.bgplvms), i + 1) else: - ax = axes[i] + ax = ax[i] plotf(i, g, ax) pylab.draw() - if axes is None: + if ax is None: fig.tight_layout() return fig else: @@ -275,20 +278,20 @@ class MRD(model): def plot_X_1d(self): return self.gref.plot_X_1d() - def plot_X(self, fig_num="MRD Predictions", axes=None): - fig = self._handle_plotting(fig_num, axes, lambda i, g, ax: ax.imshow(g.X)) + def plot_X(self, fignum="MRD Predictions", ax=None): + fig = self._handle_plotting(fignum, ax, lambda i, g, ax: ax.imshow(g.X)) return fig - def plot_predict(self, fig_num="MRD Predictions", axes=None, **kwargs): - fig = self._handle_plotting(fig_num, axes, lambda i, g, ax: ax.imshow(g.predict(g.X)[0], **kwargs)) + def plot_predict(self, fignum="MRD Predictions", ax=None, **kwargs): + fig = self._handle_plotting(fignum, ax, lambda i, g, ax: ax.imshow(g.predict(g.X)[0], **kwargs)) return fig - def plot_scales(self, fig_num="MRD Scales", axes=None, *args, **kwargs): - fig = self._handle_plotting(fig_num, axes, lambda i, g, ax: g.kern.plot_ARD(ax=ax, *args, **kwargs)) + def plot_scales(self, fignum="MRD Scales", ax=None, *args, **kwargs): + fig = self._handle_plotting(fignum, ax, lambda i, g, ax: g.kern.plot_ARD(ax=ax, *args, **kwargs)) return fig - def plot_latent(self, fig_num="MRD Latent Spaces", axes=None, *args, **kwargs): - fig = self._handle_plotting(fig_num, axes, lambda i, g, ax: g.plot_latent(ax=ax, *args, **kwargs)) + def plot_latent(self, fignum="MRD Latent Spaces", ax=None, *args, **kwargs): + fig = self._handle_plotting(fignum, ax, lambda i, g, ax: g.plot_latent(ax=ax, *args, **kwargs)) return fig def _debug_plot(self): @@ -296,11 +299,11 @@ class MRD(model): fig = pylab.figure("MRD DEBUG PLOT", figsize=(4 * len(self.bgplvms), 9)) fig.clf() axes = [fig.add_subplot(3, len(self.bgplvms), i + 1) for i in range(len(self.bgplvms))] - self.plot_X(axes=axes) + self.plot_X(ax=axes) axes = [fig.add_subplot(3, len(self.bgplvms), i + len(self.bgplvms) + 1) for i in range(len(self.bgplvms))] - self.plot_latent(axes=axes) + self.plot_latent(ax=axes) axes = [fig.add_subplot(3, len(self.bgplvms), i + 2 * len(self.bgplvms) + 1) for i in range(len(self.bgplvms))] - self.plot_scales(axes=axes) + self.plot_scales(ax=axes) pylab.draw() fig.tight_layout() From 2b0858b697d62b74a57518c91f5ad9e63540b028 Mon Sep 17 00:00:00 2001 From: Max Zwiessele Date: Tue, 4 Jun 2013 18:25:28 +0100 Subject: [PATCH 3/5] plotting behaviour adapted for BGPLVM --- GPy/models/Bayesian_GPLVM.py | 9 ++++++--- GPy/models/mrd.py | 7 +++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/GPy/models/Bayesian_GPLVM.py b/GPy/models/Bayesian_GPLVM.py index 0b0797a5..b7f7b42b 100644 --- a/GPy/models/Bayesian_GPLVM.py +++ b/GPy/models/Bayesian_GPLVM.py @@ -218,7 +218,7 @@ class Bayesian_GPLVM(sparse_GP, GPLVM): return means, covars - def plot_X_1d(self, ax=None, fignum=None, colors=None): + def plot_X_1d(self, fignum=None, ax=None, colors=None): """ Plot latent space X in 1D: @@ -230,7 +230,8 @@ class Bayesian_GPLVM(sparse_GP, GPLVM): colors of different latent space dimensions Q """ import pylab - fig = pylab.figure(num=fignum, figsize=(8, min(12, (2 * self.X.shape[1])))) + if ax is None: + fig = pylab.figure(num=fignum, figsize=(8, min(12, (2 * self.X.shape[1])))) if colors is None: colors = pylab.gca()._get_lines.color_cycle pylab.clf() @@ -241,8 +242,10 @@ class Bayesian_GPLVM(sparse_GP, GPLVM): for i in range(self.X.shape[1]): if ax is None: ax = fig.add_subplot(self.X.shape[1], 1, i + 1) - else: + elif isinstance(ax, (tuple, list)): ax = ax[i] + else: + raise ValueError("Need one ax per latent dimnesion Q") ax.plot(self.X, c='k', alpha=.3) plots.extend(ax.plot(x, self.X.T[i], c=colors.next(), label=r"$\mathbf{{X_{{{}}}}}$".format(i))) ax.fill_between(x, diff --git a/GPy/models/mrd.py b/GPy/models/mrd.py index 5165f5f8..eab5131f 100644 --- a/GPy/models/mrd.py +++ b/GPy/models/mrd.py @@ -257,16 +257,15 @@ class MRD(model): return Z def _handle_plotting(self, fignum, ax, plotf): - if ax is None: - fig = pylab.figure(num=fignum) - ax = fig.add_subplot(111) if ax is None: fig = pylab.figure(num=fignum, figsize=(4 * len(self.bgplvms), 3)) for i, g in enumerate(self.bgplvms): if ax is None: ax = fig.add_subplot(1, len(self.bgplvms), i + 1) - else: + elif isinstance(ax, (tuple, list)): ax = ax[i] + else: + raise ValueError("Need one ax per latent dimension Q") plotf(i, g, ax) pylab.draw() if ax is None: From a5e1985d9d852dd9cbcb447a3afdc6fe84271701 Mon Sep 17 00:00:00 2001 From: James Hensman Date: Tue, 4 Jun 2013 18:25:46 +0100 Subject: [PATCH 4/5] plotting in the model behaves better --- GPy/core/gp_base.py | 51 +++++++++++++++++++++++++------------------ GPy/core/sparse_GP.py | 20 ++++++++++------- 2 files changed, 42 insertions(+), 29 deletions(-) diff --git a/GPy/core/gp_base.py b/GPy/core/gp_base.py index 5d6e6f4c..26bed691 100644 --- a/GPy/core/gp_base.py +++ b/GPy/core/gp_base.py @@ -33,7 +33,7 @@ class GPBase(model.model): # All leaf nodes should call self._set_params(self._get_params()) at # the end - def plot_f(self, samples=0, plot_limits=None, which_data='all', which_parts='all', resolution=None, full_cov=False): + def plot_f(self, samples=0, plot_limits=None, which_data='all', which_parts='all', resolution=None, full_cov=False, fignum=None, ax=None): """ Plot the GP's view of the world, where the data is normalized and the likelihood is Gaussian. @@ -57,46 +57,55 @@ class GPBase(model.model): if which_data == 'all': which_data = slice(None) + if ax is None: + fig = pb.figure(num=fignum) + ax = fig.add_subplot(111) + if self.X.shape[1] == 1: Xnew, xmin, xmax = x_frame1D(self.X, plot_limits=plot_limits) if samples == 0: m, v = self._raw_predict(Xnew, which_parts=which_parts) - gpplot(Xnew, m, m - 2 * np.sqrt(v), m + 2 * np.sqrt(v)) - pb.plot(self.X[which_data], self.likelihood.Y[which_data], 'kx', mew=1.5) + gpplot(Xnew, m, m - 2 * np.sqrt(v), m + 2 * np.sqrt(v), axes=ax) + ax.plot(self.X[which_data], self.likelihood.Y[which_data], 'kx', mew=1.5) else: m, v = self._raw_predict(Xnew, which_parts=which_parts, full_cov=True) Ysim = np.random.multivariate_normal(m.flatten(), v, samples) - gpplot(Xnew, m, m - 2 * np.sqrt(np.diag(v)[:, None]), m + 2 * np.sqrt(np.diag(v))[:, None]) + gpplot(Xnew, m, m - 2 * np.sqrt(np.diag(v)[:, None]), m + 2 * np.sqrt(np.diag(v))[:, None,], axes=ax) for i in range(samples): - pb.plot(Xnew, Ysim[i, :], Tango.colorsHex['darkBlue'], linewidth=0.25) - pb.plot(self.X[which_data], self.likelihood.Y[which_data], 'kx', mew=1.5) - pb.xlim(xmin, xmax) + ax.plot(Xnew, Ysim[i, :], Tango.colorsHex['darkBlue'], linewidth=0.25) + ax.plot(self.X[which_data], self.likelihood.Y[which_data], 'kx', mew=1.5) + ax.set_xlim(xmin, xmax) ymin, ymax = min(np.append(self.likelihood.Y, m - 2 * np.sqrt(np.diag(v)[:, None]))), max(np.append(self.likelihood.Y, m + 2 * np.sqrt(np.diag(v)[:, None]))) ymin, ymax = ymin - 0.1 * (ymax - ymin), ymax + 0.1 * (ymax - ymin) - pb.ylim(ymin, ymax) + ax.set_ylim(ymin, ymax) elif self.X.shape[1] == 2: resolution = resolution or 50 Xnew, xmin, xmax, xx, yy = x_frame2D(self.X, plot_limits, resolution) m, v = self._raw_predict(Xnew, which_parts=which_parts) m = m.reshape(resolution, resolution).T - pb.contour(xx, yy, m, vmin=m.min(), vmax=m.max(), cmap=pb.cm.jet) - pb.scatter(self.X[:, 0], self.X[:, 1], 40, self.likelihood.Y, linewidth=0, cmap=pb.cm.jet, vmin=m.min(), vmax=m.max()) - pb.xlim(xmin[0], xmax[0]) - pb.ylim(xmin[1], xmax[1]) + ax.contour(xx, yy, m, vmin=m.min(), vmax=m.max(), cmap=pb.cm.jet) + ax.scatter(self.X[:, 0], self.X[:, 1], 40, self.likelihood.Y, linewidth=0, cmap=pb.cm.jet, vmin=m.min(), vmax=m.max()) + ax.set_xlim(xmin[0], xmax[0]) + ax.set_ylim(xmin[1], xmax[1]) else: raise NotImplementedError, "Cannot define a frame with more than two input dimensions" - def plot(self, samples=0, plot_limits=None, which_data='all', which_parts='all', resolution=None, levels=20): + def plot(self, plot_limits=None, which_data='all', which_parts='all', resolution=None, levels=20, samples=0, fignum=None, ax=None): """ TODO: Docstrings! :param levels: for 2D plotting, the number of contour levels to use + is ax is None, create a new figure """ # TODO include samples if which_data == 'all': which_data = slice(None) + if ax is None: + fig = pb.figure(num=fignum) + ax = fig.add_subplot(111) + if self.X.shape[1] == 1: Xu = self.X * self._Xstd + self._Xmean # NOTE self.X are the normalized values now @@ -104,12 +113,12 @@ class GPBase(model.model): Xnew, xmin, xmax = x_frame1D(Xu, plot_limits=plot_limits) m, var, lower, upper = self.predict(Xnew, which_parts=which_parts) for d in range(m.shape[1]): - gpplot(Xnew, m[:,d], lower[:,d], upper[:,d]) - pb.plot(Xu[which_data], self.likelihood.data[which_data,d], 'kx', mew=1.5) + gpplot(Xnew, m[:,d], lower[:,d], upper[:,d],axes=ax) + ax.plot(Xu[which_data], self.likelihood.data[which_data,d], 'kx', mew=1.5) ymin, ymax = min(np.append(self.likelihood.data, lower)), max(np.append(self.likelihood.data, upper)) ymin, ymax = ymin - 0.1 * (ymax - ymin), ymax + 0.1 * (ymax - ymin) - pb.xlim(xmin, xmax) - pb.ylim(ymin, ymax) + ax.set_xlim(xmin, xmax) + ax.set_ylim(ymin, ymax) elif self.X.shape[1] == 2: # FIXME resolution = resolution or 50 @@ -117,11 +126,11 @@ class GPBase(model.model): x, y = np.linspace(xmin[0], xmax[0], resolution), np.linspace(xmin[1], xmax[1], resolution) m, var, lower, upper = self.predict(Xnew, which_parts=which_parts) m = m.reshape(resolution, resolution).T - pb.contour(x, y, m, levels, vmin=m.min(), vmax=m.max(), cmap=pb.cm.jet) + ax.contour(x, y, m, levels, vmin=m.min(), vmax=m.max(), cmap=pb.cm.jet) Yf = self.likelihood.Y.flatten() - pb.scatter(self.X[:, 0], self.X[:, 1], 40, Yf, cmap=pb.cm.jet, vmin=m.min(), vmax=m.max(), linewidth=0.) - pb.xlim(xmin[0], xmax[0]) - pb.ylim(xmin[1], xmax[1]) + ax.scatter(self.X[:, 0], self.X[:, 1], 40, Yf, cmap=pb.cm.jet, vmin=m.min(), vmax=m.max(), linewidth=0.) + ax.set_xlim(xmin[0], xmax[0]) + ax.set_ylim(xmin[1], xmax[1]) else: raise NotImplementedError, "Cannot define a frame with more than two input dimensions" diff --git a/GPy/core/sparse_GP.py b/GPy/core/sparse_GP.py index d913d31d..55cff38f 100644 --- a/GPy/core/sparse_GP.py +++ b/GPy/core/sparse_GP.py @@ -281,17 +281,21 @@ class sparse_GP(GPBase): return mean, var, _025pm, _975pm - def plot(self, samples=0, plot_limits=None, which_data='all', which_parts='all', resolution=None, levels=20): - GPBase.plot(self, samples=0, plot_limits=None, which_data='all', which_parts='all', resolution=None, levels=20) + def plot(self, samples=0, plot_limits=None, which_data='all', which_parts='all', resolution=None, levels=20, fignum=None, ax=None): + if ax is None: + fig = pb.figure(num=fignum) + ax = fig.add_subplot(111) + + GPBase.plot(self, samples=0, plot_limits=None, which_data='all', which_parts='all', resolution=None, levels=20, ax=ax) if self.X.shape[1] == 1: - Xu = self.X * self._Xstd + self._Xmean # NOTE self.X are the normalized values now if self.has_uncertain_inputs: - pb.errorbar(Xu[which_data, 0], self.likelihood.data[which_data, 0], + Xu = self.X * self._Xstd + self._Xmean # NOTE self.X are the normalized values now + ax.errorbar(Xu[which_data, 0], self.likelihood.data[which_data, 0], xerr=2 * np.sqrt(self.X_variance[which_data, 0]), ecolor='k', fmt=None, elinewidth=.5, alpha=.5) Zu = self.Z * self._Xstd + self._Xmean - pb.plot(Zu, Zu * 0 + pb.ylim()[0], 'r|', mew=1.5, markersize=12) - # pb.errorbar(self.X[:,0], pb.ylim()[0]+np.zeros(self.N), xerr=2*np.sqrt(self.X_variance.flatten())) + ax.plot(Zu, np.zeros_like(Zu) + ax.get_ylim()[0], 'r|', mew=1.5, markersize=12) - elif self.X.shape[1] == 2: # FIXME - pb.plot(self.Z[:, 0], self.Z[:, 1], 'wo') + elif self.X.shape[1] == 2: + Zu = self.Z * self._Xstd + self._Xmean + ax.plot(Zu[:, 0], Zu[:, 1], 'wo') From 8bdb14b0f938d951ded612e172b8c491ac71ef96 Mon Sep 17 00:00:00 2001 From: Max Zwiessele Date: Tue, 4 Jun 2013 18:26:16 +0100 Subject: [PATCH 5/5] plotting behaviour adapted for BGPLVM --- GPy/models/mrd.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/GPy/models/mrd.py b/GPy/models/mrd.py index eab5131f..e1bfa947 100644 --- a/GPy/models/mrd.py +++ b/GPy/models/mrd.py @@ -256,19 +256,19 @@ class MRD(model): self.Z = Z return Z - def _handle_plotting(self, fignum, ax, plotf): - if ax is None: + def _handle_plotting(self, fignum, axes, plotf): + if axes is None: fig = pylab.figure(num=fignum, figsize=(4 * len(self.bgplvms), 3)) for i, g in enumerate(self.bgplvms): - if ax is None: - ax = fig.add_subplot(1, len(self.bgplvms), i + 1) - elif isinstance(ax, (tuple, list)): - ax = ax[i] + if axes is None: + axes = fig.add_subplot(1, len(self.bgplvms), i + 1) + elif isinstance(axes, (tuple, list)): + axes = axes[i] else: - raise ValueError("Need one ax per latent dimension Q") - plotf(i, g, ax) + raise ValueError("Need one axes per latent dimension Q") + plotf(i, g, axes) pylab.draw() - if ax is None: + if axes is None: fig.tight_layout() return fig else: @@ -286,11 +286,11 @@ class MRD(model): return fig def plot_scales(self, fignum="MRD Scales", ax=None, *args, **kwargs): - fig = self._handle_plotting(fignum, ax, lambda i, g, ax: g.kern.plot_ARD(ax=ax, *args, **kwargs)) + fig = self._handle_plotting(fignum, ax, lambda i, g, ax: g.kern.plot_ARD(axes=ax, *args, **kwargs)) return fig def plot_latent(self, fignum="MRD Latent Spaces", ax=None, *args, **kwargs): - fig = self._handle_plotting(fignum, ax, lambda i, g, ax: g.plot_latent(ax=ax, *args, **kwargs)) + fig = self._handle_plotting(fignum, ax, lambda i, g, ax: g.plot_latent(axes=ax, *args, **kwargs)) return fig def _debug_plot(self):