mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-04 01:02:39 +02:00
313 lines
15 KiB
Python
313 lines
15 KiB
Python
#===============================================================================
|
|
# Copyright (c) 2015, Max Zwiessele
|
|
# All rights reserved.
|
|
#
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
#
|
|
# * Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
#
|
|
# * Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
#
|
|
# * Neither the name of GPy.plotting.matplot_dep.plot_definitions nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
#
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
#===============================================================================
|
|
import numpy as np
|
|
from ..abstract_plotting_library import AbstractPlottingLibrary
|
|
from .. import Tango
|
|
from . import defaults
|
|
from plotly import tools
|
|
from plotly import plotly as py
|
|
from plotly.graph_objs import Scatter, Scatter3d, Line,\
|
|
Marker, ErrorX, ErrorY, Bar, Heatmap, Trace,\
|
|
Annotations, Annotation, Contour, Font, Surface
|
|
from plotly.exceptions import PlotlyDictKeyError
|
|
|
|
SYMBOL_MAP = {
|
|
'o': 'dot',
|
|
'v': 'triangle-down',
|
|
'^': 'triangle-up',
|
|
'<': 'triangle-left',
|
|
'>': 'triangle-right',
|
|
's': 'square',
|
|
'+': 'cross',
|
|
'x': 'x',
|
|
'*': 'x', # no star yet in plotly!!
|
|
'D': 'diamond',
|
|
'd': 'diamond',
|
|
}
|
|
|
|
class PlotlyPlots(AbstractPlottingLibrary):
|
|
def __init__(self):
|
|
super(PlotlyPlots, self).__init__()
|
|
self._defaults = defaults.__dict__
|
|
self.current_states = dict()
|
|
|
|
def figure(self, rows=1, cols=1, specs=None, is_3d=False, **kwargs):
|
|
if specs is None:
|
|
specs = [[{'is_3d': is_3d}]*cols]*rows
|
|
figure = tools.make_subplots(rows, cols, specs=specs, **kwargs)
|
|
return figure
|
|
|
|
def new_canvas(self, figure=None, row=1, col=1, projection='2d',
|
|
xlabel=None, ylabel=None, zlabel=None,
|
|
title=None, xlim=None,
|
|
ylim=None, zlim=None, **kwargs):
|
|
#if 'filename' not in kwargs:
|
|
# print('PlotlyWarning: filename was not given, this may clutter your plotly workspace')
|
|
# filename = None
|
|
#else:
|
|
# filename = kwargs.pop('filename')
|
|
if figure is None:
|
|
figure = self.figure(is_3d=projection=='3d')
|
|
figure.layout.font = Font(family="Raleway, sans-serif")
|
|
if projection == '3d':
|
|
figure.layout.legend.x=.5
|
|
figure.layout.legend.bgcolor='#DCDCDC'
|
|
return (figure, row, col), kwargs
|
|
|
|
def add_to_canvas(self, canvas, traces, legend=False, **kwargs):
|
|
figure, row, col = canvas
|
|
def append_annotation(a, xref, yref):
|
|
if 'xref' not in a:
|
|
a['xref'] = xref
|
|
if 'yref' not in a:
|
|
a['yref'] = yref
|
|
figure.layout.annotations.append(a)
|
|
def append_trace(t, row, col):
|
|
figure.append_trace(t, row, col)
|
|
def recursive_append(traces):
|
|
if isinstance(traces, Annotations):
|
|
xref, yref = figure._grid_ref[row-1][col-1]
|
|
for a in traces:
|
|
append_annotation(a, xref, yref)
|
|
elif isinstance(traces, (Trace)):
|
|
try:
|
|
append_trace(traces, row, col)
|
|
except PlotlyDictKeyError:
|
|
# Its a dictionary of plots:
|
|
for t in traces:
|
|
recursive_append(traces[t])
|
|
elif isinstance(traces, (dict)):
|
|
for t in traces:
|
|
recursive_append(traces[t])
|
|
elif isinstance(traces, (tuple, list)):
|
|
for t in traces:
|
|
recursive_append(t)
|
|
recursive_append(traces)
|
|
figure.layout['showlegend'] = legend
|
|
return canvas
|
|
|
|
def show_canvas(self, canvas, filename=None, **kwargs):
|
|
figure, _, _ = canvas
|
|
if len(figure.data) == 0:
|
|
# add mock data
|
|
figure.append_trace(Scatter(x=[], y=[], name='', showlegend=False), 1, 1)
|
|
from ..gpy_plot.plot_util import in_ipynb
|
|
if in_ipynb():
|
|
return py.iplot(figure, filename=filename)#self.current_states[hex(id(figure))]['filename'])
|
|
else:
|
|
return py.plot(figure, filename=filename)#self.current_states[hex(id(figure))]['filename'])
|
|
|
|
def scatter(self, ax, X, Y, Z=None, color=Tango.colorsHex['mediumBlue'], cmap=None, label=None, marker='o', marker_kwargs=None, **kwargs):
|
|
try:
|
|
marker = SYMBOL_MAP[marker]
|
|
except:
|
|
#not matplotlib marker
|
|
pass
|
|
marker_kwargs = marker_kwargs or {}
|
|
marker_kwargs.setdefault('symbol', marker)
|
|
if Z is not None:
|
|
return Scatter3d(x=X, y=Y, z=Z, mode='markers',
|
|
showlegend=label is not None,
|
|
marker=Marker(color=color, colorscale=cmap, **marker_kwargs),
|
|
name=label, **kwargs)
|
|
return Scatter(x=X, y=Y, mode='markers', showlegend=label is not None,
|
|
marker=Marker(color=color, colorscale=cmap, **marker_kwargs or {}),
|
|
name=label, **kwargs)
|
|
|
|
def plot(self, ax, X, Y, Z=None, color=None, label=None, line_kwargs=None, **kwargs):
|
|
if 'mode' not in kwargs:
|
|
kwargs['mode'] = 'lines'
|
|
if Z is not None:
|
|
return Scatter3d(x=X, y=Y, z=Z, showlegend=label is not None, line=Line(color=color, **line_kwargs or {}), name=label, **kwargs)
|
|
return Scatter(x=X, y=Y, showlegend=label is not None, line=Line(color=color, **line_kwargs or {}), name=label, **kwargs)
|
|
|
|
def plot_axis_lines(self, ax, X, color=Tango.colorsHex['mediumBlue'], label=None, marker_kwargs=None, **kwargs):
|
|
if X.shape[1] == 1:
|
|
annotations = Annotations()
|
|
for i, row in enumerate(X):
|
|
annotations.append(
|
|
Annotation(
|
|
text='',
|
|
x=row[0], y=0,
|
|
yref='paper',
|
|
ax=0, ay=20,
|
|
arrowhead=2,
|
|
arrowsize=1,
|
|
arrowwidth=2,
|
|
arrowcolor=color,
|
|
showarrow=True,
|
|
#showlegend=i==0,
|
|
#label=label,
|
|
))
|
|
return annotations
|
|
elif X.shape[1] == 2:
|
|
marker_kwargs.setdefault('symbol', 'diamond')
|
|
opacity = kwargs.pop('opacity', .8)
|
|
return Scatter3d(x=X[:, 0], y=X[:, 1], z=np.zeros(X.shape[0]),
|
|
mode='markers',
|
|
projection=dict(z=dict(show=True, opacity=opacity)),
|
|
marker=Marker(color=color, **marker_kwargs or {}),
|
|
opacity=0,
|
|
name=label,
|
|
showlegend=label is not None, **kwargs)
|
|
|
|
def barplot(self, canvas, x, height, width=0.8, bottom=0, color=Tango.colorsHex['mediumBlue'], label=None, **kwargs):
|
|
figure, _, _ = canvas
|
|
if 'barmode' in kwargs:
|
|
figure.layout['barmode'] = kwargs.pop('barmode')
|
|
return Bar(x=x, y=height, marker=Marker(color=color), name=label)
|
|
|
|
def xerrorbar(self, ax, X, Y, error, Z=None, color=Tango.colorsHex['mediumBlue'], label=None, error_kwargs=None, **kwargs):
|
|
error_kwargs = error_kwargs or {}
|
|
if (error.shape[0] == 2) and (error.ndim == 2):
|
|
error_kwargs.update(dict(array=error[1], arrayminus=error[0], symmetric=False))
|
|
else:
|
|
error_kwargs.update(dict(array=error, symmetric=True))
|
|
if Z is not None:
|
|
return Scatter3d(x=X, y=Y, z=Z, mode='markers',
|
|
error_x=ErrorX(color=color, **error_kwargs or {}),
|
|
marker=Marker(size='0'), name=label,
|
|
showlegend=label is not None, **kwargs)
|
|
return Scatter(x=X, y=Y, mode='markers',
|
|
error_x=ErrorX(color=color, **error_kwargs or {}),
|
|
marker=Marker(size='0'), name=label,
|
|
showlegend=label is not None,
|
|
**kwargs)
|
|
|
|
def yerrorbar(self, ax, X, Y, error, Z=None, color=Tango.colorsHex['mediumBlue'], label=None, error_kwargs=None, **kwargs):
|
|
error_kwargs = error_kwargs or {}
|
|
if (error.shape[0] == 2) and (error.ndim == 2):
|
|
error_kwargs.update(dict(array=error[1], arrayminus=error[0], symmetric=False))
|
|
else:
|
|
error_kwargs.update(dict(array=error, symmetric=True))
|
|
if Z is not None:
|
|
return Scatter3d(x=X, y=Y, z=Z, mode='markers',
|
|
error_y=ErrorY(color=color, **error_kwargs or {}),
|
|
marker=Marker(size='0'), name=label,
|
|
showlegend=label is not None, **kwargs)
|
|
return Scatter(x=X, y=Y, mode='markers',
|
|
error_y=ErrorY(color=color, **error_kwargs or {}),
|
|
marker=Marker(size='0'), name=label,
|
|
showlegend=label is not None,
|
|
**kwargs)
|
|
|
|
def imshow(self, ax, X, extent=None, label=None, vmin=None, vmax=None, **imshow_kwargs):
|
|
if not 'showscale' in imshow_kwargs:
|
|
imshow_kwargs['showscale'] = False
|
|
return Heatmap(z=X, name=label,
|
|
x0=extent[0], dx=float(extent[1]-extent[0])/(X.shape[0]-1),
|
|
y0=extent[2], dy=float(extent[3]-extent[2])/(X.shape[1]-1),
|
|
zmin=vmin, zmax=vmax,
|
|
showlegend=label is not None,
|
|
hoverinfo='z',
|
|
**imshow_kwargs)
|
|
|
|
def imshow_interact(self, ax, plot_function, extent=None, label=None, resolution=None, vmin=None, vmax=None, **imshow_kwargs):
|
|
# TODO stream interaction?
|
|
super(PlotlyPlots, self).imshow_interact(ax, plot_function)
|
|
|
|
def annotation_heatmap(self, ax, X, annotation, extent=None, label='Gradient', imshow_kwargs=None, **annotation_kwargs):
|
|
imshow_kwargs.setdefault('label', label)
|
|
imshow_kwargs.setdefault('showscale', True)
|
|
imshow = self.imshow(ax, X, extent, **imshow_kwargs)
|
|
X = X-X.min()
|
|
X /= X.max()/2.
|
|
X -= 1
|
|
x = np.linspace(extent[0], extent[1], X.shape[0])
|
|
y = np.linspace(extent[2], extent[3], X.shape[1])
|
|
annotations = Annotations()
|
|
for n, row in enumerate(annotation):
|
|
for m, val in enumerate(row):
|
|
var = X[n][m]
|
|
annotations.append(
|
|
Annotation(
|
|
text=str(val),
|
|
x=x[m], y=y[n],
|
|
xref='x1', yref='y1',
|
|
font=dict(color='white' if np.abs(var) > 0.8 else 'black', size=10),
|
|
opacity=.5,
|
|
showarrow=False,
|
|
hoverinfo='x'))
|
|
return imshow, annotations
|
|
|
|
def annotation_heatmap_interact(self, ax, plot_function, extent, label=None, resolution=15, imshow_kwargs=None, **annotation_kwargs):
|
|
super(PlotlyPlots, self).annotation_heatmap_interact(ax, plot_function, extent)
|
|
|
|
def contour(self, ax, X, Y, C, levels=20, label=None, **kwargs):
|
|
return Contour(x=X, y=Y, z=C,
|
|
#ncontours=levels, contours=Contours(start=C.min(), end=C.max(), size=(C.max()-C.min())/levels),
|
|
name=label, **kwargs)
|
|
|
|
def surface(self, ax, X, Y, Z, color=None, label=None, **kwargs):
|
|
return Surface(x=X, y=Y, z=Z, name=label, showlegend=label is not None, **kwargs)
|
|
|
|
def fill_between(self, ax, X, lower, upper, color=Tango.colorsHex['mediumBlue'], label=None, line_kwargs=None, **kwargs):
|
|
if not 'line' in kwargs:
|
|
kwargs['line'] = Line(**line_kwargs or {})
|
|
else:
|
|
kwargs['line'].update(line_kwargs or {})
|
|
if color.startswith('#'):
|
|
fcolor = 'rgba({c[0]}, {c[1]}, {c[2]}, {alpha})'.format(c=Tango.hex2rgb(color), alpha=kwargs.get('opacity', 1.0))
|
|
else: fcolor = color
|
|
u = Scatter(x=X, y=upper, fillcolor=fcolor, showlegend=label is not None, name=label, fill='tonextx', legendgroup='{}_fill_({},{})'.format(label, ax[1], ax[2]), **kwargs)
|
|
#fcolor = '{}, {alpha})'.format(','.join(fcolor.split(',')[:-1]), alpha=0.0)
|
|
l = Scatter(x=X, y=lower, fillcolor=fcolor, showlegend=False, name=label, legendgroup='{}_fill_({},{})'.format(label, ax[1], ax[2]), **kwargs)
|
|
return l, u
|
|
|
|
def fill_gradient(self, canvas, X, percentiles, color=Tango.colorsHex['mediumBlue'], label=None, **kwargs):
|
|
if color.startswith('#'):
|
|
colarray = Tango.hex2rgb(color)
|
|
opacity = .9
|
|
else:
|
|
colarray = map(float(color.strip(')').split('(')[1]))
|
|
if len(colarray) == 4:
|
|
colarray, opacity = colarray[:3] ,colarray[3]
|
|
|
|
alpha = opacity*(1.-np.abs(np.linspace(-1,1,len(percentiles)-1)))
|
|
|
|
def pairwise(iterable):
|
|
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
|
|
from itertools import tee
|
|
a, b = tee(iterable)
|
|
next(b, None)
|
|
return zip(a, b)
|
|
|
|
polycol = []
|
|
for i, y1, a in zip(range(len(percentiles)), percentiles, alpha):
|
|
fcolor = 'rgba({}, {}, {}, {alpha})'.format(*colarray, alpha=a)
|
|
if i == len(percentiles)/2:
|
|
polycol.append(Scatter(x=X, y=y1, fillcolor=fcolor, showlegend=True,
|
|
name=label, line=Line(width=0, smoothing=0), mode='none', fill='tonextx',
|
|
legendgroup='density', hoverinfo='none', **kwargs))
|
|
else:
|
|
polycol.append(Scatter(x=X, y=y1, fillcolor=fcolor, showlegend=False,
|
|
name=None, line=Line(width=1, smoothing=0, color=fcolor), mode='none', fill='tonextx',
|
|
legendgroup='density', hoverinfo='none', **kwargs))
|
|
return polycol
|