[mrd] missing data implemented, and plotting better

This commit is contained in:
Max Zwiessele 2014-05-16 15:12:19 +01:00
parent 94c84a23a3
commit 01d6b91f90
3 changed files with 41 additions and 26 deletions

View file

@ -288,7 +288,6 @@ class Parameterized(Parameterizable):
self._connect_parameters() self._connect_parameters()
self._connect_fixes() self._connect_fixes()
self._notify_parent_change() self._notify_parent_change()
self.parameters_changed() self.parameters_changed()
except Exception as e: except Exception as e:
print "WARNING: caught exception {!s}, trying to continue".format(e) print "WARNING: caught exception {!s}, trying to continue".format(e)

View file

@ -202,6 +202,17 @@ class VarDTCMissingData(LatentFunctionInference):
def set_limit(self, limit): def set_limit(self, limit):
self._Y.limit = limit self._Y.limit = limit
def __getstate__(self):
# has to be overridden, as Cacher objects cannot be pickled.
return self._Y.limit, self._inan
def __setstate__(self, state):
# has to be overridden, as Cacher objects cannot be pickled.
from ...util.caching import Cacher
self.limit = state[0]
self._inan = state[1]
self._Y = Cacher(self._subarray_computations, self.limit)
def _subarray_computations(self, Y): def _subarray_computations(self, Y):
if self._inan is None: if self._inan is None:
inan = np.isnan(Y) inan = np.isnan(Y)

View file

@ -203,6 +203,7 @@ class MRD(SparseGP):
fig = pylab.figure(num=fignum) fig = pylab.figure(num=fignum)
sharex_ax = None sharex_ax = None
sharey_ax = None sharey_ax = None
plots = []
for i, g in enumerate(self.bgplvms): for i, g in enumerate(self.bgplvms):
try: try:
if sharex: if sharex:
@ -219,15 +220,16 @@ class MRD(SparseGP):
ax = axes[i] ax = axes[i]
else: else:
raise ValueError("Need one axes per latent dimension input_dim") raise ValueError("Need one axes per latent dimension input_dim")
plotf(i, g, ax) plots.append(plotf(i, g, ax))
if sharey_ax is not None: if sharey_ax is not None:
pylab.setp(ax.get_yticklabels(), visible=False) pylab.setp(ax.get_yticklabels(), visible=False)
pylab.draw() pylab.draw()
if axes is None: if axes is None:
try:
fig.tight_layout() fig.tight_layout()
return fig except:
else: pass
return pylab.gcf() return plots
def predict(self, Xnew, full_cov=False, Y_metadata=None, kern=None, Yindex=0): def predict(self, Xnew, full_cov=False, Y_metadata=None, kern=None, Yindex=0):
""" """
@ -259,10 +261,10 @@ class MRD(SparseGP):
""" """
if titles is None: if titles is None:
titles = [r'${}$'.format(name) for name in self.names] titles = [r'${}$'.format(name) for name in self.names]
ymax = reduce(max, [np.ceil(max(g.kernels.input_sensitivity())) for g in self.bgplvms]) ymax = reduce(max, [np.ceil(max(g.kern.input_sensitivity())) for g in self.bgplvms])
def plotf(i, g, ax): def plotf(i, g, ax):
ax.set_ylim([0,ymax]) ax.set_ylim([0,ymax])
g.kernels.plot_ARD(ax=ax, title=titles[i], *args, **kwargs) return g.kern.plot_ARD(ax=ax, title=titles[i], *args, **kwargs)
fig = self._handle_plotting(fignum, ax, plotf, sharex=sharex, sharey=sharey) fig = self._handle_plotting(fignum, ax, plotf, sharex=sharex, sharey=sharey)
return fig return fig
@ -270,30 +272,33 @@ class MRD(SparseGP):
resolution=50, ax=None, marker='o', s=40, resolution=50, ax=None, marker='o', s=40,
fignum=None, plot_inducing=True, legend=True, fignum=None, plot_inducing=True, legend=True,
plot_limits=None, plot_limits=None,
aspect='auto', updates=False, predict_kwargs=dict(Yindex=0), imshow_kwargs={}): aspect='auto', updates=False, predict_kwargs={}, imshow_kwargs={}):
""" """
see plotting.matplot_dep.dim_reduction_plots.plot_latent
if predict_kwargs is None, will plot latent spaces for 0th dataset (and kernel), otherwise give
predict_kwargs=dict(Yindex='index') for plotting only the latent space of dataset with 'index'.
""" """
import sys import sys
assert "matplotlib" in sys.modules, "matplotlib package has not been imported." assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from ..plotting.matplot_dep import dim_reduction_plots from ..plotting.matplot_dep import dim_reduction_plots
if "Yindex" not in predict_kwargs:
return dim_reduction_plots.plot_latent(self, labels, which_indices, predict_kwargs['Yindex'] = 0
if ax is None:
fig = pylab.figure(num=fignum)
ax = fig.add_subplot(111)
else:
fig = ax.figure
plot = dim_reduction_plots.plot_latent(self, labels, which_indices,
resolution, ax, marker, s, resolution, ax, marker, s,
fignum, plot_inducing, legend, fignum, plot_inducing, legend,
plot_limits, aspect, updates, predict_kwargs, imshow_kwargs) plot_limits, aspect, updates, predict_kwargs, imshow_kwargs)
ax.set_title(self.bgplvms[predict_kwargs['Yindex']].name)
def _debug_plot(self): try:
self.plot_X_1d()
fig = pylab.figure("MRD DEBUG PLOT", figsize=(4 * len(self.bgplvms), 9))
fig.clf()
axes = [fig.add_subplot(3, len(self.bgplvms), i + 1) for i in range(len(self.bgplvms))]
self.plot_X(ax=axes)
axes = [fig.add_subplot(3, len(self.bgplvms), i + len(self.bgplvms) + 1) for i in range(len(self.bgplvms))]
self.plot_latent(ax=axes)
axes = [fig.add_subplot(3, len(self.bgplvms), i + 2 * len(self.bgplvms) + 1) for i in range(len(self.bgplvms))]
self.plot_scales(ax=axes)
pylab.draw()
fig.tight_layout() fig.tight_layout()
except:
pass
return plot
def __getstate__(self): def __getstate__(self):
# TODO: # TODO: