Plotting functions modified

This commit is contained in:
Ricardo 2014-01-28 13:40:24 +00:00
parent 58501dc301
commit 11f784d472
2 changed files with 13 additions and 58 deletions

View file

@ -3,7 +3,6 @@
import numpy as np
import pylab as pb
from scipy.special import gammaln, digamma
from ...util.linalg import pdinv
from domains import _REAL, _POSITIVE
@ -12,16 +11,14 @@ import weakref
class Prior:
domain = None
def pdf(self, x):
return np.exp(self.lnpdf(x))
def plot(self):
rvs = self.rvs(1000)
pb.hist(rvs, 100, normed=True)
xmin, xmax = pb.xlim()
xx = np.linspace(xmin, xmax, 1000)
pb.plot(xx, self.pdf(xx), 'r', linewidth=2)
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from ..plotting.matplot_dep import priors_plots
priors_plots.univariate_plot(self)
class Gaussian(Prior):
@ -153,16 +150,9 @@ class MultivariateGaussian:
return np.random.multivariate_normal(self.mu, self.var, n)
def plot(self):
if self.input_dim == 2:
rvs = self.rvs(200)
pb.plot(rvs[:, 0], rvs[:, 1], 'kx', mew=1.5)
xmin, xmax = pb.xlim()
ymin, ymax = pb.ylim()
xx, yy = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
xflat = np.vstack((xx.flatten(), yy.flatten())).T
zz = self.pdf(xflat).reshape(100, 100)
pb.contour(xx, yy, zz, linewidths=2)
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from ..plotting.matplot_dep import priors_plots
priors_plots.multivariate_plot(self)
def gamma_from_EV(E, V):
warnings.warn("use Gamma.from_EV to create Gamma Prior", FutureWarning)

View file

@ -11,7 +11,7 @@ from ...util.misc import param_to_array
class Normal(Parameterized):
'''
Normal distribution for variational approximations.
holds the means and variances for a factorizing multivariate normal distribution
'''
def __init__(self, means, variances, name='latent space'):
@ -20,47 +20,12 @@ class Normal(Parameterized):
self.variances = Param('variance', variances)
self.add_parameters(self.means, self.variances)
def plot(self, fignum=None, ax=None, colors=None):
def plot(self, *args):
"""
Plot latent space X in 1D:
- if fig is given, create input_dim subplots in fig and plot in these
- if ax is given plot input_dim 1D latent space plots of X into each `axis`
- if neither fig nor ax is given create a figure with fignum and plot in there
colors:
colors of different latent space dimensions input_dim
See GPy.plotting.matplot_dep.variational_plots
"""
import pylab
if ax is None:
fig = pylab.figure(num=fignum, figsize=(8, min(12, (2 * self.means.shape[1]))))
if colors is None:
colors = pylab.gca()._get_lines.color_cycle
pylab.clf()
else:
colors = iter(colors)
plots = []
means, variances = param_to_array(self.means, self.variances)
x = np.arange(means.shape[0])
for i in range(means.shape[1]):
if ax is None:
a = fig.add_subplot(means.shape[1], 1, i + 1)
elif isinstance(ax, (tuple, list)):
a = ax[i]
else:
raise ValueError("Need one ax per latent dimension input_dim")
a.plot(means, c='k', alpha=.3)
plots.extend(a.plot(x, means.T[i], c=colors.next(), label=r"$\mathbf{{X_{{{}}}}}$".format(i)))
a.fill_between(x,
means.T[i] - 2 * np.sqrt(variances.T[i]),
means.T[i] + 2 * np.sqrt(variances.T[i]),
facecolor=plots[-1].get_color(),
alpha=.3)
a.legend(borderaxespad=0.)
a.set_xlim(x.min(), x.max())
if i < means.shape[1] - 1:
a.set_xticklabels('')
pylab.draw()
fig.tight_layout(h_pad=.01) # , rect=(0, 0, 1, .95))
return fig
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from ..plotting.matplot_dep import variational_plots
return variational_plots.plot(self,*args)