mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-06 10:32:39 +02:00
[visualize] some adjustments to vector_show
This commit is contained in:
parent
5826ac6b73
commit
22221565bb
2 changed files with 13 additions and 9 deletions
|
|
@ -161,6 +161,7 @@ def bgplvm_oil(optimize=True, verbose=1, plot=True, N=200, Q=7, num_inducing=40,
|
||||||
import GPy
|
import GPy
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from ..util.misc import param_to_array
|
from ..util.misc import param_to_array
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
_np.random.seed(0)
|
_np.random.seed(0)
|
||||||
data = GPy.util.datasets.oil()
|
data = GPy.util.datasets.oil()
|
||||||
|
|
@ -174,11 +175,10 @@ def bgplvm_oil(optimize=True, verbose=1, plot=True, N=200, Q=7, num_inducing=40,
|
||||||
m.optimize('scg', messages=verbose, max_iters=max_iters, gtol=.05)
|
m.optimize('scg', messages=verbose, max_iters=max_iters, gtol=.05)
|
||||||
|
|
||||||
if plot:
|
if plot:
|
||||||
y = m.Y
|
|
||||||
fig, (latent_axes, sense_axes) = plt.subplots(1, 2)
|
fig, (latent_axes, sense_axes) = plt.subplots(1, 2)
|
||||||
m.plot_latent(ax=latent_axes, labels=m.data_labels)
|
m.plot_latent(ax=latent_axes, labels=m.data_labels)
|
||||||
data_show = GPy.plotting.matplot_dep.visualize.vector_show(y)
|
data_show = GPy.plotting.matplot_dep.visualize.vector_show(np.zeros((m.Y.shape[1], 1)))
|
||||||
lvm_visualizer = GPy.plotting.matplot_dep.visualize.lvm_dimselect(param_to_array(m.X.mean), # @UnusedVariable
|
lvm_visualizer = GPy.plotting.matplot_dep.visualize.lvm_dimselect(param_to_array(m.X.mean)[0:1,:], # @UnusedVariable
|
||||||
m, data_show, latent_axes=latent_axes, sense_axes=sense_axes)
|
m, data_show, latent_axes=latent_axes, sense_axes=sense_axes)
|
||||||
raw_input('Press enter to finish')
|
raw_input('Press enter to finish')
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
|
|
||||||
|
|
@ -74,13 +74,17 @@ class vector_show(matplotlib_show):
|
||||||
"""
|
"""
|
||||||
def __init__(self, vals, axes=None):
|
def __init__(self, vals, axes=None):
|
||||||
matplotlib_show.__init__(self, vals, axes)
|
matplotlib_show.__init__(self, vals, axes)
|
||||||
self.handle = self.axes.plot(np.arange(0, len(vals))[:, None], self.vals)
|
#assert vals.ndim == 2, "Please give a vector in [n x 1] to plot"
|
||||||
|
#assert vals.shape[1] == 1, "only showing a vector in one dimension"
|
||||||
|
self.size = vals.size
|
||||||
|
|
||||||
|
self.handle = self.axes.plot(np.arange(0, vals.size)[:, None], self.vals)[0]
|
||||||
|
|
||||||
def modify(self, vals):
|
def modify(self, vals):
|
||||||
self.vals = vals.copy()
|
self.vals = vals.copy()
|
||||||
for handle, vals in zip(self.handle, self.vals.T):
|
xdata, ydata = self.handle.get_data()
|
||||||
xdata, ydata = handle.get_data()
|
assert vals.size == self.size, "values passed into modify changed size! vals:{} != in:{}".format(vals.size, self.size)
|
||||||
handle.set_data(xdata, vals)
|
self.handle.set_data(xdata, self.vals)
|
||||||
self.axes.figure.canvas.draw()
|
self.axes.figure.canvas.draw()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -94,7 +98,7 @@ class lvm(matplotlib_show):
|
||||||
:type data_visualize: visualize.data_show type.
|
:type data_visualize: visualize.data_show type.
|
||||||
:param latent_axes: the axes where the latent visualization should be plotted.
|
:param latent_axes: the axes where the latent visualization should be plotted.
|
||||||
"""
|
"""
|
||||||
if vals == None:
|
if vals is None:
|
||||||
if isinstance(model.X, VariationalPosterior):
|
if isinstance(model.X, VariationalPosterior):
|
||||||
vals = param_to_array(model.X.mean)
|
vals = param_to_array(model.X.mean)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue