mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-08 11:32:39 +02:00
[ploting] dim reduction
This commit is contained in:
parent
382645ff37
commit
efc1f4413c
2 changed files with 5 additions and 5 deletions
|
|
@ -31,7 +31,7 @@ def plot_latent(model, labels=None, which_indices=None,
|
||||||
resolution=50, ax=None, marker='o', s=40,
|
resolution=50, ax=None, marker='o', s=40,
|
||||||
fignum=None, plot_inducing=False, legend=True,
|
fignum=None, plot_inducing=False, legend=True,
|
||||||
plot_limits=None,
|
plot_limits=None,
|
||||||
aspect='auto', updates=False, **kwargs):
|
aspect='auto', updates=False, predict_kwargs={}, imshow_kwargs={}):
|
||||||
"""
|
"""
|
||||||
:param labels: a np.array of size model.num_data containing labels for the points (can be number, strings, etc)
|
:param labels: a np.array of size model.num_data containing labels for the points (can be number, strings, etc)
|
||||||
:param resolution: the resolution of the grid on which to evaluate the predictive variance
|
:param resolution: the resolution of the grid on which to evaluate the predictive variance
|
||||||
|
|
@ -60,7 +60,7 @@ def plot_latent(model, labels=None, which_indices=None,
|
||||||
def plot_function(x):
|
def plot_function(x):
|
||||||
Xtest_full = np.zeros((x.shape[0], model.X.shape[1]))
|
Xtest_full = np.zeros((x.shape[0], model.X.shape[1]))
|
||||||
Xtest_full[:, [input_1, input_2]] = x
|
Xtest_full[:, [input_1, input_2]] = x
|
||||||
_, var = model.predict(Xtest_full)
|
_, var = model.predict(Xtest_full, **predict_kwargs)
|
||||||
var = var[:, :1]
|
var = var[:, :1]
|
||||||
return np.log(var)
|
return np.log(var)
|
||||||
|
|
||||||
|
|
@ -81,7 +81,7 @@ def plot_latent(model, labels=None, which_indices=None,
|
||||||
view = ImshowController(ax, plot_function,
|
view = ImshowController(ax, plot_function,
|
||||||
(xmin, ymin, xmax, ymax),
|
(xmin, ymin, xmax, ymax),
|
||||||
resolution, aspect=aspect, interpolation='bilinear',
|
resolution, aspect=aspect, interpolation='bilinear',
|
||||||
cmap=pb.cm.binary, **kwargs)
|
cmap=pb.cm.binary, **imshow_kwargs)
|
||||||
|
|
||||||
# make sure labels are in order of input:
|
# make sure labels are in order of input:
|
||||||
ulabels = []
|
ulabels = []
|
||||||
|
|
|
||||||
|
|
@ -97,7 +97,7 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
|
||||||
|
|
||||||
for d in which_data_ycols:
|
for d in which_data_ycols:
|
||||||
plots['gpplot'] = gpplot(Xnew, m[:, d], lower[:, d], upper[:, d], ax=ax, edgecol=linecol, fillcol=fillcol)
|
plots['gpplot'] = gpplot(Xnew, m[:, d], lower[:, d], upper[:, d], ax=ax, edgecol=linecol, fillcol=fillcol)
|
||||||
plots['dataplot'] = ax.plot(X[which_data_rows,free_dims], Y[which_data_rows, d], data_symbol, mew=1.5)
|
if not plot_raw: plots['dataplot'] = ax.plot(X[which_data_rows,free_dims], Y[which_data_rows, d], data_symbol, mew=1.5)
|
||||||
|
|
||||||
#optionally plot some samples
|
#optionally plot some samples
|
||||||
if samples: #NOTE not tested with fixed_inputs
|
if samples: #NOTE not tested with fixed_inputs
|
||||||
|
|
@ -151,7 +151,7 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
|
||||||
for d in which_data_ycols:
|
for d in which_data_ycols:
|
||||||
m_d = m[:,d].reshape(resolution, resolution).T
|
m_d = m[:,d].reshape(resolution, resolution).T
|
||||||
plots['contour'] = ax.contour(x, y, m_d, levels, vmin=m.min(), vmax=m.max(), cmap=pb.cm.jet)
|
plots['contour'] = ax.contour(x, y, m_d, levels, vmin=m.min(), vmax=m.max(), cmap=pb.cm.jet)
|
||||||
plots['dataplot'] = ax.scatter(X[which_data_rows, free_dims[0]], X[which_data_rows, free_dims[1]], 40, Y[which_data_rows, d], cmap=pb.cm.jet, vmin=m.min(), vmax=m.max(), linewidth=0.)
|
if not plot_raw: plots['dataplot'] = ax.scatter(X[which_data_rows, free_dims[0]], X[which_data_rows, free_dims[1]], 40, Y[which_data_rows, d], cmap=pb.cm.jet, vmin=m.min(), vmax=m.max(), linewidth=0.)
|
||||||
|
|
||||||
#set the limits of the plot to some sensible values
|
#set the limits of the plot to some sensible values
|
||||||
ax.set_xlim(xmin[0], xmax[0])
|
ax.set_xlim(xmin[0], xmax[0])
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue