From e35999b24ba3e070dc245f775c82d5adddebc116 Mon Sep 17 00:00:00 2001 From: mzwiessele Date: Wed, 8 Apr 2015 08:24:55 +0200 Subject: [PATCH] [var plots] wrong return values --- GPy/plotting/matplot_dep/variational_plots.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/GPy/plotting/matplot_dep/variational_plots.py b/GPy/plotting/matplot_dep/variational_plots.py index 5cced10d..55128ec7 100644 --- a/GPy/plotting/matplot_dep/variational_plots.py +++ b/GPy/plotting/matplot_dep/variational_plots.py @@ -1,6 +1,6 @@ import pylab as pb, numpy as np -def plot(parameterized, fignum=None, ax=None, colors=None): +def plot(parameterized, fignum=None, ax=None, colors=None, figsize=(12, 6)): """ Plot latent space X in 1D: @@ -13,13 +13,15 @@ def plot(parameterized, fignum=None, ax=None, colors=None): """ if ax is None: - fig = pb.figure(num=fignum, figsize=(8, min(12, (2 * parameterized.mean.shape[1])))) + fig = pb.figure(num=fignum, figsize=figsize) if colors is None: colors = pb.gca()._get_lines.color_cycle pb.clf() else: colors = iter(colors) - plots = [] + lines = [] + fills = [] + bg_lines = [] means, variances = parameterized.mean, parameterized.variance x = np.arange(means.shape[0]) for i in range(means.shape[1]): @@ -29,20 +31,20 @@ def plot(parameterized, fignum=None, ax=None, colors=None): a = ax[i] else: raise ValueError("Need one ax per latent dimension input_dim") - a.plot(means, c='k', alpha=.3) - plots.extend(a.plot(x, means.T[i], c=colors.next(), label=r"$\mathbf{{X_{{{}}}}}$".format(i))) - a.fill_between(x, + bg_lines.append(a.plot(means, c='k', alpha=.3)) + lines.extend(a.plot(x, means.T[i], c=colors.next(), label=r"$\mathbf{{X_{{{}}}}}$".format(i))) + fills.append(a.fill_between(x, means.T[i] - 2 * np.sqrt(variances.T[i]), means.T[i] + 2 * np.sqrt(variances.T[i]), - facecolor=plots[-1].get_color(), - alpha=.3) + facecolor=lines[-1].get_color(), + alpha=.3)) a.legend(borderaxespad=0.) a.set_xlim(x.min(), x.max()) if i < means.shape[1] - 1: a.set_xticklabels('') pb.draw() fig.tight_layout(h_pad=.01) # , rect=(0, 0, 1, .95)) - return fig + return dict(lines=lines, fills=fills, bg_lines=bg_lines) def plot_SpikeSlab(parameterized, fignum=None, ax=None, colors=None, side_by_side=True): """