[var plots] wrong return values

This commit is contained in:
mzwiessele 2015-04-08 08:24:55 +02:00
parent 1efa842130
commit e35999b24b

View file

@ -1,6 +1,6 @@
import pylab as pb, numpy as np
def plot(parameterized, fignum=None, ax=None, colors=None):
def plot(parameterized, fignum=None, ax=None, colors=None, figsize=(12, 6)):
"""
Plot latent space X in 1D:
@ -13,13 +13,15 @@ def plot(parameterized, fignum=None, ax=None, colors=None):
"""
if ax is None:
fig = pb.figure(num=fignum, figsize=(8, min(12, (2 * parameterized.mean.shape[1]))))
fig = pb.figure(num=fignum, figsize=figsize)
if colors is None:
colors = pb.gca()._get_lines.color_cycle
pb.clf()
else:
colors = iter(colors)
plots = []
lines = []
fills = []
bg_lines = []
means, variances = parameterized.mean, parameterized.variance
x = np.arange(means.shape[0])
for i in range(means.shape[1]):
@ -29,20 +31,20 @@ def plot(parameterized, fignum=None, ax=None, colors=None):
a = ax[i]
else:
raise ValueError("Need one ax per latent dimension input_dim")
a.plot(means, c='k', alpha=.3)
plots.extend(a.plot(x, means.T[i], c=colors.next(), label=r"$\mathbf{{X_{{{}}}}}$".format(i)))
a.fill_between(x,
bg_lines.append(a.plot(means, c='k', alpha=.3))
lines.extend(a.plot(x, means.T[i], c=colors.next(), label=r"$\mathbf{{X_{{{}}}}}$".format(i)))
fills.append(a.fill_between(x,
means.T[i] - 2 * np.sqrt(variances.T[i]),
means.T[i] + 2 * np.sqrt(variances.T[i]),
facecolor=plots[-1].get_color(),
alpha=.3)
facecolor=lines[-1].get_color(),
alpha=.3))
a.legend(borderaxespad=0.)
a.set_xlim(x.min(), x.max())
if i < means.shape[1] - 1:
a.set_xticklabels('')
pb.draw()
fig.tight_layout(h_pad=.01) # , rect=(0, 0, 1, .95))
return fig
return dict(lines=lines, fills=fills, bg_lines=bg_lines)
def plot_SpikeSlab(parameterized, fignum=None, ax=None, colors=None, side_by_side=True):
"""