[mrd] plot_scales and plot_latent added

This commit is contained in:
Max Zwiessele 2016-03-01 10:02:02 +00:00
parent c4020cd2eb
commit 885d3722cc
8 changed files with 112 additions and 148 deletions

View file

@ -5,14 +5,14 @@ import numpy as np
import itertools, logging
from ..kern import Kern
from GPy.core.parameterization.variational import NormalPrior
from ..core.parameterization.variational import NormalPrior
from ..core.parameterization import Param
from paramz import ObsAr
from ..inference.latent_function_inference.var_dtc import VarDTC
from ..inference.latent_function_inference import InferenceMethodList
from ..likelihoods import Gaussian
from ..util.initialization import initialize_latent
from GPy.models.bayesian_gplvm_minibatch import BayesianGPLVMMiniBatch
from ..models.bayesian_gplvm_minibatch import BayesianGPLVMMiniBatch
class MRD(BayesianGPLVMMiniBatch):
"""
@ -215,40 +215,6 @@ class MRD(BayesianGPLVMMiniBatch):
Z = np.random.randn(self.num_inducing, self.input_dim) * X.var()
return Z
def _handle_plotting(self, fignum, axes, plotf, sharex=False, sharey=False):
import matplotlib.pyplot as plt
if axes is None:
fig = plt.figure(num=fignum)
sharex_ax = None
sharey_ax = None
plots = []
for i, g in enumerate(self.bgplvms):
try:
if sharex:
sharex_ax = ax # @UndefinedVariable
sharex = False # dont set twice
if sharey:
sharey_ax = ax # @UndefinedVariable
sharey = False # dont set twice
except:
pass
if axes is None:
ax = fig.add_subplot(1, len(self.bgplvms), i + 1, sharex=sharex_ax, sharey=sharey_ax)
elif isinstance(axes, (tuple, list, np.ndarray)):
ax = axes[i]
else:
raise ValueError("Need one axes per latent dimension input_dim")
plots.append(plotf(i, g, ax))
if sharey_ax is not None:
plt.setp(ax.get_yticklabels(), visible=False)
plt.draw()
if axes is None:
try:
fig.tight_layout()
except:
pass
return plots
def predict(self, Xnew, full_cov=False, Y_metadata=None, kern=None, Yindex=0):
"""
Prediction for data set Yindex[default=0].
@ -270,59 +236,53 @@ class MRD(BayesianGPLVMMiniBatch):
# sharex=sharex, sharey=sharey)
# return fig
def plot_scales(self, fignum=None, ax=None, titles=None, sharex=False, sharey=True, *args, **kwargs):
def plot_scales(self, titles=None, fig_kwargs=dict(figsize=None, tight_layout=True), **kwargs):
"""
TODO: Explain other parameters
Plot input sensitivity for all datasets, to see which input dimensions are
significant for which dataset.
:param titles: titles for axes of datasets
kwargs go into plot_ARD for each kernel.
"""
from ..plotting import plotting_library as pl
if titles is None:
titles = [r'${}$'.format(name) for name in self.names]
ymax = reduce(max, [np.ceil(max(g.kern.input_sensitivity())) for g in self.bgplvms])
def plotf(i, g, ax):
#ax.set_ylim([0,ymax])
return g.kern.plot_ARD(ax=ax, title=titles[i], *args, **kwargs)
fig = self._handle_plotting(fignum, ax, plotf, sharex=sharex, sharey=sharey)
return fig
M = len(self.bgplvms)
fig = pl().figure(rows=1, cols=M, **fig_kwargs)
plots = {}
for c in range(M):
canvas = self.bgplvms[c].kern.plot_ARD(title=titles[c], figure=fig, col=c+1, **kwargs)
plots[titles[c]] = canvas
pl().show_canvas(canvas)
return plots
def plot_latent(self, labels=None, which_indices=None,
resolution=50, ax=None, marker='o', s=40,
fignum=None, plot_inducing=True, legend=True,
resolution=60, legend=True,
plot_limits=None,
aspect='auto', updates=False, predict_kwargs={}, imshow_kwargs={}):
updates=False,
kern=None, marker='<>^vsd',
num_samples=1000, projection='2d',
predict_kwargs={},
scatter_kwargs=None, **imshow_kwargs):
"""
see plotting.matplot_dep.dim_reduction_plots.plot_latent
if predict_kwargs is None, will plot latent spaces for 0th dataset (and kernel), otherwise give
predict_kwargs=dict(Yindex='index') for plotting only the latent space of dataset with 'index'.
"""
import sys
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from matplotlib import pyplot as plt
from ..plotting.matplot_dep import dim_reduction_plots
from ..plotting.gpy_plot.latent_plots import plot_latent
if "Yindex" not in predict_kwargs:
predict_kwargs['Yindex'] = 0
Yindex = predict_kwargs['Yindex']
if ax is None:
fig = plt.figure(num=fignum)
ax = fig.add_subplot(111)
else:
fig = ax.figure
self.kern = self.bgplvms[Yindex].kern
self.likelihood = self.bgplvms[Yindex].likelihood
plot = dim_reduction_plots.plot_latent(self, labels, which_indices,
resolution, ax, marker, s,
fignum, plot_inducing, legend,
plot_limits, aspect, updates, predict_kwargs, imshow_kwargs)
ax.set_title(self.bgplvms[Yindex].name)
try:
fig.tight_layout()
except:
pass
return plot
return plot_latent(self, labels, which_indices, resolution, legend, plot_limits, updates, kern, marker, num_samples, projection, scatter_kwargs)
def __getstate__(self):
state = super(MRD, self).__getstate__()