fix deprecated import style

This commit is contained in:
Martin Bubel 2024-05-19 20:21:53 +02:00
parent f980f4a5b3
commit 741bb0b0da

View file

@ -1,4 +1,6 @@
from matplotlib import pyplot as pb, numpy as np
from matplotlib import pyplot as pb
import numpy as np
def plot(parameterized, fignum=None, ax=None, colors=None, figsize=(12, 6)):
"""
@ -17,6 +19,7 @@ def plot(parameterized, fignum=None, ax=None, colors=None, figsize=(12, 6)):
if colors is None:
from ..Tango import mediumList
from itertools import cycle
colors = cycle(mediumList)
pb.clf()
else:
@ -33,21 +36,30 @@ def plot(parameterized, fignum=None, ax=None, colors=None, figsize=(12, 6)):
a = ax[i]
else:
raise ValueError("Need one ax per latent dimension input_dim")
bg_lines.append(a.plot(means, c='k', alpha=.3))
lines.extend(a.plot(x, means.T[i], c=next(colors), label=r"$\mathbf{{X_{{{}}}}}$".format(i)))
fills.append(a.fill_between(x,
means.T[i] - 2 * np.sqrt(variances.T[i]),
means.T[i] + 2 * np.sqrt(variances.T[i]),
facecolor=lines[-1].get_color(),
alpha=.3))
a.legend(borderaxespad=0.)
bg_lines.append(a.plot(means, c="k", alpha=0.3))
lines.extend(
a.plot(
x, means.T[i], c=next(colors), label=r"$\mathbf{{X_{{{}}}}}$".format(i)
)
)
fills.append(
a.fill_between(
x,
means.T[i] - 2 * np.sqrt(variances.T[i]),
means.T[i] + 2 * np.sqrt(variances.T[i]),
facecolor=lines[-1].get_color(),
alpha=0.3,
)
)
a.legend(borderaxespad=0.0)
a.set_xlim(x.min(), x.max())
if i < means.shape[1] - 1:
a.set_xticklabels('')
a.set_xticklabels("")
pb.draw()
a.figure.tight_layout(h_pad=.01) # , rect=(0, 0, 1, .95))
a.figure.tight_layout(h_pad=0.01) # , rect=(0, 0, 1, .95))
return dict(lines=lines, fills=fills, bg_lines=bg_lines)
def plot_SpikeSlab(parameterized, fignum=None, ax=None, colors=None, side_by_side=True):
"""
Plot latent space X in 1D:
@ -62,45 +74,60 @@ def plot_SpikeSlab(parameterized, fignum=None, ax=None, colors=None, side_by_sid
"""
if ax is None:
if side_by_side:
fig = pb.figure(num=fignum, figsize=(16, min(12, (2 * parameterized.mean.shape[1]))))
fig = pb.figure(
num=fignum, figsize=(16, min(12, (2 * parameterized.mean.shape[1])))
)
else:
fig = pb.figure(num=fignum, figsize=(8, min(12, (2 * parameterized.mean.shape[1]))))
fig = pb.figure(
num=fignum, figsize=(8, min(12, (2 * parameterized.mean.shape[1])))
)
if colors is None:
from ..Tango import mediumList
from itertools import cycle
colors = cycle(mediumList)
pb.clf()
else:
colors = iter(colors)
plots = []
means, variances, gamma = parameterized.mean, parameterized.variance, parameterized.binary_prob
means, variances, gamma = (
parameterized.mean,
parameterized.variance,
parameterized.binary_prob,
)
x = np.arange(means.shape[0])
for i in range(means.shape[1]):
if side_by_side:
sub1 = (means.shape[1],2,2*i+1)
sub2 = (means.shape[1],2,2*i+2)
sub1 = (means.shape[1], 2, 2 * i + 1)
sub2 = (means.shape[1], 2, 2 * i + 2)
else:
sub1 = (means.shape[1]*2,1,2*i+1)
sub2 = (means.shape[1]*2,1,2*i+2)
sub1 = (means.shape[1] * 2, 1, 2 * i + 1)
sub2 = (means.shape[1] * 2, 1, 2 * i + 2)
# mean and variance plot
a = fig.add_subplot(*sub1)
a.plot(means, c='k', alpha=.3)
plots.extend(a.plot(x, means.T[i], c=next(colors), label=r"$\mathbf{{X_{{{}}}}}$".format(i)))
a.fill_between(x,
means.T[i] - 2 * np.sqrt(variances.T[i]),
means.T[i] + 2 * np.sqrt(variances.T[i]),
facecolor=plots[-1].get_color(),
alpha=.3)
a.legend(borderaxespad=0.)
a.plot(means, c="k", alpha=0.3)
plots.extend(
a.plot(
x, means.T[i], c=next(colors), label=r"$\mathbf{{X_{{{}}}}}$".format(i)
)
)
a.fill_between(
x,
means.T[i] - 2 * np.sqrt(variances.T[i]),
means.T[i] + 2 * np.sqrt(variances.T[i]),
facecolor=plots[-1].get_color(),
alpha=0.3,
)
a.legend(borderaxespad=0.0)
a.set_xlim(x.min(), x.max())
if i < means.shape[1] - 1:
a.set_xticklabels('')
a.set_xticklabels("")
# binary prob plot
a = fig.add_subplot(*sub2)
a.bar(x,gamma[:,i],bottom=0.,linewidth=1.,width=1.0,align='center')
a.bar(x, gamma[:, i], bottom=0.0, linewidth=1.0, width=1.0, align="center")
a.set_xlim(x.min(), x.max())
a.set_ylim([0.,1.])
a.set_ylim([0.0, 1.0])
pb.draw()
fig.tight_layout(h_pad=.01) # , rect=(0, 0, 1, .95))
fig.tight_layout(h_pad=0.01) # , rect=(0, 0, 1, .95))
return fig