[tests working now?]

This commit is contained in:
mzwiessele 2015-10-07 00:52:47 +01:00
parent 5290e4bf0e
commit 7ebdc698f6
34 changed files with 42 additions and 33 deletions

View file

@ -43,20 +43,20 @@ if config.get('plotting', 'library') is not 'none':
from ..models import GPLVM, BayesianGPLVM, bayesian_gplvm_minibatch, SSGPLVM, SSMRD
GPLVM.plot_latent = gpy_plot.latent_plots.plot_latent
GPLVM.plot_latent_scatter = gpy_plot.latent_plots.plot_latent_scatter
GPLVM.plot_latent_inducing = gpy_plot.latent_plots.plot_latent_inducing
GPLVM.plot_scatter = gpy_plot.latent_plots.plot_latent_scatter
GPLVM.plot_inducing = gpy_plot.latent_plots.plot_latent_inducing
GPLVM.plot_steepest_gradient_map = gpy_plot.latent_plots.plot_steepest_gradient_map
BayesianGPLVM.plot_latent = gpy_plot.latent_plots.plot_latent
BayesianGPLVM.plot_latent_scatter = gpy_plot.latent_plots.plot_latent_scatter
BayesianGPLVM.plot_latent_inducing = gpy_plot.latent_plots.plot_latent_inducing
BayesianGPLVM.plot_scatter = gpy_plot.latent_plots.plot_latent_scatter
BayesianGPLVM.plot_inducing = gpy_plot.latent_plots.plot_latent_inducing
BayesianGPLVM.plot_steepest_gradient_map = gpy_plot.latent_plots.plot_steepest_gradient_map
bayesian_gplvm_minibatch.BayesianGPLVMMiniBatch.plot_latent = gpy_plot.latent_plots.plot_latent
bayesian_gplvm_minibatch.BayesianGPLVMMiniBatch.plot_latent_scatter = gpy_plot.latent_plots.plot_latent_scatter
bayesian_gplvm_minibatch.BayesianGPLVMMiniBatch.plot_latent_inducing = gpy_plot.latent_plots.plot_latent_inducing
bayesian_gplvm_minibatch.BayesianGPLVMMiniBatch.plot_scatter = gpy_plot.latent_plots.plot_latent_scatter
bayesian_gplvm_minibatch.BayesianGPLVMMiniBatch.plot_inducing = gpy_plot.latent_plots.plot_latent_inducing
bayesian_gplvm_minibatch.BayesianGPLVMMiniBatch.plot_steepest_gradient_map = gpy_plot.latent_plots.plot_steepest_gradient_map
SSGPLVM.plot_latent = gpy_plot.latent_plots.plot_latent
SSGPLVM.plot_latent_scatter = gpy_plot.latent_plots.plot_latent_scatter
SSGPLVM.plot_latent_inducing = gpy_plot.latent_plots.plot_latent_inducing
SSGPLVM.plot_scatter = gpy_plot.latent_plots.plot_latent_scatter
SSGPLVM.plot_inducing = gpy_plot.latent_plots.plot_latent_inducing
SSGPLVM.plot_steepest_gradient_map = gpy_plot.latent_plots.plot_steepest_gradient_map
from ..kern import Kern

View file

@ -52,7 +52,7 @@ def _plot_latent_scatter(canvas, X, visible_dims, labels, marker, num_samples, p
Tango.reset()
X, labels = subsample_X(X, labels, num_samples)
scatters = []
generate_colors = 'color' not in kwargs
generate_colors = 'color' not in kwargs
for x, y, z, this_label, _, m in scatter_label_generator(labels, X, visible_dims, marker):
update_not_existing_kwargs(kwargs, pl.defaults.latent_scatter)
if generate_colors:
@ -89,7 +89,7 @@ def plot_latent_scatter(self, labels=None,
labels = np.ones(self.num_data)
legend = False
else:
legend = find_best_layout_for_subplots(len(np.unique(labels)))
legend = find_best_layout_for_subplots(len(np.unique(labels)))[1]
scatters = _plot_latent_scatter(canvas, X, sig_dims, labels, marker, num_samples, projection=projection, **kwargs)
if projection == '3d':
return pl.show_canvas(canvas, dict(scatter=scatters), legend=legend,
@ -126,9 +126,9 @@ def plot_latent_inducing(self,
if 'color' not in kwargs:
kwargs['color'] = 'white'
canvas, kwargs = pl.get_new_canvas(projection=projection, **kwargs)
X, _, _ = get_x_y_var(self)
labels = np.ones(self.num_data)
scatters = _plot_latent_scatter(canvas, X, sig_dims, labels, marker, num_samples, projection=projection, **kwargs)
Z = self.Z.values
labels = np.array(['inducing'] * Z.shape[0])
scatters = _plot_latent_scatter(canvas, Z, sig_dims, labels, marker, num_samples, projection=projection, **kwargs)
if projection == '3d':
return pl.show_canvas(canvas, dict(scatter=scatters), legend=legend,
xlabel='latent dimension %i' % input_1,

View file

@ -138,7 +138,6 @@ def scatter_label_generator(labels, X, visible_dims, marker=None):
for lab in labels:
if not lab in ulabels:
ulabels.append(lab)
if marker is not None:
marker = itertools.cycle(list(marker))
else:
@ -154,19 +153,20 @@ def scatter_label_generator(labels, X, visible_dims, marker=None):
except:
input_1 = visible_dims
input_2 = input_3 = None
for ul in ulabels:
if type(ul) is np.string_:
this_label = ul
elif type(ul) is np.int64:
this_label = 'class %i' % ul
else:
from numbers import Number
if isinstance(ul, str):
try:
this_label = unicode(ul)
except NameError:
#python3
this_label = ul
elif isinstance(ul, Number):
this_label = 'class {!s}'.format(ul)
else:
this_label = ul
if marker is not None:
m = next(marker)

View file

@ -76,7 +76,13 @@ class MatplotlibPlots(AbstractPlottingLibrary):
legend_ontop(ax, ncol=legend, fontdict=fontdict)
if zlim is not None:
ax.set_zlim(zlim)
#ax.figure.show()
ax.figure.canvas.draw()
ax.figure.show()
#try:
# ax.figure.tight_layout()
#except:
# # couldnt do tight layout, python 2.7 on MacOSX
# pass
ax.figure.canvas.draw()
return plots

View file

@ -35,7 +35,7 @@ def legend_ontop(ax, mode='expand', ncol=3, fontdict=None):
from mpl_toolkits.axes_grid1 import make_axes_locatable
handles, labels = ax.get_legend_handles_labels()
divider = make_axes_locatable(ax)
cax = divider.append_axes("top", "5%", pad="1%")
cax = divider.append_axes("top", "5%", pad=0)
lgd = cax.legend(handles, labels, bbox_to_anchor=(0., 0., 1., 1.), loc=3,
ncol=ncol, mode=mode, borderaxespad=0., prop=fontdict or {})
cax.set_axis_off()