diff --git a/GPy/models/gplvm.py b/GPy/models/gplvm.py index b0034663..6416847c 100644 --- a/GPy/models/gplvm.py +++ b/GPy/models/gplvm.py @@ -6,7 +6,6 @@ import numpy as np from .. import kern from ..core import GP, Param from ..likelihoods import Gaussian -from .. import util class GPLVM(GP): @@ -42,18 +41,4 @@ class GPLVM(GP): def parameters_changed(self): super(GPLVM, self).parameters_changed() - self.X.gradient = self.kern.gradients_X(self.grad_dict['dL_dK'], self.X, None) - - def plot_latent(self, labels=None, which_indices=None, - resolution=50, ax=None, marker='o', s=40, - fignum=None, legend=True, - plot_limits=None, - aspect='auto', updates=False, **kwargs): - import sys - assert "matplotlib" in sys.modules, "matplotlib package has not been imported." - from ..plotting.matplot_dep import dim_reduction_plots - - return dim_reduction_plots.plot_latent(self, labels, which_indices, - resolution, ax, marker, s, - fignum, False, legend, - plot_limits, aspect, updates, **kwargs) + self.X.gradient = self.kern.gradients_X(self.grad_dict['dL_dK'], self.X, None) \ No newline at end of file diff --git a/GPy/plotting/__init__.py b/GPy/plotting/__init__.py index 4971620a..0b6c9c89 100644 --- a/GPy/plotting/__init__.py +++ b/GPy/plotting/__init__.py @@ -36,7 +36,7 @@ if config.get('plotting', 'library') is not 'none': GP.plot_samples = gpy_plot.gp_plots.plot_samples GP.plot = gpy_plot.gp_plots.plot GP.plot_f = gpy_plot.gp_plots.plot_f - #GP.plot_magnificaion = gpy_plot.latent_plots.plot_magnification + GP.plot_magnificaion = gpy_plot.latent_plots.plot_magnification from ..core import SparseGP SparseGP.plot_inducing = gpy_plot.data_plots.plot_inducing @@ -49,8 +49,9 @@ if config.get('plotting', 'library') is not 'none': #Kern.plot_covariance = gpy_plot.kern_plots.plot_kern # Variational plot! + - from . import matplot_dep + #from . import matplot_dep # Still to convert to new style: #GP.plot = matplot_dep.models_plots.plot_fit #GP.plot_f = matplot_dep.models_plots.plot_fit_f diff --git a/GPy/plotting/abstract_plotting_library.py b/GPy/plotting/abstract_plotting_library.py index f09f8d3c..a03d8591 100644 --- a/GPy/plotting/abstract_plotting_library.py +++ b/GPy/plotting/abstract_plotting_library.py @@ -57,7 +57,7 @@ class AbstractPlottingLibrary(object): return self.__defaults #=============================================================================== - def get_new_canvas(self, plot_3d=False, xlabel=None, ylabel=None, zlabel=None, title=None, legend=True, **kwargs): + def get_new_canvas(self, xlabel=None, ylabel=None, zlabel=None, title=None, projection='2d', legend=True, **kwargs): """ Return a canvas, kwargupdate for your plotting library. @@ -174,10 +174,17 @@ class AbstractPlottingLibrary(object): """ raise NotImplementedError("Implement all plot functions in AbstractPlottingLibrary in order to use your own plotting library") - def imshow(self, canvas, X, label=None, color=None, **kwargs): + def imshow(self, canvas, X, extent=None, label=None, plot_function=None, vmin=None, vmax=None, **kwargs): """ - Show the image stored in X on the canvas/ - + Show the image stored in X on the canvas. + + if X is a function, create an imshow controller to stream + the image. There is an imshow controller written for + mmatplotlib, which updates the imshow on changes in axis. + + Just ignore the plot_function, if you do not have the option + to have interactive changes. + the kwargs are plotting library specific kwargs! """ raise NotImplementedError("Implement all plot functions in AbstractPlottingLibrary in order to use your own plotting library") diff --git a/GPy/plotting/gpy_plot/latent_plots.py b/GPy/plotting/gpy_plot/latent_plots.py index 292cbd0e..23235c44 100644 --- a/GPy/plotting/gpy_plot/latent_plots.py +++ b/GPy/plotting/gpy_plot/latent_plots.py @@ -32,6 +32,7 @@ from . import pl from .plot_util import get_x_y_var, get_free_dims, get_which_data_ycols,\ get_which_data_rows, update_not_existing_kwargs, helper_predict_with_model,\ helper_for_plot_data +import itertools def plot_prediction_fit(self, plot_limits=None, which_data_rows='all', which_data_ycols='all', @@ -102,6 +103,164 @@ def _plot_prediction_fit(self, canvas, plot_limits=None, raise NotImplementedError("Cannot plot in more then one dimension.") return plots +def plot_magnification(self, labels=None, which_indices=None, + resolution=60, legend=True, + plot_limits=None, + updates=False, + mean=True, covariance=True, + kern=None, marker='<>^vsd', imshow_kwargs=None, **kwargs): + """ + Plot the magnification factor of the GP on the inputs. This is the + density of the GP as a gray scale. + + :param array-like labels: a label for each data point (row) of the inputs + :param (int, int) which_indices: which input dimensions to plot against each other + :param int resolution: the resolution at which we predict the magnification factor + :param bool legend: whether to plot the legend on the figure + :param plot_limits: the plot limits for the plot + :type plot_limits: (xmin, xmax, ymin, ymax) or ((xmin, xmax), (ymin, ymax)) + :param bool updates: if possible, make interactive updates using the specific library you are using + :param bool mean: use the mean of the Wishart embedding for the magnification factor + :param bool covariance: use the covariance of the Wishart embedding for the magnification factor + :param :py:class:`~GPy.kern.Kern` kern: the kernel to use for prediction + :param str marker: markers to use - cycle if more labels then markers are given + :param imshow_kwargs: the kwargs for the imshow (magnification factor) + :param kwargs: the kwargs for the scatter plots + """ + input_1, input_2 = self.get_most_significant_input_dimensions(which_indices) + + #fethch the data points X that we'd like to plot + X, _, _ = get_x_y_var(self) + + if plot_limits is None: + xmin, ymin = X[:, [input_1, input_2]].min(0) + xmax, ymax = X[:, [input_1, input_2]].max(0) + x_r, y_r = xmax-xmin, ymax-ymin + xmin -= .1*x_r + xmax += .1*x_r + ymin -= .1*y_r + ymax += .1*y_r + else: + try: + xmin, xmax, ymin, ymax = plot_limits + except (TypeError, ValueError) as e: + try: + xmin, xmax = plot_limits + ymin, ymax = xmin, xmax + except (TypeError, ValueError) as e: + raise e.__class__("Wrong plot limits: {} given -> need (xmin, xmax, ymin, ymax)".format(plot_limits)) + xlim = (xmin, xmax) + ylim = (ymin, ymax) + + from .. import Tango + Tango.reset() + + if labels is None: + labels = np.ones(self.num_data) + + if X.shape[0] > 1000: + print("Warning: subsampling X, as it has more samples then 1000. X.shape={!s}".format(X.shape)) + subsample = np.random.choice(X.shape[0], size=1000, replace=False) + X = X[subsample] + labels = labels[subsample] + #======================================================================= + # <<>> + # <<>> + # plt.close('all') + # fig, ax = plt.subplots(1,1) + # from GPy.plotting.matplot_dep.dim_reduction_plots import most_significant_input_dimensions + # import matplotlib.patches as mpatches + # i1, i2 = most_significant_input_dimensions(m, None) + # xmin, xmax = 100, -100 + # ymin, ymax = 100, -100 + # legend_handles = [] + # + # X = m.X.mean[:, [i1, i2]] + # X = m.X.variance[:, [i1, i2]] + # + # xmin = X[:,0].min(); xmax = X[:,0].max() + # ymin = X[:,1].min(); ymax = X[:,1].max() + # range_ = [[xmin, xmax], [ymin, ymax]] + # ul = np.unique(labels) + # + # for i, l in enumerate(ul): + # #cdict = dict(red =[(0., colors[i][0], colors[i][0]), (1., colors[i][0], colors[i][0])], + # # green=[(0., colors[i][0], colors[i][1]), (1., colors[i][1], colors[i][1])], + # # blue =[(0., colors[i][0], colors[i][2]), (1., colors[i][2], colors[i][2])], + # # alpha=[(0., 0., .0), (.5, .5, .5), (1., .5, .5)]) + # #cmap = LinearSegmentedColormap('{}'.format(l), cdict) + # cmap = LinearSegmentedColormap.from_list('cmap_{}'.format(str(l)), [colors[i], colors[i]], 255) + # cmap._init() + # #alphas = .5*(1+scipy.special.erf(np.linspace(-2,2, cmap.N+3)))#np.log(np.linspace(np.exp(0), np.exp(1.), cmap.N+3)) + # alphas = (scipy.special.erf(np.linspace(0,2.4, cmap.N+3)))#np.log(np.linspace(np.exp(0), np.exp(1.), cmap.N+3)) + # cmap._lut[:, -1] = alphas + # print l + # x, y = X[labels==l].T + # + # heatmap, xedges, yedges = np.histogram2d(x, y, bins=300, range=range_) + # #heatmap, xedges, yedges = np.histogram2d(x, y, bins=100) + # + # im = ax.imshow(heatmap, extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]], cmap=cmap, aspect='auto', interpolation='nearest', label=str(l)) + # legend_handles.append(mpatches.Patch(color=colors[i], label=l)) + # ax.set_xlim(xmin, xmax) + # ax.set_ylim(ymin, ymax) + # plt.legend(legend_handles, [l.get_label() for l in legend_handles]) + # plt.draw() + # plt.show() + #======================================================================= - \ No newline at end of file + canvas, kwargs = pl.get_new_canvas(xlabel='latent dimension %i' % input_1, ylabel='latent dimension %i' % input_2, **kwargs) + + _, _, _, _, _, Xgrid, _, _, _, _, resolution = helper_for_plot_data(self, ((xmin, ymin), (xmax, ymax)), (input_1, input_2), None, resolution) + + def plot_function(x): + Xtest_full = np.zeros((x.shape[0], X.shape[1])) + Xtest_full[:, [input_1, input_2]] = x + mf = self.predict_magnification(Xtest_full, kern=kern, mean=mean, covariance=covariance) + return mf + + imshow_kwargs = update_not_existing_kwargs(imshow_kwargs, pl.defaults.magnification) + Y = plot_function(Xgrid[:, [input_1, input_2]]).reshape(resolution, resolution).T[::-1, :] + view = pl.imshow(canvas, Y, + (xmin, ymin, xmax, ymax), + None, plot_function, resolution, + vmin=Y.min(), vmax=Y.max(), + **imshow_kwargs) + + # make sure labels are in order of input: + ulabels = [] + for lab in labels: + if not lab in ulabels: + ulabels.append(lab) + + marker = itertools.cycle(list(marker)) + scatters = [] + + for ul in ulabels: + if type(ul) is np.string_: + this_label = ul + elif type(ul) is np.int64: + this_label = 'class %i' % ul + else: + this_label = unicode(ul) + m = marker.next() + + index = np.nonzero(labels == ul)[0] + if self.input_dim == 1: + x = X[index, input_1] + y = np.zeros(index.size) + else: + x = X[index, input_1] + y = X[index, input_2] + update_not_existing_kwargs(kwargs, pl.defaults.latent_scatter) + scatters.append(pl.scatter(canvas, x, y, marker=m, color=Tango.nextMedium(), label=this_label, **kwargs)) + + plots = pl.show_canvas(canvas, dict(scatter=scatters, imshow=view), legend=legend, xlim=xlim, ylim=ylim) + if updates: + clear = raw_input('yes or enter to deactivate updates - otherwise still do updates - use plots[imshow].deactivate() to clear') + if clear.lower() in 'yes' or clear == '': + view.deactivate() + else: + view.deactivate() + return plots diff --git a/GPy/plotting/gpy_plot/plot_util.py b/GPy/plotting/gpy_plot/plot_util.py index 1a9b2a92..0defc9df 100644 --- a/GPy/plotting/gpy_plot/plot_util.py +++ b/GPy/plotting/gpy_plot/plot_util.py @@ -125,7 +125,10 @@ def update_not_existing_kwargs(to_update, update_from): This is used for updated kwargs from the default dicts. """ - return to_update.update({k:v for k,v in update_from.items() if k not in to_update}) + if to_update is None: + to_update = {} + to_update.update({k:v for k,v in update_from.items() if k not in to_update}) + return to_update def get_x_y_var(model): """ @@ -208,12 +211,14 @@ def x_frame2D(X,plot_limits=None,resolution=None): """ Internal helper function for making plots, returns a set of input values to plot as well as lower and upper limits """ - assert X.shape[1] ==2, "x_frame2D is defined for two-dimensional inputs" + assert X.shape[1]==2, "x_frame2D is defined for two-dimensional inputs" if plot_limits is None: - xmin,xmax = X.min(0),X.max(0) + xmin, xmax = X.min(0),X.max(0) xmin, xmax = xmin-0.2*(xmax-xmin), xmax+0.2*(xmax-xmin) elif len(plot_limits)==2: xmin, xmax = plot_limits + elif len(plot_limits)==4: + xmin, xmax = (plot_limits[0], plot_limits[2]), (plot_limits[1], plot_limits[3]) else: raise ValueError("Bad limits for plotting") diff --git a/GPy/plotting/matplot_dep/latent_space_visualizations/controllers/axis_event_controller.py b/GPy/plotting/matplot_dep/controllers/axis_event_controller.py similarity index 100% rename from GPy/plotting/matplot_dep/latent_space_visualizations/controllers/axis_event_controller.py rename to GPy/plotting/matplot_dep/controllers/axis_event_controller.py diff --git a/GPy/plotting/matplot_dep/latent_space_visualizations/controllers/imshow_controller.py b/GPy/plotting/matplot_dep/controllers/imshow_controller.py similarity index 93% rename from GPy/plotting/matplot_dep/latent_space_visualizations/controllers/imshow_controller.py rename to GPy/plotting/matplot_dep/controllers/imshow_controller.py index 87f7df7b..9d941073 100644 --- a/GPy/plotting/matplot_dep/latent_space_visualizations/controllers/imshow_controller.py +++ b/GPy/plotting/matplot_dep/controllers/imshow_controller.py @@ -22,11 +22,10 @@ class ImshowController(BufferedAxisChangedController): """ super(ImshowController, self).__init__(ax, plot_function, plot_limits, resolution, update_lim, **kwargs) - def _init_view(self, canvas, X, xmin, xmax, ymin, ymax, **kwargs): - return pl.imshow(canvas, X, extent=(xmin, xmax, + def _init_view(self, canvas, X, xmin, xmax, ymin, ymax, vmin=None, vmax=None, **kwargs): + return canvas.imshow(X, extent=(xmin, xmax, ymin, ymax), - vmin=X.min(), - vmax=X.max(), + vmin=vmin, vmax=vmax, **kwargs) def update_view(self, view, X, xmin, xmax, ymin, ymax): diff --git a/GPy/plotting/matplot_dep/defaults.py b/GPy/plotting/matplot_dep/defaults.py index f074fc55..5e8d84e9 100644 --- a/GPy/plotting/matplot_dep/defaults.py +++ b/GPy/plotting/matplot_dep/defaults.py @@ -65,4 +65,9 @@ data_y_1d = dict(linewidth=0, cmap='RdBu', s=40) data_y_1d_plot = dict(color='k', linewidth=1.5) # Kernel plots: -ard = dict(edgecolor='k', linewidth=1.2) \ No newline at end of file +ard = dict(edgecolor='k', linewidth=1.2) + +# Input plots: +latent = dict(aspect='auto', cmap='Greys', interpolation='bilinear') +magnification = dict(aspect='auto', cmap='Greys', interpolation='bilinear') +latent_scatter = dict(s=40, linewidth=.2, edgecolor='k', alpha=.9) \ No newline at end of file diff --git a/GPy/plotting/matplot_dep/dim_reduction_plots.py b/GPy/plotting/matplot_dep/dim_reduction_plots.py index be7d032f..25b5d45b 100644 --- a/GPy/plotting/matplot_dep/dim_reduction_plots.py +++ b/GPy/plotting/matplot_dep/dim_reduction_plots.py @@ -7,7 +7,7 @@ from ...core.parameterization.variational import VariationalPosterior from .base_plots import x_frame2D import itertools try: -from GPy.plotting import Tango + from GPy.plotting import Tango from matplotlib.cm import get_cmap from matplotlib import pyplot as pb from matplotlib import cm diff --git a/GPy/plotting/matplot_dep/latent_space_visualizations/__init__.py b/GPy/plotting/matplot_dep/latent_space_visualizations/__init__.py deleted file mode 100644 index 4644261c..00000000 --- a/GPy/plotting/matplot_dep/latent_space_visualizations/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .import controllers diff --git a/GPy/plotting/matplot_dep/latent_space_visualizations/controllers/__init__.py b/GPy/plotting/matplot_dep/latent_space_visualizations/controllers/__init__.py deleted file mode 100644 index f59b71ba..00000000 --- a/GPy/plotting/matplot_dep/latent_space_visualizations/controllers/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from . import axis_event_controller, imshow_controller diff --git a/GPy/plotting/matplot_dep/plot_definitions.py b/GPy/plotting/matplot_dep/plot_definitions.py index 174d0cd3..f1fca71a 100644 --- a/GPy/plotting/matplot_dep/plot_definitions.py +++ b/GPy/plotting/matplot_dep/plot_definitions.py @@ -33,13 +33,14 @@ from ..abstract_plotting_library import AbstractPlottingLibrary from .. import Tango from . import defaults from matplotlib.colors import LinearSegmentedColormap +from .controllers.imshow_controller import ImshowController class MatplotlibPlots(AbstractPlottingLibrary): def __init__(self): super(MatplotlibPlots, self).__init__() self._defaults = defaults.__dict__ - def get_new_canvas(self, xlabel=None, ylabel=None, zlabel=None, title=None, legend=True, projection='2d', **kwargs): + def get_new_canvas(self, xlabel=None, ylabel=None, zlabel=None, title=None, projection='2d', **kwargs): if projection == '3d': from mpl_toolkits.mplot3d import Axes3D elif projection == '2d': @@ -61,22 +62,23 @@ class MatplotlibPlots(AbstractPlottingLibrary): if title is not None: ax.set_title(title) return ax, kwargs - def show_canvas(self, ax, plots, xlim=None, ylim=None, zlim=None, **kwargs): + def show_canvas(self, ax, plots, xlim=None, ylim=None, zlim=None, legend=True, **kwargs): try: ax.autoscale_view() ax.set_xlim(xlim) ax.set_ylim(ylim) + if legend: + ax.legend() if zlim is not None: ax.set_zlim(zlim) ax.figure.canvas.draw() - #ax.figure.tight_layout() except: pass return plots - def scatter(self, ax, X, Y, Z=None, color=Tango.colorsHex['mediumBlue'], label=None, **kwargs): + def scatter(self, ax, X, Y, Z=None, color=Tango.colorsHex['mediumBlue'], label=None, marker='o', **kwargs): if Z is not None: - return ax.scatter(X, Y, c=color, zs=Z, label=label, **kwargs) + return ax.scatter(X, Y, c=color, zs=Z, label=label, marker=marker, **kwargs) return ax.scatter(X, Y, c=color, label=label, **kwargs) def plot(self, ax, X, Y, color=None, label=None, **kwargs): @@ -116,8 +118,11 @@ class MatplotlibPlots(AbstractPlottingLibrary): return ax.errorbar(X, Y, Z, yerr=error, ecolor=color, label=label, **kwargs) return ax.errorbar(X, Y, yerr=error, ecolor=color, label=label, **kwargs) - def imshow(self, ax, X, label=None, **kwargs): - return ax.imshow(X, label=label, **kwargs) + def imshow(self, ax, X, extent=None, label=None, plot_function=None, resolution=None, vmin=None, vmax=None, **kwargs): + if plot_function is not None: + self.controller = ImshowController(ax, plot_function, extent, resolution=resolution, vmin=vmin, vmax=vmax, **kwargs) + return self.controller + return ax.imshow(X, label=label, extent=extent, vmin=vmin, vmax=vmax, **kwargs) def contour(self, ax, X, Y, C, levels=20, label=None, **kwargs): return ax.contour(X, Y, C, levels=np.linspace(C.min(), C.max(), levels), label=label, **kwargs) diff --git a/GPy/testing/plotting_tests/baseline/gp_class_-failed-diff.png b/GPy/testing/plotting_tests/baseline/gp_class_-failed-diff.png index 6ec4e856..e6f3f308 100644 Binary files a/GPy/testing/plotting_tests/baseline/gp_class_-failed-diff.png and b/GPy/testing/plotting_tests/baseline/gp_class_-failed-diff.png differ diff --git a/GPy/testing/plotting_tests/baseline/gp_class_link-failed-diff.png b/GPy/testing/plotting_tests/baseline/gp_class_link-failed-diff.png index 937b5dae..26a94894 100644 Binary files a/GPy/testing/plotting_tests/baseline/gp_class_link-failed-diff.png and b/GPy/testing/plotting_tests/baseline/gp_class_link-failed-diff.png differ diff --git a/GPy/testing/plotting_tests/baseline/gp_class_raw_link-failed-diff.png b/GPy/testing/plotting_tests/baseline/gp_class_raw_link-failed-diff.png index b8c957bd..6222d49f 100644 Binary files a/GPy/testing/plotting_tests/baseline/gp_class_raw_link-failed-diff.png and b/GPy/testing/plotting_tests/baseline/gp_class_raw_link-failed-diff.png differ diff --git a/GPy/testing/plotting_tests/baseline/sparse_gp_class_-failed-diff.png b/GPy/testing/plotting_tests/baseline/sparse_gp_class_-failed-diff.png index 49d4224e..b43d0363 100644 Binary files a/GPy/testing/plotting_tests/baseline/sparse_gp_class_-failed-diff.png and b/GPy/testing/plotting_tests/baseline/sparse_gp_class_-failed-diff.png differ diff --git a/GPy/testing/plotting_tests/baseline/sparse_gp_class_link-failed-diff.png b/GPy/testing/plotting_tests/baseline/sparse_gp_class_link-failed-diff.png index 35abbd30..0e6e5adc 100644 Binary files a/GPy/testing/plotting_tests/baseline/sparse_gp_class_link-failed-diff.png and b/GPy/testing/plotting_tests/baseline/sparse_gp_class_link-failed-diff.png differ diff --git a/GPy/testing/plotting_tests/baseline/sparse_gp_class_raw_link-failed-diff.png b/GPy/testing/plotting_tests/baseline/sparse_gp_class_raw_link-failed-diff.png index 11988e85..05f6e64f 100644 Binary files a/GPy/testing/plotting_tests/baseline/sparse_gp_class_raw_link-failed-diff.png and b/GPy/testing/plotting_tests/baseline/sparse_gp_class_raw_link-failed-diff.png differ