diff --git a/GPy/plotting/matplot_dep/variational_plots.py b/GPy/plotting/matplot_dep/variational_plots.py index 3f20efeb..17481a80 100644 --- a/GPy/plotting/matplot_dep/variational_plots.py +++ b/GPy/plotting/matplot_dep/variational_plots.py @@ -1,4 +1,6 @@ -from matplotlib import pyplot as pb, numpy as np +from matplotlib import pyplot as pb +import numpy as np + def plot(parameterized, fignum=None, ax=None, colors=None, figsize=(12, 6)): """ @@ -17,6 +19,7 @@ def plot(parameterized, fignum=None, ax=None, colors=None, figsize=(12, 6)): if colors is None: from ..Tango import mediumList from itertools import cycle + colors = cycle(mediumList) pb.clf() else: @@ -33,21 +36,30 @@ def plot(parameterized, fignum=None, ax=None, colors=None, figsize=(12, 6)): a = ax[i] else: raise ValueError("Need one ax per latent dimension input_dim") - bg_lines.append(a.plot(means, c='k', alpha=.3)) - lines.extend(a.plot(x, means.T[i], c=next(colors), label=r"$\mathbf{{X_{{{}}}}}$".format(i))) - fills.append(a.fill_between(x, - means.T[i] - 2 * np.sqrt(variances.T[i]), - means.T[i] + 2 * np.sqrt(variances.T[i]), - facecolor=lines[-1].get_color(), - alpha=.3)) - a.legend(borderaxespad=0.) + bg_lines.append(a.plot(means, c="k", alpha=0.3)) + lines.extend( + a.plot( + x, means.T[i], c=next(colors), label=r"$\mathbf{{X_{{{}}}}}$".format(i) + ) + ) + fills.append( + a.fill_between( + x, + means.T[i] - 2 * np.sqrt(variances.T[i]), + means.T[i] + 2 * np.sqrt(variances.T[i]), + facecolor=lines[-1].get_color(), + alpha=0.3, + ) + ) + a.legend(borderaxespad=0.0) a.set_xlim(x.min(), x.max()) if i < means.shape[1] - 1: - a.set_xticklabels('') + a.set_xticklabels("") pb.draw() - a.figure.tight_layout(h_pad=.01) # , rect=(0, 0, 1, .95)) + a.figure.tight_layout(h_pad=0.01) # , rect=(0, 0, 1, .95)) return dict(lines=lines, fills=fills, bg_lines=bg_lines) + def plot_SpikeSlab(parameterized, fignum=None, ax=None, colors=None, side_by_side=True): """ Plot latent space X in 1D: @@ -62,45 +74,60 @@ def plot_SpikeSlab(parameterized, fignum=None, ax=None, colors=None, side_by_sid """ if ax is None: if side_by_side: - fig = pb.figure(num=fignum, figsize=(16, min(12, (2 * parameterized.mean.shape[1])))) + fig = pb.figure( + num=fignum, figsize=(16, min(12, (2 * parameterized.mean.shape[1]))) + ) else: - fig = pb.figure(num=fignum, figsize=(8, min(12, (2 * parameterized.mean.shape[1])))) + fig = pb.figure( + num=fignum, figsize=(8, min(12, (2 * parameterized.mean.shape[1]))) + ) if colors is None: from ..Tango import mediumList from itertools import cycle + colors = cycle(mediumList) pb.clf() else: colors = iter(colors) plots = [] - means, variances, gamma = parameterized.mean, parameterized.variance, parameterized.binary_prob + means, variances, gamma = ( + parameterized.mean, + parameterized.variance, + parameterized.binary_prob, + ) x = np.arange(means.shape[0]) for i in range(means.shape[1]): if side_by_side: - sub1 = (means.shape[1],2,2*i+1) - sub2 = (means.shape[1],2,2*i+2) + sub1 = (means.shape[1], 2, 2 * i + 1) + sub2 = (means.shape[1], 2, 2 * i + 2) else: - sub1 = (means.shape[1]*2,1,2*i+1) - sub2 = (means.shape[1]*2,1,2*i+2) + sub1 = (means.shape[1] * 2, 1, 2 * i + 1) + sub2 = (means.shape[1] * 2, 1, 2 * i + 2) # mean and variance plot a = fig.add_subplot(*sub1) - a.plot(means, c='k', alpha=.3) - plots.extend(a.plot(x, means.T[i], c=next(colors), label=r"$\mathbf{{X_{{{}}}}}$".format(i))) - a.fill_between(x, - means.T[i] - 2 * np.sqrt(variances.T[i]), - means.T[i] + 2 * np.sqrt(variances.T[i]), - facecolor=plots[-1].get_color(), - alpha=.3) - a.legend(borderaxespad=0.) + a.plot(means, c="k", alpha=0.3) + plots.extend( + a.plot( + x, means.T[i], c=next(colors), label=r"$\mathbf{{X_{{{}}}}}$".format(i) + ) + ) + a.fill_between( + x, + means.T[i] - 2 * np.sqrt(variances.T[i]), + means.T[i] + 2 * np.sqrt(variances.T[i]), + facecolor=plots[-1].get_color(), + alpha=0.3, + ) + a.legend(borderaxespad=0.0) a.set_xlim(x.min(), x.max()) if i < means.shape[1] - 1: - a.set_xticklabels('') + a.set_xticklabels("") # binary prob plot a = fig.add_subplot(*sub2) - a.bar(x,gamma[:,i],bottom=0.,linewidth=1.,width=1.0,align='center') + a.bar(x, gamma[:, i], bottom=0.0, linewidth=1.0, width=1.0, align="center") a.set_xlim(x.min(), x.max()) - a.set_ylim([0.,1.]) + a.set_ylim([0.0, 1.0]) pb.draw() - fig.tight_layout(h_pad=.01) # , rect=(0, 0, 1, .95)) + fig.tight_layout(h_pad=0.01) # , rect=(0, 0, 1, .95)) return fig