plot_latent enhancements

This commit is contained in:
Max Zwiessele 2013-06-28 11:00:42 +01:00
parent 08a902ed6c
commit 7325e319b4
2 changed files with 32 additions and 20 deletions

View file

@ -24,8 +24,7 @@ class BayesianGPLVM(SparseGP, GPLVM):
""" """
def __init__(self, likelihood_or_Y, input_dim, X=None, X_variance=None, init='PCA', num_inducing=10, def __init__(self, likelihood_or_Y, input_dim, X=None, X_variance=None, init='PCA', num_inducing=10,
Z=None, kernel=None, oldpsave=10, _debug=False, Z=None, kernel=None, **kwargs):
**kwargs):
if type(likelihood_or_Y) is np.ndarray: if type(likelihood_or_Y) is np.ndarray:
likelihood = Gaussian(likelihood_or_Y) likelihood = Gaussian(likelihood_or_Y)
else: else:
@ -117,7 +116,7 @@ class BayesianGPLVM(SparseGP, GPLVM):
return np.hstack((self.dbound_dmuS.flatten(), self.dbound_dZtheta)) return np.hstack((self.dbound_dmuS.flatten(), self.dbound_dZtheta))
def plot_latent(self, *args, **kwargs): def plot_latent(self, *args, **kwargs):
return plot_latent.plot_latent_indices(self, *args, **kwargs) return plot_latent.plot_latent(self, *args, **kwargs)
def do_test_latents(self, Y): def do_test_latents(self, Y):
""" """

View file

@ -48,7 +48,7 @@ class MRD(Model):
kernels=None, initx='PCA', kernels=None, initx='PCA',
initz='permute', _debug=False, **kw): initz='permute', _debug=False, **kw):
if names is None: if names is None:
self.names = ["{}".format(i + 1) for i in range(len(likelihood_or_Y_list))] self.names = ["{}".format(i) for i in range(len(likelihood_or_Y_list))]
# sort out the kernels # sort out the kernels
if kernels is None: if kernels is None:
@ -281,12 +281,23 @@ class MRD(Model):
self.Z = Z self.Z = Z
return Z return Z
def _handle_plotting(self, fignum, axes, plotf): def _handle_plotting(self, fignum, axes, plotf, sharex=False, sharey=False):
if axes is None: if axes is None:
fig = pylab.figure(num=fignum) fig = pylab.figure(num=fignum)
sharex_ax = None
sharey_ax = None
for i, g in enumerate(self.bgplvms): for i, g in enumerate(self.bgplvms):
try:
if sharex:
sharex_ax = ax # @UndefinedVariable
sharex = False # dont set twice
if sharey:
sharey_ax = ax # @UndefinedVariable
sharey = False # dont set twice
except:
pass
if axes is None: if axes is None:
ax = fig.add_subplot(1, len(self.bgplvms), i + 1) ax = fig.add_subplot(1, len(self.bgplvms), i + 1, sharex=sharex_ax, sharey=sharey_ax)
elif isinstance(axes, (tuple, list)): elif isinstance(axes, (tuple, list)):
ax = axes[i] ax = axes[i]
else: else:
@ -306,16 +317,27 @@ class MRD(Model):
fig = self._handle_plotting(fignum, ax, lambda i, g, ax: ax.imshow(g.X)) fig = self._handle_plotting(fignum, ax, lambda i, g, ax: ax.imshow(g.X))
return fig return fig
def plot_predict(self, fignum=None, ax=None, **kwargs): def plot_predict(self, fignum=None, ax=None, sharex=False, sharey=False, **kwargs):
fig = self._handle_plotting(fignum, ax, lambda i, g, ax: ax.imshow(g. predict(g.X)[0], **kwargs)) fig = self._handle_plotting(fignum,
ax,
lambda i, g, ax: ax.imshow(g. predict(g.X)[0], **kwargs),
sharex=sharex, sharey=sharey)
return fig return fig
def plot_scales(self, fignum=None, ax=None, *args, **kwargs): def plot_scales(self, fignum=None, ax=None, titles=None, sharex=False, sharey=True, *args, **kwargs):
fig = self._handle_plotting(fignum, ax, lambda i, g, ax: g.kern.plot_ARD(ax=ax, title=r'$Y_{}$'.format(i), *args, **kwargs)) """
:param:`titles` :
titles for axes of datasets
"""
if titles is None:
titles = [r'${}$'.format(name) for name in self.names]
def plotf(i, g, ax):
g.kern.plot_ARD(ax=ax, title=titles[i], *args, **kwargs)
fig = self._handle_plotting(fignum, ax, plotf, sharex=sharex, sharey=sharey)
return fig return fig
def plot_latent(self, fignum=None, ax=None, *args, **kwargs): def plot_latent(self, fignum=None, ax=None, *args, **kwargs):
fig = self.gref.plot_latent(*args, **kwargs) # self._handle_plotting(fignum, ax, lambda i, g, ax: g.plot_latent(ax=ax, *args, **kwargs)) fig = self.gref.plot_latent(fignum=fignum, ax=ax, *args, **kwargs) # self._handle_plotting(fignum, ax, lambda i, g, ax: g.plot_latent(ax=ax, *args, **kwargs))
return fig return fig
def _debug_plot(self): def _debug_plot(self):
@ -331,13 +353,4 @@ class MRD(Model):
pylab.draw() pylab.draw()
fig.tight_layout() fig.tight_layout()
def _debug_optimize(self, opt='scg', maxiters=5000, itersteps=10):
iters = 0
optstep = lambda: self.optimize(opt, messages=1, max_f_eval=itersteps)
self._debug_plot()
raw_input("enter to start debug")
while iters < maxiters:
optstep()
self._debug_plot()
iters += itersteps