diff --git a/GPy/plotting/__init__.py b/GPy/plotting/__init__.py index 00c4fa82..ad62a00f 100644 --- a/GPy/plotting/__init__.py +++ b/GPy/plotting/__init__.py @@ -2,7 +2,7 @@ # Licensed under the BSD 3-clause license (see LICENSE.txt) current_lib = [None] -supported_libraries = ['matplotlib', 'plotly', 'none'] +supported_libraries = ['matplotlib', 'plotly', 'plotly_online', 'plotly_offline', 'none'] error_suggestion = "Please make sure you specify your plotting library in your configuration file (/.config/GPy/user.cfg).\n\n[plotting]\nlibrary = \n\nCurrently supported libraries: {}".format(", ".join(supported_libraries)) def change_plotting_library(lib, **kwargs): @@ -19,10 +19,14 @@ def change_plotting_library(lib, **kwargs): from .matplot_dep.plot_definitions import MatplotlibPlots from .matplot_dep import visualize, mapping_plots, priors_plots, ssgplvm, svig_plots, variational_plots, img_plots current_lib[0] = MatplotlibPlots() - if lib == 'plotly': + if lib in ['plotly', 'plotly_online']: import plotly - from .plotly_dep.plot_definitions import PlotlyPlots - current_lib[0] = PlotlyPlots(**kwargs) + from .plotly_dep.plot_definitions import PlotlyPlotsOnline + current_lib[0] = PlotlyPlotsOnline(**kwargs) + if lib == 'plotly_offline': + import plotly + from .plotly_dep.plot_definitions import PlotlyPlotsOffline + current_lib[0] = PlotlyPlotsOffline(**kwargs) if lib == 'none': current_lib[0] = None inject_plotting() diff --git a/GPy/plotting/plotly_dep/plot_definitions.py b/GPy/plotting/plotly_dep/plot_definitions.py index f7fa2054..b85f540f 100644 --- a/GPy/plotting/plotly_dep/plot_definitions.py +++ b/GPy/plotting/plotly_dep/plot_definitions.py @@ -31,7 +31,6 @@ import numpy as np from ..abstract_plotting_library import AbstractPlottingLibrary from .. import Tango from . import defaults -OFFLINE=False import plotly from plotly import tools from plotly.graph_objs import Scatter, Scatter3d, Line,\ @@ -53,13 +52,11 @@ SYMBOL_MAP = { 'd': 'diamond', } -class PlotlyPlots(AbstractPlottingLibrary): - def __init__(self, offline=False): - super(PlotlyPlots, self).__init__() +class PlotlyPlotsBase(AbstractPlottingLibrary): + def __init__(self): + super(PlotlyPlotsBase, self).__init__() self._defaults = defaults.__dict__ self.current_states = dict() - global OFFLINE - OFFLINE=offline def figure(self, rows=1, cols=1, specs=None, is_3d=False, **kwargs): if specs is None: @@ -101,8 +98,8 @@ class PlotlyPlots(AbstractPlottingLibrary): append_annotation(a, xref, yref) # elif isinstance(traces, (Trace)): # doesn't work # elif type(traces) in [v for k,v in go.__dict__.iteritems()]: - elif isinstance(traces, (Scatter, Scatter3d, Line, Marker, ErrorX, - ErrorY, Bar, Heatmap, Trace, Contour, Font, Surface)): + elif isinstance(traces, (Scatter, Scatter3d, ErrorX, + ErrorY, Bar, Heatmap, Trace, Contour, Surface)): try: append_trace(traces, row, col) except PlotlyDictKeyError: @@ -120,22 +117,7 @@ class PlotlyPlots(AbstractPlottingLibrary): return canvas def show_canvas(self, canvas, filename=None, **kwargs): - figure, _, _ = canvas - if len(figure.data) == 0: - # add mock data - figure.append_trace(Scatter(x=[], y=[], name='', showlegend=False), 1, 1) - from ..gpy_plot.plot_util import in_ipynb - if in_ipynb(): - if OFFLINE: - plotly.offline.init_notebook_mode(connected=True) - return plotly.offline.iplot(figure, filename=filename, **kwargs)#self.current_states[hex(id(figure))]['filename']) - else: - return plotly.plotly.iplot(figure, filename=filename, **kwargs) - else: - if OFFLINE: - return plotly.offline.plot(figure, filename=filename, **kwargs) - else: - return plotly.plotly.plot(figure, filename=filename, **kwargs)#self.current_states[hex(id(figure))]['filename']) + return NotImplementedError def scatter(self, ax, X, Y, Z=None, color=Tango.colorsHex['mediumBlue'], cmap=None, label=None, marker='o', marker_kwargs=None, **kwargs): try: @@ -245,7 +227,7 @@ class PlotlyPlots(AbstractPlottingLibrary): def imshow_interact(self, ax, plot_function, extent=None, label=None, resolution=None, vmin=None, vmax=None, **imshow_kwargs): # TODO stream interaction? - super(PlotlyPlots, self).imshow_interact(ax, plot_function) + super(PlotlyPlotsBase, self).imshow_interact(ax, plot_function) def annotation_heatmap(self, ax, X, annotation, extent=None, label='Gradient', imshow_kwargs=None, **annotation_kwargs): imshow_kwargs.setdefault('label', label) @@ -272,7 +254,7 @@ class PlotlyPlots(AbstractPlottingLibrary): return imshow, annotations def annotation_heatmap_interact(self, ax, plot_function, extent, label=None, resolution=15, imshow_kwargs=None, **annotation_kwargs): - super(PlotlyPlots, self).annotation_heatmap_interact(ax, plot_function, extent) + super(PlotlyPlotsBase, self).annotation_heatmap_interact(ax, plot_function, extent) def contour(self, ax, X, Y, C, levels=20, label=None, **kwargs): return Contour(x=X, y=Y, z=C, @@ -325,3 +307,35 @@ class PlotlyPlots(AbstractPlottingLibrary): name=None, line=Line(width=1, smoothing=0, color=fcolor), mode='none', fill='tonextx', legendgroup='density', hoverinfo='none', **kwargs)) return polycol + + +class PlotlyPlotsOnline(PlotlyPlotsBase): + def __init__(self): + super(PlotlyPlotsOnline, self).__init__() + + def show_canvas(self, canvas, filename=None, **kwargs): + figure, _, _ = canvas + if len(figure.data) == 0: + # add mock data + figure.append_trace(Scatter(x=[], y=[], name='', showlegend=False), 1, 1) + from ..gpy_plot.plot_util import in_ipynb + if in_ipynb(): + return plotly.plotly.iplot(figure, filename=filename, **kwargs) + else: + return plotly.plotly.plot(figure, filename=filename, **kwargs)#self.current_states[hex(id(figure))]['filename']) + +class PlotlyPlotsOffline(PlotlyPlotsBase): + def __init__(self): + super(PlotlyPlotsOffline, self).__init__() + + def show_canvas(self, canvas, filename=None, **kwargs): + figure, _, _ = canvas + if len(figure.data) == 0: + # add mock data + figure.append_trace(Scatter(x=[], y=[], name='', showlegend=False), 1, 1) + from ..gpy_plot.plot_util import in_ipynb + if in_ipynb(): + plotly.offline.init_notebook_mode(connected=True) + return plotly.offline.iplot(figure, filename=filename, **kwargs)#self.current_states[hex(id(figure))]['filename']) + else: + return plotly.offline.plot(figure, filename=filename, **kwargs)