merge devel branch in

This commit is contained in:
Zhenwen Dai 2014-05-21 10:38:34 +01:00
commit 52c0be1848
21 changed files with 595 additions and 134 deletions

View file

@ -31,7 +31,7 @@ def plot_latent(model, labels=None, which_indices=None,
resolution=50, ax=None, marker='o', s=40,
fignum=None, plot_inducing=False, legend=True,
plot_limits=None,
aspect='auto', updates=False, **kwargs):
aspect='auto', updates=False, predict_kwargs={}, imshow_kwargs={}):
"""
:param labels: a np.array of size model.num_data containing labels for the points (can be number, strings, etc)
:param resolution: the resolution of the grid on which to evaluate the predictive variance
@ -60,7 +60,7 @@ def plot_latent(model, labels=None, which_indices=None,
def plot_function(x):
Xtest_full = np.zeros((x.shape[0], model.X.shape[1]))
Xtest_full[:, [input_1, input_2]] = x
_, var = model.predict(Xtest_full)
_, var = model.predict(Xtest_full, **predict_kwargs)
var = var[:, :1]
return np.log(var)
@ -81,7 +81,7 @@ def plot_latent(model, labels=None, which_indices=None,
view = ImshowController(ax, plot_function,
(xmin, ymin, xmax, ymax),
resolution, aspect=aspect, interpolation='bilinear',
cmap=pb.cm.binary, **kwargs)
cmap=pb.cm.binary, **imshow_kwargs)
# make sure labels are in order of input:
ulabels = []

View file

@ -97,7 +97,7 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
for d in which_data_ycols:
plots['gpplot'] = gpplot(Xnew, m[:, d], lower[:, d], upper[:, d], ax=ax, edgecol=linecol, fillcol=fillcol)
plots['dataplot'] = ax.plot(X[which_data_rows,free_dims], Y[which_data_rows, d], data_symbol, mew=1.5)
if not plot_raw: plots['dataplot'] = ax.plot(X[which_data_rows,free_dims], Y[which_data_rows, d], data_symbol, mew=1.5)
#optionally plot some samples
if samples: #NOTE not tested with fixed_inputs
@ -151,7 +151,7 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
for d in which_data_ycols:
m_d = m[:,d].reshape(resolution, resolution).T
plots['contour'] = ax.contour(x, y, m_d, levels, vmin=m.min(), vmax=m.max(), cmap=pb.cm.jet)
plots['dataplot'] = ax.scatter(X[which_data_rows, free_dims[0]], X[which_data_rows, free_dims[1]], 40, Y[which_data_rows, d], cmap=pb.cm.jet, vmin=m.min(), vmax=m.max(), linewidth=0.)
if not plot_raw: plots['dataplot'] = ax.scatter(X[which_data_rows, free_dims[0]], X[which_data_rows, free_dims[1]], 40, Y[which_data_rows, d], cmap=pb.cm.jet, vmin=m.min(), vmax=m.max(), linewidth=0.)
#set the limits of the plot to some sensible values
ax.set_xlim(xmin[0], xmax[0])

View file

@ -88,7 +88,6 @@ class vector_show(matplotlib_show):
class lvm(matplotlib_show):
def __init__(self, vals, model, data_visualize, latent_axes=None, sense_axes=None, latent_index=[0,1], disable_drag=False):
"""Visualize a latent variable model
@ -150,7 +149,6 @@ class lvm(matplotlib_show):
pass
def on_click(self, event):
print 'click!'
if event.inaxes!=self.latent_axes: return
if self.disable_drag:
self.move_on = True
@ -228,11 +226,10 @@ class lvm_dimselect(lvm):
self.labels = labels
lvm.__init__(self,vals,model,data_visualize,latent_axes,sense_axes,latent_index)
self.show_sensitivities()
print "use left and right mouse butons to select dimensions"
print "use left and right mouse buttons to select dimensions"
def on_click(self, event):
if event.inaxes==self.sense_axes:
new_index = max(0,min(int(np.round(event.xdata-0.5)),self.model.input_dim-1))
if event.button == 1: