From 11f784d47275733debaa0e2cc5c6fb869dfa972f Mon Sep 17 00:00:00 2001 From: Ricardo Date: Tue, 28 Jan 2014 13:40:24 +0000 Subject: [PATCH] Plotting functions modified --- GPy/core/parameterization/priors.py | 24 ++++-------- GPy/core/parameterization/variational.py | 47 +++--------------------- 2 files changed, 13 insertions(+), 58 deletions(-) diff --git a/GPy/core/parameterization/priors.py b/GPy/core/parameterization/priors.py index 9614ca53..f1208f18 100644 --- a/GPy/core/parameterization/priors.py +++ b/GPy/core/parameterization/priors.py @@ -3,7 +3,6 @@ import numpy as np -import pylab as pb from scipy.special import gammaln, digamma from ...util.linalg import pdinv from domains import _REAL, _POSITIVE @@ -12,16 +11,14 @@ import weakref class Prior: domain = None - + def pdf(self, x): return np.exp(self.lnpdf(x)) def plot(self): - rvs = self.rvs(1000) - pb.hist(rvs, 100, normed=True) - xmin, xmax = pb.xlim() - xx = np.linspace(xmin, xmax, 1000) - pb.plot(xx, self.pdf(xx), 'r', linewidth=2) + assert "matplotlib" in sys.modules, "matplotlib package has not been imported." + from ..plotting.matplot_dep import priors_plots + priors_plots.univariate_plot(self) class Gaussian(Prior): @@ -153,16 +150,9 @@ class MultivariateGaussian: return np.random.multivariate_normal(self.mu, self.var, n) def plot(self): - if self.input_dim == 2: - rvs = self.rvs(200) - pb.plot(rvs[:, 0], rvs[:, 1], 'kx', mew=1.5) - xmin, xmax = pb.xlim() - ymin, ymax = pb.ylim() - xx, yy = np.mgrid[xmin:xmax:100j, ymin:ymax:100j] - xflat = np.vstack((xx.flatten(), yy.flatten())).T - zz = self.pdf(xflat).reshape(100, 100) - pb.contour(xx, yy, zz, linewidths=2) - + assert "matplotlib" in sys.modules, "matplotlib package has not been imported." + from ..plotting.matplot_dep import priors_plots + priors_plots.multivariate_plot(self) def gamma_from_EV(E, V): warnings.warn("use Gamma.from_EV to create Gamma Prior", FutureWarning) diff --git a/GPy/core/parameterization/variational.py b/GPy/core/parameterization/variational.py index 25718fbf..e9868b82 100644 --- a/GPy/core/parameterization/variational.py +++ b/GPy/core/parameterization/variational.py @@ -11,7 +11,7 @@ from ...util.misc import param_to_array class Normal(Parameterized): ''' Normal distribution for variational approximations. - + holds the means and variances for a factorizing multivariate normal distribution ''' def __init__(self, means, variances, name='latent space'): @@ -20,47 +20,12 @@ class Normal(Parameterized): self.variances = Param('variance', variances) self.add_parameters(self.means, self.variances) - def plot(self, fignum=None, ax=None, colors=None): + def plot(self, *args): """ Plot latent space X in 1D: - - if fig is given, create input_dim subplots in fig and plot in these - - if ax is given plot input_dim 1D latent space plots of X into each `axis` - - if neither fig nor ax is given create a figure with fignum and plot in there - - colors: - colors of different latent space dimensions input_dim - + See GPy.plotting.matplot_dep.variational_plots """ - import pylab - if ax is None: - fig = pylab.figure(num=fignum, figsize=(8, min(12, (2 * self.means.shape[1])))) - if colors is None: - colors = pylab.gca()._get_lines.color_cycle - pylab.clf() - else: - colors = iter(colors) - plots = [] - means, variances = param_to_array(self.means, self.variances) - x = np.arange(means.shape[0]) - for i in range(means.shape[1]): - if ax is None: - a = fig.add_subplot(means.shape[1], 1, i + 1) - elif isinstance(ax, (tuple, list)): - a = ax[i] - else: - raise ValueError("Need one ax per latent dimension input_dim") - a.plot(means, c='k', alpha=.3) - plots.extend(a.plot(x, means.T[i], c=colors.next(), label=r"$\mathbf{{X_{{{}}}}}$".format(i))) - a.fill_between(x, - means.T[i] - 2 * np.sqrt(variances.T[i]), - means.T[i] + 2 * np.sqrt(variances.T[i]), - facecolor=plots[-1].get_color(), - alpha=.3) - a.legend(borderaxespad=0.) - a.set_xlim(x.min(), x.max()) - if i < means.shape[1] - 1: - a.set_xticklabels('') - pylab.draw() - fig.tight_layout(h_pad=.01) # , rect=(0, 0, 1, .95)) - return fig + assert "matplotlib" in sys.modules, "matplotlib package has not been imported." + from ..plotting.matplot_dep import variational_plots + return variational_plots.plot(self,*args)