mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-06 02:24:17 +02:00
plotting, allot of plotting
This commit is contained in:
parent
6f9c97ee72
commit
ce728d8465
9 changed files with 97 additions and 39 deletions
|
|
@ -30,7 +30,8 @@ def most_significant_input_dimensions(model, which_indices):
|
|||
def plot_latent(model, labels=None, which_indices=None,
|
||||
resolution=50, ax=None, marker='o', s=40,
|
||||
fignum=None, plot_inducing=False, legend=True,
|
||||
aspect='auto', updates=False):
|
||||
plot_limits=None,
|
||||
aspect='auto', updates=False, **kwargs):
|
||||
"""
|
||||
:param labels: a np.array of size model.num_data containing labels for the points (can be number, strings, etc)
|
||||
:param resolution: the resolution of the grid on which to evaluate the predictive variance
|
||||
|
|
@ -38,6 +39,8 @@ def plot_latent(model, labels=None, which_indices=None,
|
|||
if ax is None:
|
||||
fig = pb.figure(num=fignum)
|
||||
ax = fig.add_subplot(111)
|
||||
else:
|
||||
fig = ax.figure
|
||||
Tango.reset()
|
||||
|
||||
if labels is None:
|
||||
|
|
@ -57,15 +60,28 @@ def plot_latent(model, labels=None, which_indices=None,
|
|||
def plot_function(x):
|
||||
Xtest_full = np.zeros((x.shape[0], model.X.shape[1]))
|
||||
Xtest_full[:, [input_1, input_2]] = x
|
||||
mu, var, low, up = model.predict(Xtest_full)
|
||||
_, var = model.predict(Xtest_full)
|
||||
var = var[:, :1]
|
||||
return np.log(var)
|
||||
|
||||
#Create an IMshow controller that can re-plot the latent space shading at a good resolution
|
||||
if plot_limits is None:
|
||||
xmin, ymin = X[:, [input_1, input_2]].min(0)
|
||||
xmax, ymax = X[:, [input_1, input_2]].max(0)
|
||||
x_r, y_r = xmax-xmin, ymax-ymin
|
||||
xmin -= .1*x_r
|
||||
xmax += .1*x_r
|
||||
ymin -= .1*y_r
|
||||
ymax += .1*y_r
|
||||
else:
|
||||
try:
|
||||
xmin, xmax, ymin, ymax = plot_limits
|
||||
except (TypeError, ValueError) as e:
|
||||
raise e.__class__, "Wrong plot limits: {} given -> need (xmin, xmax, ymin, ymax)".format(plot_limits)
|
||||
view = ImshowController(ax, plot_function,
|
||||
tuple(X[:, [input_1, input_2]].min(0)) + tuple(X[:, [input_1, input_2]].max(0)),
|
||||
(xmin, ymin, xmax, ymax),
|
||||
resolution, aspect=aspect, interpolation='bilinear',
|
||||
cmap=pb.cm.binary)
|
||||
cmap=pb.cm.binary, **kwargs)
|
||||
|
||||
# make sure labels are in order of input:
|
||||
ulabels = []
|
||||
|
|
@ -99,18 +115,31 @@ def plot_latent(model, labels=None, which_indices=None,
|
|||
if not np.all(labels == 1.) and legend:
|
||||
ax.legend(loc=0, numpoints=1)
|
||||
|
||||
#ax.set_xlim(xmin[0], xmax[0])
|
||||
#ax.set_ylim(xmin[1], xmax[1])
|
||||
ax.grid(b=False) # remove the grid if present, it doesn't look good
|
||||
ax.set_aspect('auto') # set a nice aspect ratio
|
||||
|
||||
if plot_inducing:
|
||||
Z = param_to_array(model.Z)
|
||||
ax.plot(Z[:, input_1], Z[:, input_2], '^w')
|
||||
|
||||
ax.set_xlim((xmin, xmax))
|
||||
ax.set_ylim((ymin, ymax))
|
||||
|
||||
try:
|
||||
fig.canvas.draw()
|
||||
fig.tight_layout()
|
||||
fig.canvas.draw()
|
||||
except Exception as e:
|
||||
print "Could not invoke tight layout: {}".format(e)
|
||||
pass
|
||||
|
||||
if updates:
|
||||
ax.figure.canvas.show()
|
||||
try:
|
||||
ax.figure.canvas.show()
|
||||
except Exception as e:
|
||||
print "Could not invoke show: {}".format(e)
|
||||
raw_input('Enter to continue')
|
||||
view.deactivate()
|
||||
return ax
|
||||
|
||||
def plot_magnification(model, labels=None, which_indices=None,
|
||||
|
|
@ -186,7 +215,7 @@ def plot_magnification(model, labels=None, which_indices=None,
|
|||
ax.plot(model.Z[:, input_1], model.Z[:, input_2], '^w')
|
||||
|
||||
if updates:
|
||||
ax.figure.canvas.show()
|
||||
fig.canvas.show()
|
||||
raw_input('Enter to continue')
|
||||
|
||||
pb.title('Magnification Factor')
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ class AxisChangedController(AxisEventController):
|
|||
Constructor
|
||||
'''
|
||||
super(AxisChangedController, self).__init__(ax)
|
||||
self._lim_ratio_threshold = update_lim or .8
|
||||
self._lim_ratio_threshold = update_lim or .95
|
||||
self._x_lim = self.ax.get_xlim()
|
||||
self._y_lim = self.ax.get_ylim()
|
||||
|
||||
|
|
@ -80,6 +80,10 @@ class AxisChangedController(AxisEventController):
|
|||
class BufferedAxisChangedController(AxisChangedController):
|
||||
def __init__(self, ax, plot_function, plot_limits, resolution=50, update_lim=None, **kwargs):
|
||||
"""
|
||||
Buffered axis changed controller. Controls the buffer and handles update events for when the axes changed.
|
||||
|
||||
Updated plotting will be after first reload (first time will be within plot limits, after that the limits will be buffered)
|
||||
|
||||
:param plot_function:
|
||||
function to use for creating image for plotting (return ndarray-like)
|
||||
plot_function gets called with (2D!) Xtest grid if replotting required
|
||||
|
|
@ -91,11 +95,13 @@ class BufferedAxisChangedController(AxisChangedController):
|
|||
"""
|
||||
super(BufferedAxisChangedController, self).__init__(ax, update_lim=update_lim)
|
||||
self.plot_function = plot_function
|
||||
xmin, xmax = self._x_lim # self._compute_buffered(*self._x_lim)
|
||||
ymin, ymax = self._y_lim # self._compute_buffered(*self._y_lim)
|
||||
xmin, ymin, xmax, ymax = plot_limits#self._x_lim # self._compute_buffered(*self._x_lim)
|
||||
# imshow acts on the limits of the plot, this is why we need to override the limits here, to make sure the right plot limits are used:
|
||||
self._x_lim = xmin, xmax
|
||||
self._y_lim = ymin, ymax
|
||||
self.resolution = resolution
|
||||
self._not_init = False
|
||||
self.view = self._init_view(self.ax, self.recompute_X(), xmin, xmax, ymin, ymax, **kwargs)
|
||||
self.view = self._init_view(self.ax, self.recompute_X(buffered=False), xmin, xmax, ymin, ymax, **kwargs)
|
||||
self._not_init = True
|
||||
|
||||
def update(self, ax):
|
||||
|
|
@ -111,14 +117,16 @@ class BufferedAxisChangedController(AxisChangedController):
|
|||
def update_view(self, view, X, xmin, xmax, ymin, ymax):
|
||||
raise NotImplementedError('update view given in here')
|
||||
|
||||
def get_grid(self):
|
||||
xmin, xmax = self._compute_buffered(*self._x_lim)
|
||||
ymin, ymax = self._compute_buffered(*self._y_lim)
|
||||
def get_grid(self, buffered=True):
|
||||
if buffered: comp = self._compute_buffered
|
||||
else: comp = lambda a,b: (a,b)
|
||||
xmin, xmax = comp(*self._x_lim)
|
||||
ymin, ymax = comp(*self._y_lim)
|
||||
x, y = numpy.mgrid[xmin:xmax:1j * self.resolution, ymin:ymax:1j * self.resolution]
|
||||
return numpy.hstack((x.flatten()[:, None], y.flatten()[:, None]))
|
||||
|
||||
def recompute_X(self):
|
||||
X = self.plot_function(self.get_grid())
|
||||
def recompute_X(self, buffered=True):
|
||||
X = self.plot_function(self.get_grid(buffered))
|
||||
if isinstance(X, (tuple, list)):
|
||||
for x in X:
|
||||
x.shape = [self.resolution, self.resolution]
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import numpy
|
|||
|
||||
|
||||
class ImshowController(BufferedAxisChangedController):
|
||||
def __init__(self, ax, plot_function, plot_limits, resolution=50, update_lim=.5, **kwargs):
|
||||
def __init__(self, ax, plot_function, plot_limits, resolution=50, update_lim=.8, **kwargs):
|
||||
"""
|
||||
:param plot_function:
|
||||
function to use for creating image for plotting (return ndarray-like)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ import GPy
|
|||
import numpy as np
|
||||
import matplotlib as mpl
|
||||
import time
|
||||
from ...util.misc import param_to_array
|
||||
from GPy.core.parameterization.variational import VariationalPosterior
|
||||
try:
|
||||
import visual
|
||||
visual_available = True
|
||||
|
|
@ -72,12 +74,13 @@ class vector_show(matplotlib_show):
|
|||
"""
|
||||
def __init__(self, vals, axes=None):
|
||||
matplotlib_show.__init__(self, vals, axes)
|
||||
self.handle = self.axes.plot(np.arange(0, len(vals))[:, None], self.vals.T)[0]
|
||||
self.handle = self.axes.plot(np.arange(0, len(vals))[:, None], self.vals)
|
||||
|
||||
def modify(self, vals):
|
||||
self.vals = vals.copy()
|
||||
xdata, ydata = self.handle.get_data()
|
||||
self.handle.set_data(xdata, self.vals.T)
|
||||
for handle, vals in zip(self.handle, self.vals.T):
|
||||
xdata, ydata = handle.get_data()
|
||||
handle.set_data(xdata, vals)
|
||||
self.axes.figure.canvas.draw()
|
||||
|
||||
|
||||
|
|
@ -91,8 +94,12 @@ class lvm(matplotlib_show):
|
|||
:param latent_axes: the axes where the latent visualization should be plotted.
|
||||
"""
|
||||
if vals == None:
|
||||
vals = model.X[0]
|
||||
|
||||
if isinstance(model.X, VariationalPosterior):
|
||||
vals = param_to_array(model.X.mean)
|
||||
else:
|
||||
vals = param_to_array(model.X)
|
||||
|
||||
vals = param_to_array(vals)
|
||||
matplotlib_show.__init__(self, vals, axes=latent_axes)
|
||||
|
||||
if isinstance(latent_axes,mpl.axes.Axes):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue