mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 04:22:38 +02:00
plot_latent enhancements
This commit is contained in:
parent
08a902ed6c
commit
7325e319b4
2 changed files with 32 additions and 20 deletions
|
|
@ -24,8 +24,7 @@ class BayesianGPLVM(SparseGP, GPLVM):
|
|||
|
||||
"""
|
||||
def __init__(self, likelihood_or_Y, input_dim, X=None, X_variance=None, init='PCA', num_inducing=10,
|
||||
Z=None, kernel=None, oldpsave=10, _debug=False,
|
||||
**kwargs):
|
||||
Z=None, kernel=None, **kwargs):
|
||||
if type(likelihood_or_Y) is np.ndarray:
|
||||
likelihood = Gaussian(likelihood_or_Y)
|
||||
else:
|
||||
|
|
@ -117,7 +116,7 @@ class BayesianGPLVM(SparseGP, GPLVM):
|
|||
return np.hstack((self.dbound_dmuS.flatten(), self.dbound_dZtheta))
|
||||
|
||||
def plot_latent(self, *args, **kwargs):
|
||||
return plot_latent.plot_latent_indices(self, *args, **kwargs)
|
||||
return plot_latent.plot_latent(self, *args, **kwargs)
|
||||
|
||||
def do_test_latents(self, Y):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ class MRD(Model):
|
|||
kernels=None, initx='PCA',
|
||||
initz='permute', _debug=False, **kw):
|
||||
if names is None:
|
||||
self.names = ["{}".format(i + 1) for i in range(len(likelihood_or_Y_list))]
|
||||
self.names = ["{}".format(i) for i in range(len(likelihood_or_Y_list))]
|
||||
|
||||
# sort out the kernels
|
||||
if kernels is None:
|
||||
|
|
@ -281,12 +281,23 @@ class MRD(Model):
|
|||
self.Z = Z
|
||||
return Z
|
||||
|
||||
def _handle_plotting(self, fignum, axes, plotf):
|
||||
def _handle_plotting(self, fignum, axes, plotf, sharex=False, sharey=False):
|
||||
if axes is None:
|
||||
fig = pylab.figure(num=fignum)
|
||||
sharex_ax = None
|
||||
sharey_ax = None
|
||||
for i, g in enumerate(self.bgplvms):
|
||||
try:
|
||||
if sharex:
|
||||
sharex_ax = ax # @UndefinedVariable
|
||||
sharex = False # dont set twice
|
||||
if sharey:
|
||||
sharey_ax = ax # @UndefinedVariable
|
||||
sharey = False # dont set twice
|
||||
except:
|
||||
pass
|
||||
if axes is None:
|
||||
ax = fig.add_subplot(1, len(self.bgplvms), i + 1)
|
||||
ax = fig.add_subplot(1, len(self.bgplvms), i + 1, sharex=sharex_ax, sharey=sharey_ax)
|
||||
elif isinstance(axes, (tuple, list)):
|
||||
ax = axes[i]
|
||||
else:
|
||||
|
|
@ -306,16 +317,27 @@ class MRD(Model):
|
|||
fig = self._handle_plotting(fignum, ax, lambda i, g, ax: ax.imshow(g.X))
|
||||
return fig
|
||||
|
||||
def plot_predict(self, fignum=None, ax=None, **kwargs):
|
||||
fig = self._handle_plotting(fignum, ax, lambda i, g, ax: ax.imshow(g. predict(g.X)[0], **kwargs))
|
||||
def plot_predict(self, fignum=None, ax=None, sharex=False, sharey=False, **kwargs):
|
||||
fig = self._handle_plotting(fignum,
|
||||
ax,
|
||||
lambda i, g, ax: ax.imshow(g. predict(g.X)[0], **kwargs),
|
||||
sharex=sharex, sharey=sharey)
|
||||
return fig
|
||||
|
||||
def plot_scales(self, fignum=None, ax=None, *args, **kwargs):
|
||||
fig = self._handle_plotting(fignum, ax, lambda i, g, ax: g.kern.plot_ARD(ax=ax, title=r'$Y_{}$'.format(i), *args, **kwargs))
|
||||
def plot_scales(self, fignum=None, ax=None, titles=None, sharex=False, sharey=True, *args, **kwargs):
|
||||
"""
|
||||
:param:`titles` :
|
||||
titles for axes of datasets
|
||||
"""
|
||||
if titles is None:
|
||||
titles = [r'${}$'.format(name) for name in self.names]
|
||||
def plotf(i, g, ax):
|
||||
g.kern.plot_ARD(ax=ax, title=titles[i], *args, **kwargs)
|
||||
fig = self._handle_plotting(fignum, ax, plotf, sharex=sharex, sharey=sharey)
|
||||
return fig
|
||||
|
||||
def plot_latent(self, fignum=None, ax=None, *args, **kwargs):
|
||||
fig = self.gref.plot_latent(*args, **kwargs) # self._handle_plotting(fignum, ax, lambda i, g, ax: g.plot_latent(ax=ax, *args, **kwargs))
|
||||
fig = self.gref.plot_latent(fignum=fignum, ax=ax, *args, **kwargs) # self._handle_plotting(fignum, ax, lambda i, g, ax: g.plot_latent(ax=ax, *args, **kwargs))
|
||||
return fig
|
||||
|
||||
def _debug_plot(self):
|
||||
|
|
@ -331,13 +353,4 @@ class MRD(Model):
|
|||
pylab.draw()
|
||||
fig.tight_layout()
|
||||
|
||||
def _debug_optimize(self, opt='scg', maxiters=5000, itersteps=10):
|
||||
iters = 0
|
||||
optstep = lambda: self.optimize(opt, messages=1, max_f_eval=itersteps)
|
||||
self._debug_plot()
|
||||
raw_input("enter to start debug")
|
||||
while iters < maxiters:
|
||||
optstep()
|
||||
self._debug_plot()
|
||||
iters += itersteps
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue