From 2b0858b697d62b74a57518c91f5ad9e63540b028 Mon Sep 17 00:00:00 2001 From: Max Zwiessele Date: Tue, 4 Jun 2013 18:25:28 +0100 Subject: [PATCH] plotting behaviour adapted for BGPLVM --- GPy/models/Bayesian_GPLVM.py | 9 ++++++--- GPy/models/mrd.py | 7 +++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/GPy/models/Bayesian_GPLVM.py b/GPy/models/Bayesian_GPLVM.py index 0b0797a5..b7f7b42b 100644 --- a/GPy/models/Bayesian_GPLVM.py +++ b/GPy/models/Bayesian_GPLVM.py @@ -218,7 +218,7 @@ class Bayesian_GPLVM(sparse_GP, GPLVM): return means, covars - def plot_X_1d(self, ax=None, fignum=None, colors=None): + def plot_X_1d(self, fignum=None, ax=None, colors=None): """ Plot latent space X in 1D: @@ -230,7 +230,8 @@ class Bayesian_GPLVM(sparse_GP, GPLVM): colors of different latent space dimensions Q """ import pylab - fig = pylab.figure(num=fignum, figsize=(8, min(12, (2 * self.X.shape[1])))) + if ax is None: + fig = pylab.figure(num=fignum, figsize=(8, min(12, (2 * self.X.shape[1])))) if colors is None: colors = pylab.gca()._get_lines.color_cycle pylab.clf() @@ -241,8 +242,10 @@ class Bayesian_GPLVM(sparse_GP, GPLVM): for i in range(self.X.shape[1]): if ax is None: ax = fig.add_subplot(self.X.shape[1], 1, i + 1) - else: + elif isinstance(ax, (tuple, list)): ax = ax[i] + else: + raise ValueError("Need one ax per latent dimnesion Q") ax.plot(self.X, c='k', alpha=.3) plots.extend(ax.plot(x, self.X.T[i], c=colors.next(), label=r"$\mathbf{{X_{{{}}}}}$".format(i))) ax.fill_between(x, diff --git a/GPy/models/mrd.py b/GPy/models/mrd.py index 5165f5f8..eab5131f 100644 --- a/GPy/models/mrd.py +++ b/GPy/models/mrd.py @@ -257,16 +257,15 @@ class MRD(model): return Z def _handle_plotting(self, fignum, ax, plotf): - if ax is None: - fig = pylab.figure(num=fignum) - ax = fig.add_subplot(111) if ax is None: fig = pylab.figure(num=fignum, figsize=(4 * len(self.bgplvms), 3)) for i, g in enumerate(self.bgplvms): if ax is None: ax = fig.add_subplot(1, len(self.bgplvms), i + 1) - else: + elif isinstance(ax, (tuple, list)): ax = ax[i] + else: + raise ValueError("Need one ax per latent dimension Q") plotf(i, g, ax) pylab.draw() if ax is None: