From 7d96266714a9d356f1966f24d590c6e9923ce447 Mon Sep 17 00:00:00 2001 From: Max Zwiessele Date: Mon, 29 Jul 2013 15:33:04 +0100 Subject: [PATCH] added imshow controller and label controller --- .../controllers/__init__.py | 0 .../controllers/axis_event_controller.py | 142 ++++++++++++++++++ .../controllers/imshow_controller.py | 71 +++++++++ 3 files changed, 213 insertions(+) create mode 100644 GPy/util/latent_space_visualizations/controllers/__init__.py create mode 100644 GPy/util/latent_space_visualizations/controllers/axis_event_controller.py create mode 100644 GPy/util/latent_space_visualizations/controllers/imshow_controller.py diff --git a/GPy/util/latent_space_visualizations/controllers/__init__.py b/GPy/util/latent_space_visualizations/controllers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/GPy/util/latent_space_visualizations/controllers/axis_event_controller.py b/GPy/util/latent_space_visualizations/controllers/axis_event_controller.py new file mode 100644 index 00000000..acb1ac8d --- /dev/null +++ b/GPy/util/latent_space_visualizations/controllers/axis_event_controller.py @@ -0,0 +1,142 @@ +''' +Created on 24 Jul 2013 + +@author: maxz +''' +import numpy + +class AxisEventController(object): + def __init__(self, ax): + self.ax = ax + self.activate() + def deactivate(self): + for cb_class in self.ax.callbacks.callbacks.values(): + for cb_num in cb_class.keys(): + self.ax.callbacks.disconnect(cb_num) + def activate(self): + self.ax.callbacks.connect('xlim_changed', self.xlim_changed) + self.ax.callbacks.connect('ylim_changed', self.ylim_changed) + def xlim_changed(self, ax): + pass + def ylim_changed(self, ax): + pass + + +class AxisChangedController(AxisEventController): + ''' + Buffered control of axis limit changes + ''' + _changing = False + + def __init__(self, ax, update_lim=None): + ''' + Constructor + ''' + super(AxisChangedController, self).__init__(ax) + self._lim_ratio_threshold = update_lim or .8 + self._x_lim = self.ax.get_xlim() + self._y_lim = self.ax.get_ylim() + + def update(self, ax): + pass + + def xlim_changed(self, ax): + super(AxisChangedController, self).xlim_changed(ax) + if not self._changing and self.lim_changed(ax.get_xlim(), self._x_lim): + self._changing = True + self._x_lim = ax.get_xlim() + self.update(ax) + self._changing = False + + def ylim_changed(self, ax): + super(AxisChangedController, self).ylim_changed(ax) + if not self._changing and self.lim_changed(ax.get_ylim(), self._y_lim): + self._changing = True + self._y_lim = ax.get_ylim() + self.update(ax) + self._changing = False + + def extent(self, lim): + return numpy.subtract(*lim) + + def lim_changed(self, axlim, savedlim): + axextent = self.extent(axlim) + extent = self.extent(savedlim) + lim_changed = ((axextent / extent) < self._lim_ratio_threshold ** 2 + or (extent / axextent) < self._lim_ratio_threshold ** 2 + or ((1 - (self.extent((axlim[0], savedlim[0])) / self.extent((savedlim[0], axlim[1])))) + < self._lim_ratio_threshold) + or ((1 - (self.extent((savedlim[0], axlim[0])) / self.extent((axlim[0], savedlim[1])))) + < self._lim_ratio_threshold) + ) + return lim_changed + + def _buffer_lim(self, lim): + # buffer_size = 1 - self._lim_ratio_threshold + # extent = self.extent(lim) + return lim + + +class BufferedAxisChangedController(AxisChangedController): + def __init__(self, ax, plot_function, plot_limits, resolution=50, update_lim=None, **kwargs): + """ + :param plot_function: + function to use for creating image for plotting (return ndarray-like) + plot_function gets called with (2D!) Xtest grid if replotting required + :type plot_function: function + :param plot_limits: + beginning plot limits [xmin, ymin, xmax, ymax] + + :param kwargs: additional kwargs are for pyplot.imshow(**kwargs) + """ + super(BufferedAxisChangedController, self).__init__(ax, update_lim=update_lim) + self.plot_function = plot_function + xmin, xmax = self._x_lim # self._compute_buffered(*self._x_lim) + ymin, ymax = self._y_lim # self._compute_buffered(*self._y_lim) + self.resolution = resolution + self._not_init = False + self.view = self._init_view(self.ax, self.recompute_X(), xmin, xmax, ymin, ymax, **kwargs) + self._not_init = True + + def update(self, ax): + super(BufferedAxisChangedController, self).update(ax) + if self._not_init: + xmin, xmax = self._compute_buffered(*self._x_lim) + ymin, ymax = self._compute_buffered(*self._y_lim) + self.update_view(self.view, self.recompute_X(), xmin, xmax, ymin, ymax) + + def _init_view(self, ax, X, xmin, xmax, ymin, ymax): + raise NotImplementedError('return view for this controller') + + def update_view(self, view, X, xmin, xmax, ymin, ymax): + raise NotImplementedError('update view given in here') + + def get_grid(self): + xmin, xmax = self._compute_buffered(*self._x_lim) + ymin, ymax = self._compute_buffered(*self._y_lim) + x, y = numpy.mgrid[xmin:xmax:1j * self.resolution, ymin:ymax:1j * self.resolution] + return numpy.hstack((x.flatten()[:, None], y.flatten()[:, None])) + + def recompute_X(self): + X = self.plot_function(self.get_grid()) + if isinstance(X, (tuple, list)): + for x in X: + x.shape = [self.resolution, self.resolution] + x[:, :] = x.T[::-1, :] + return X + return X.reshape(self.resolution, self.resolution).T[::-1, :] + + def _compute_buffered(self, mi, ma): + buffersize = self._buffersize() + size = ma - mi + return mi - (buffersize * size), ma + (buffersize * size) + + def _buffersize(self): + try: + buffersize = 1. - self._lim_ratio_threshold + except: + buffersize = .4 + return buffersize + + + diff --git a/GPy/util/latent_space_visualizations/controllers/imshow_controller.py b/GPy/util/latent_space_visualizations/controllers/imshow_controller.py new file mode 100644 index 00000000..fa6682e9 --- /dev/null +++ b/GPy/util/latent_space_visualizations/controllers/imshow_controller.py @@ -0,0 +1,71 @@ +''' +Created on 24 Jul 2013 + +@author: maxz +''' +from GPy.util.latent_space_visualizations.controllers.axis_event_controller import BufferedAxisChangedController +import itertools +import numpy + + +class ImshowController(BufferedAxisChangedController): + def __init__(self, ax, plot_function, plot_limits, resolution=50, update_lim=.5, **kwargs): + """ + :param plot_function: + function to use for creating image for plotting (return ndarray-like) + plot_function gets called with (2D!) Xtest grid if replotting required + :type plot_function: function + :param plot_limits: + beginning plot limits [xmin, ymin, xmax, ymax] + + :param kwargs: additional kwargs are for pyplot.imshow(**kwargs) + """ + super(ImshowController, self).__init__(ax, plot_function, plot_limits, resolution, update_lim, **kwargs) + + def _init_view(self, ax, X, xmin, xmax, ymin, ymax, **kwargs): + return ax.imshow(X, extent=(xmin, xmax, + ymin, ymax), + vmin=X.min(), + vmax=X.max(), + **kwargs) + + def update_view(self, view, X, xmin, xmax, ymin, ymax): + view.set_data(X) + view.set_extent((xmin, xmax, ymin, ymax)) + +class ImAnnotateController(ImshowController): + def __init__(self, ax, plot_function, plot_limits, resolution=20, update_lim=.99, **kwargs): + """ + :param plot_function: + function to use for creating image for plotting (return ndarray-like) + plot_function gets called with (2D!) Xtest grid if replotting required + :type plot_function: function + :param plot_limits: + beginning plot limits [xmin, ymin, xmax, ymax] + :param text_props: kwargs for pyplot.text(**text_props) + :param kwargs: additional kwargs are for pyplot.imshow(**kwargs) + """ + super(ImAnnotateController, self).__init__(ax, plot_function, plot_limits, resolution, update_lim, **kwargs) + + def _init_view(self, ax, X, xmin, xmax, ymin, ymax, text_props={}, **kwargs): + view = [super(ImAnnotateController, self)._init_view(ax, X[0], xmin, xmax, ymin, ymax, **kwargs)] + xoffset, yoffset = self._offsets(xmin, xmax, ymin, ymax) + xlin = numpy.linspace(xmin, xmax, self.resolution, endpoint=False) + ylin = numpy.linspace(ymin, ymax, self.resolution, endpoint=False) + for [i, x], [j, y] in itertools.product(enumerate(xlin), enumerate(ylin[::-1])): + view.append(ax.text(x + xoffset, y + yoffset, "{}".format(X[1][j, i]), ha='center', va='center', **text_props)) + return view + + def update_view(self, view, X, xmin, xmax, ymin, ymax): + super(ImAnnotateController, self).update_view(view[0], X[0], xmin, xmax, ymin, ymax) + xoffset, yoffset = self._offsets(xmin, xmax, ymin, ymax) + xlin = numpy.linspace(xmin, xmax, self.resolution, endpoint=False) + ylin = numpy.linspace(ymin, ymax, self.resolution, endpoint=False) + for [[i, x], [j, y]], text in itertools.izip(itertools.product(enumerate(xlin), enumerate(ylin[::-1])), view[1:]): + text.set_x(x + xoffset) + text.set_y(y + yoffset) + text.set_text("{}".format(X[1][j, i])) + return view + + def _offsets(self, xmin, xmax, ymin, ymax): + return (xmax - xmin) / (2 * self.resolution), (ymax - ymin) / (2 * self.resolution)