From 01d6b91f9079dfc4aab01c6531d5e9ff4a6b326e Mon Sep 17 00:00:00 2001 From: Max Zwiessele Date: Fri, 16 May 2014 15:12:19 +0100 Subject: [PATCH] [mrd] missing data implemented, and plotting better --- GPy/core/parameterization/parameterized.py | 1 - .../latent_function_inference/var_dtc.py | 11 ++++ GPy/models/mrd.py | 55 ++++++++++--------- 3 files changed, 41 insertions(+), 26 deletions(-) diff --git a/GPy/core/parameterization/parameterized.py b/GPy/core/parameterization/parameterized.py index 48eb2ddc..dd9a07c4 100644 --- a/GPy/core/parameterization/parameterized.py +++ b/GPy/core/parameterization/parameterized.py @@ -288,7 +288,6 @@ class Parameterized(Parameterizable): self._connect_parameters() self._connect_fixes() self._notify_parent_change() - self.parameters_changed() except Exception as e: print "WARNING: caught exception {!s}, trying to continue".format(e) diff --git a/GPy/inference/latent_function_inference/var_dtc.py b/GPy/inference/latent_function_inference/var_dtc.py index 3043a7e8..b5e6787b 100644 --- a/GPy/inference/latent_function_inference/var_dtc.py +++ b/GPy/inference/latent_function_inference/var_dtc.py @@ -202,6 +202,17 @@ class VarDTCMissingData(LatentFunctionInference): def set_limit(self, limit): self._Y.limit = limit + def __getstate__(self): + # has to be overridden, as Cacher objects cannot be pickled. + return self._Y.limit, self._inan + + def __setstate__(self, state): + # has to be overridden, as Cacher objects cannot be pickled. + from ...util.caching import Cacher + self.limit = state[0] + self._inan = state[1] + self._Y = Cacher(self._subarray_computations, self.limit) + def _subarray_computations(self, Y): if self._inan is None: inan = np.isnan(Y) diff --git a/GPy/models/mrd.py b/GPy/models/mrd.py index 73b267ba..c8067c01 100644 --- a/GPy/models/mrd.py +++ b/GPy/models/mrd.py @@ -203,6 +203,7 @@ class MRD(SparseGP): fig = pylab.figure(num=fignum) sharex_ax = None sharey_ax = None + plots = [] for i, g in enumerate(self.bgplvms): try: if sharex: @@ -219,15 +220,16 @@ class MRD(SparseGP): ax = axes[i] else: raise ValueError("Need one axes per latent dimension input_dim") - plotf(i, g, ax) + plots.append(plotf(i, g, ax)) if sharey_ax is not None: pylab.setp(ax.get_yticklabels(), visible=False) pylab.draw() if axes is None: - fig.tight_layout() - return fig - else: - return pylab.gcf() + try: + fig.tight_layout() + except: + pass + return plots def predict(self, Xnew, full_cov=False, Y_metadata=None, kern=None, Yindex=0): """ @@ -259,10 +261,10 @@ class MRD(SparseGP): """ if titles is None: titles = [r'${}$'.format(name) for name in self.names] - ymax = reduce(max, [np.ceil(max(g.kernels.input_sensitivity())) for g in self.bgplvms]) + ymax = reduce(max, [np.ceil(max(g.kern.input_sensitivity())) for g in self.bgplvms]) def plotf(i, g, ax): ax.set_ylim([0,ymax]) - g.kernels.plot_ARD(ax=ax, title=titles[i], *args, **kwargs) + return g.kern.plot_ARD(ax=ax, title=titles[i], *args, **kwargs) fig = self._handle_plotting(fignum, ax, plotf, sharex=sharex, sharey=sharey) return fig @@ -270,30 +272,33 @@ class MRD(SparseGP): resolution=50, ax=None, marker='o', s=40, fignum=None, plot_inducing=True, legend=True, plot_limits=None, - aspect='auto', updates=False, predict_kwargs=dict(Yindex=0), imshow_kwargs={}): + aspect='auto', updates=False, predict_kwargs={}, imshow_kwargs={}): """ + see plotting.matplot_dep.dim_reduction_plots.plot_latent + if predict_kwargs is None, will plot latent spaces for 0th dataset (and kernel), otherwise give + predict_kwargs=dict(Yindex='index') for plotting only the latent space of dataset with 'index'. """ import sys assert "matplotlib" in sys.modules, "matplotlib package has not been imported." from ..plotting.matplot_dep import dim_reduction_plots + if "Yindex" not in predict_kwargs: + predict_kwargs['Yindex'] = 0 + if ax is None: + fig = pylab.figure(num=fignum) + ax = fig.add_subplot(111) + else: + fig = ax.figure + plot = dim_reduction_plots.plot_latent(self, labels, which_indices, + resolution, ax, marker, s, + fignum, plot_inducing, legend, + plot_limits, aspect, updates, predict_kwargs, imshow_kwargs) + ax.set_title(self.bgplvms[predict_kwargs['Yindex']].name) + try: + fig.tight_layout() + except: + pass - return dim_reduction_plots.plot_latent(self, labels, which_indices, - resolution, ax, marker, s, - fignum, plot_inducing, legend, - plot_limits, aspect, updates, predict_kwargs, imshow_kwargs) - - def _debug_plot(self): - self.plot_X_1d() - fig = pylab.figure("MRD DEBUG PLOT", figsize=(4 * len(self.bgplvms), 9)) - fig.clf() - axes = [fig.add_subplot(3, len(self.bgplvms), i + 1) for i in range(len(self.bgplvms))] - self.plot_X(ax=axes) - axes = [fig.add_subplot(3, len(self.bgplvms), i + len(self.bgplvms) + 1) for i in range(len(self.bgplvms))] - self.plot_latent(ax=axes) - axes = [fig.add_subplot(3, len(self.bgplvms), i + 2 * len(self.bgplvms) + 1) for i in range(len(self.bgplvms))] - self.plot_scales(ax=axes) - pylab.draw() - fig.tight_layout() + return plot def __getstate__(self): # TODO: