mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-08 03:22:38 +02:00
[mrd] missing data implemented, and plotting better
This commit is contained in:
parent
94c84a23a3
commit
01d6b91f90
3 changed files with 41 additions and 26 deletions
|
|
@ -288,7 +288,6 @@ class Parameterized(Parameterizable):
|
||||||
self._connect_parameters()
|
self._connect_parameters()
|
||||||
self._connect_fixes()
|
self._connect_fixes()
|
||||||
self._notify_parent_change()
|
self._notify_parent_change()
|
||||||
|
|
||||||
self.parameters_changed()
|
self.parameters_changed()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print "WARNING: caught exception {!s}, trying to continue".format(e)
|
print "WARNING: caught exception {!s}, trying to continue".format(e)
|
||||||
|
|
|
||||||
|
|
@ -202,6 +202,17 @@ class VarDTCMissingData(LatentFunctionInference):
|
||||||
def set_limit(self, limit):
|
def set_limit(self, limit):
|
||||||
self._Y.limit = limit
|
self._Y.limit = limit
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
# has to be overridden, as Cacher objects cannot be pickled.
|
||||||
|
return self._Y.limit, self._inan
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
# has to be overridden, as Cacher objects cannot be pickled.
|
||||||
|
from ...util.caching import Cacher
|
||||||
|
self.limit = state[0]
|
||||||
|
self._inan = state[1]
|
||||||
|
self._Y = Cacher(self._subarray_computations, self.limit)
|
||||||
|
|
||||||
def _subarray_computations(self, Y):
|
def _subarray_computations(self, Y):
|
||||||
if self._inan is None:
|
if self._inan is None:
|
||||||
inan = np.isnan(Y)
|
inan = np.isnan(Y)
|
||||||
|
|
|
||||||
|
|
@ -203,6 +203,7 @@ class MRD(SparseGP):
|
||||||
fig = pylab.figure(num=fignum)
|
fig = pylab.figure(num=fignum)
|
||||||
sharex_ax = None
|
sharex_ax = None
|
||||||
sharey_ax = None
|
sharey_ax = None
|
||||||
|
plots = []
|
||||||
for i, g in enumerate(self.bgplvms):
|
for i, g in enumerate(self.bgplvms):
|
||||||
try:
|
try:
|
||||||
if sharex:
|
if sharex:
|
||||||
|
|
@ -219,15 +220,16 @@ class MRD(SparseGP):
|
||||||
ax = axes[i]
|
ax = axes[i]
|
||||||
else:
|
else:
|
||||||
raise ValueError("Need one axes per latent dimension input_dim")
|
raise ValueError("Need one axes per latent dimension input_dim")
|
||||||
plotf(i, g, ax)
|
plots.append(plotf(i, g, ax))
|
||||||
if sharey_ax is not None:
|
if sharey_ax is not None:
|
||||||
pylab.setp(ax.get_yticklabels(), visible=False)
|
pylab.setp(ax.get_yticklabels(), visible=False)
|
||||||
pylab.draw()
|
pylab.draw()
|
||||||
if axes is None:
|
if axes is None:
|
||||||
|
try:
|
||||||
fig.tight_layout()
|
fig.tight_layout()
|
||||||
return fig
|
except:
|
||||||
else:
|
pass
|
||||||
return pylab.gcf()
|
return plots
|
||||||
|
|
||||||
def predict(self, Xnew, full_cov=False, Y_metadata=None, kern=None, Yindex=0):
|
def predict(self, Xnew, full_cov=False, Y_metadata=None, kern=None, Yindex=0):
|
||||||
"""
|
"""
|
||||||
|
|
@ -259,10 +261,10 @@ class MRD(SparseGP):
|
||||||
"""
|
"""
|
||||||
if titles is None:
|
if titles is None:
|
||||||
titles = [r'${}$'.format(name) for name in self.names]
|
titles = [r'${}$'.format(name) for name in self.names]
|
||||||
ymax = reduce(max, [np.ceil(max(g.kernels.input_sensitivity())) for g in self.bgplvms])
|
ymax = reduce(max, [np.ceil(max(g.kern.input_sensitivity())) for g in self.bgplvms])
|
||||||
def plotf(i, g, ax):
|
def plotf(i, g, ax):
|
||||||
ax.set_ylim([0,ymax])
|
ax.set_ylim([0,ymax])
|
||||||
g.kernels.plot_ARD(ax=ax, title=titles[i], *args, **kwargs)
|
return g.kern.plot_ARD(ax=ax, title=titles[i], *args, **kwargs)
|
||||||
fig = self._handle_plotting(fignum, ax, plotf, sharex=sharex, sharey=sharey)
|
fig = self._handle_plotting(fignum, ax, plotf, sharex=sharex, sharey=sharey)
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
@ -270,30 +272,33 @@ class MRD(SparseGP):
|
||||||
resolution=50, ax=None, marker='o', s=40,
|
resolution=50, ax=None, marker='o', s=40,
|
||||||
fignum=None, plot_inducing=True, legend=True,
|
fignum=None, plot_inducing=True, legend=True,
|
||||||
plot_limits=None,
|
plot_limits=None,
|
||||||
aspect='auto', updates=False, predict_kwargs=dict(Yindex=0), imshow_kwargs={}):
|
aspect='auto', updates=False, predict_kwargs={}, imshow_kwargs={}):
|
||||||
"""
|
"""
|
||||||
|
see plotting.matplot_dep.dim_reduction_plots.plot_latent
|
||||||
|
if predict_kwargs is None, will plot latent spaces for 0th dataset (and kernel), otherwise give
|
||||||
|
predict_kwargs=dict(Yindex='index') for plotting only the latent space of dataset with 'index'.
|
||||||
"""
|
"""
|
||||||
import sys
|
import sys
|
||||||
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
||||||
from ..plotting.matplot_dep import dim_reduction_plots
|
from ..plotting.matplot_dep import dim_reduction_plots
|
||||||
|
if "Yindex" not in predict_kwargs:
|
||||||
return dim_reduction_plots.plot_latent(self, labels, which_indices,
|
predict_kwargs['Yindex'] = 0
|
||||||
|
if ax is None:
|
||||||
|
fig = pylab.figure(num=fignum)
|
||||||
|
ax = fig.add_subplot(111)
|
||||||
|
else:
|
||||||
|
fig = ax.figure
|
||||||
|
plot = dim_reduction_plots.plot_latent(self, labels, which_indices,
|
||||||
resolution, ax, marker, s,
|
resolution, ax, marker, s,
|
||||||
fignum, plot_inducing, legend,
|
fignum, plot_inducing, legend,
|
||||||
plot_limits, aspect, updates, predict_kwargs, imshow_kwargs)
|
plot_limits, aspect, updates, predict_kwargs, imshow_kwargs)
|
||||||
|
ax.set_title(self.bgplvms[predict_kwargs['Yindex']].name)
|
||||||
def _debug_plot(self):
|
try:
|
||||||
self.plot_X_1d()
|
|
||||||
fig = pylab.figure("MRD DEBUG PLOT", figsize=(4 * len(self.bgplvms), 9))
|
|
||||||
fig.clf()
|
|
||||||
axes = [fig.add_subplot(3, len(self.bgplvms), i + 1) for i in range(len(self.bgplvms))]
|
|
||||||
self.plot_X(ax=axes)
|
|
||||||
axes = [fig.add_subplot(3, len(self.bgplvms), i + len(self.bgplvms) + 1) for i in range(len(self.bgplvms))]
|
|
||||||
self.plot_latent(ax=axes)
|
|
||||||
axes = [fig.add_subplot(3, len(self.bgplvms), i + 2 * len(self.bgplvms) + 1) for i in range(len(self.bgplvms))]
|
|
||||||
self.plot_scales(ax=axes)
|
|
||||||
pylab.draw()
|
|
||||||
fig.tight_layout()
|
fig.tight_layout()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return plot
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
# TODO:
|
# TODO:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue