[visualize] some adjustments to vector_show

This commit is contained in:
Max Zwiessele 2014-05-12 12:05:06 +01:00
parent 5826ac6b73
commit 22221565bb
2 changed files with 13 additions and 9 deletions

View file

@ -161,6 +161,7 @@ def bgplvm_oil(optimize=True, verbose=1, plot=True, N=200, Q=7, num_inducing=40,
import GPy import GPy
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from ..util.misc import param_to_array from ..util.misc import param_to_array
import numpy as np
_np.random.seed(0) _np.random.seed(0)
data = GPy.util.datasets.oil() data = GPy.util.datasets.oil()
@ -174,11 +175,10 @@ def bgplvm_oil(optimize=True, verbose=1, plot=True, N=200, Q=7, num_inducing=40,
m.optimize('scg', messages=verbose, max_iters=max_iters, gtol=.05) m.optimize('scg', messages=verbose, max_iters=max_iters, gtol=.05)
if plot: if plot:
y = m.Y
fig, (latent_axes, sense_axes) = plt.subplots(1, 2) fig, (latent_axes, sense_axes) = plt.subplots(1, 2)
m.plot_latent(ax=latent_axes, labels=m.data_labels) m.plot_latent(ax=latent_axes, labels=m.data_labels)
data_show = GPy.plotting.matplot_dep.visualize.vector_show(y) data_show = GPy.plotting.matplot_dep.visualize.vector_show(np.zeros((m.Y.shape[1], 1)))
lvm_visualizer = GPy.plotting.matplot_dep.visualize.lvm_dimselect(param_to_array(m.X.mean), # @UnusedVariable lvm_visualizer = GPy.plotting.matplot_dep.visualize.lvm_dimselect(param_to_array(m.X.mean)[0:1,:], # @UnusedVariable
m, data_show, latent_axes=latent_axes, sense_axes=sense_axes) m, data_show, latent_axes=latent_axes, sense_axes=sense_axes)
raw_input('Press enter to finish') raw_input('Press enter to finish')
plt.close(fig) plt.close(fig)

View file

@ -74,13 +74,17 @@ class vector_show(matplotlib_show):
""" """
def __init__(self, vals, axes=None): def __init__(self, vals, axes=None):
matplotlib_show.__init__(self, vals, axes) matplotlib_show.__init__(self, vals, axes)
self.handle = self.axes.plot(np.arange(0, len(vals))[:, None], self.vals) #assert vals.ndim == 2, "Please give a vector in [n x 1] to plot"
#assert vals.shape[1] == 1, "only showing a vector in one dimension"
self.size = vals.size
self.handle = self.axes.plot(np.arange(0, vals.size)[:, None], self.vals)[0]
def modify(self, vals): def modify(self, vals):
self.vals = vals.copy() self.vals = vals.copy()
for handle, vals in zip(self.handle, self.vals.T): xdata, ydata = self.handle.get_data()
xdata, ydata = handle.get_data() assert vals.size == self.size, "values passed into modify changed size! vals:{} != in:{}".format(vals.size, self.size)
handle.set_data(xdata, vals) self.handle.set_data(xdata, self.vals)
self.axes.figure.canvas.draw() self.axes.figure.canvas.draw()
@ -94,12 +98,12 @@ class lvm(matplotlib_show):
:type data_visualize: visualize.data_show type. :type data_visualize: visualize.data_show type.
:param latent_axes: the axes where the latent visualization should be plotted. :param latent_axes: the axes where the latent visualization should be plotted.
""" """
if vals == None: if vals is None:
if isinstance(model.X, VariationalPosterior): if isinstance(model.X, VariationalPosterior):
vals = param_to_array(model.X.mean) vals = param_to_array(model.X.mean)
else: else:
vals = param_to_array(model.X) vals = param_to_array(model.X)
vals = param_to_array(vals) vals = param_to_array(vals)
matplotlib_show.__init__(self, vals, axes=latent_axes) matplotlib_show.__init__(self, vals, axes=latent_axes)