diff --git a/GPy/plotting/matplot_dep/dim_reduction_plots.py b/GPy/plotting/matplot_dep/dim_reduction_plots.py index f8413671..0ba082df 100644 --- a/GPy/plotting/matplot_dep/dim_reduction_plots.py +++ b/GPy/plotting/matplot_dep/dim_reduction_plots.py @@ -31,7 +31,7 @@ def plot_latent(model, labels=None, which_indices=None, resolution=50, ax=None, marker='o', s=40, fignum=None, plot_inducing=False, legend=True, plot_limits=None, - aspect='auto', updates=False, **kwargs): + aspect='auto', updates=False, predict_kwargs={}, imshow_kwargs={}): """ :param labels: a np.array of size model.num_data containing labels for the points (can be number, strings, etc) :param resolution: the resolution of the grid on which to evaluate the predictive variance @@ -60,7 +60,7 @@ def plot_latent(model, labels=None, which_indices=None, def plot_function(x): Xtest_full = np.zeros((x.shape[0], model.X.shape[1])) Xtest_full[:, [input_1, input_2]] = x - _, var = model.predict(Xtest_full) + _, var = model.predict(Xtest_full, **predict_kwargs) var = var[:, :1] return np.log(var) @@ -81,7 +81,7 @@ def plot_latent(model, labels=None, which_indices=None, view = ImshowController(ax, plot_function, (xmin, ymin, xmax, ymax), resolution, aspect=aspect, interpolation='bilinear', - cmap=pb.cm.binary, **kwargs) + cmap=pb.cm.binary, **imshow_kwargs) # make sure labels are in order of input: ulabels = [] diff --git a/GPy/plotting/matplot_dep/models_plots.py b/GPy/plotting/matplot_dep/models_plots.py index 84747d05..8f3e55b0 100644 --- a/GPy/plotting/matplot_dep/models_plots.py +++ b/GPy/plotting/matplot_dep/models_plots.py @@ -97,7 +97,7 @@ def plot_fit(model, plot_limits=None, which_data_rows='all', for d in which_data_ycols: plots['gpplot'] = gpplot(Xnew, m[:, d], lower[:, d], upper[:, d], ax=ax, edgecol=linecol, fillcol=fillcol) - plots['dataplot'] = ax.plot(X[which_data_rows,free_dims], Y[which_data_rows, d], data_symbol, mew=1.5) + if not plot_raw: plots['dataplot'] = ax.plot(X[which_data_rows,free_dims], Y[which_data_rows, d], data_symbol, mew=1.5) #optionally plot some samples if samples: #NOTE not tested with fixed_inputs @@ -151,7 +151,7 @@ def plot_fit(model, plot_limits=None, which_data_rows='all', for d in which_data_ycols: m_d = m[:,d].reshape(resolution, resolution).T plots['contour'] = ax.contour(x, y, m_d, levels, vmin=m.min(), vmax=m.max(), cmap=pb.cm.jet) - plots['dataplot'] = ax.scatter(X[which_data_rows, free_dims[0]], X[which_data_rows, free_dims[1]], 40, Y[which_data_rows, d], cmap=pb.cm.jet, vmin=m.min(), vmax=m.max(), linewidth=0.) + if not plot_raw: plots['dataplot'] = ax.scatter(X[which_data_rows, free_dims[0]], X[which_data_rows, free_dims[1]], 40, Y[which_data_rows, d], cmap=pb.cm.jet, vmin=m.min(), vmax=m.max(), linewidth=0.) #set the limits of the plot to some sensible values ax.set_xlim(xmin[0], xmax[0])