[plotly] todos: fill_gradient

This commit is contained in:
mzwiessele 2015-10-08 14:05:20 +01:00
parent 7f84bec6fb
commit b3154e43b4
10 changed files with 196 additions and 149 deletions

View file

@ -17,6 +17,8 @@ from numpy.testing import Tester
from . import kern from . import kern
from . import plotting from . import plotting
from .plotting import plotting_library
# Direct imports for convenience: # Direct imports for convenience:
from .core import Model from .core import Model
from .core.parameterization import Param, Parameterized, ObsAr from .core.parameterization import Param, Parameterized, ObsAr

View file

@ -1,5 +1,5 @@
# This is the local installation configuration file for GPy # This is the local installation configuration file for GPy
[plotting] [plotting]
#library = plotly library = plotly
library = matplotlib #library = matplotlib

View file

@ -1,3 +1,3 @@
from .. import plotting_library as pl from .. import plotting_library
pl = plotting_library
from . import data_plots, gp_plots, latent_plots, kernel_plots, plot_util, inference_plots from . import data_plots, gp_plots, latent_plots, kernel_plots, plot_util, inference_plots

View file

@ -123,7 +123,7 @@ def plot_data_error(self, which_data_rows='all',
def _plot_data_error(self, canvas, which_data_rows='all', def _plot_data_error(self, canvas, which_data_rows='all',
which_data_ycols='all', visible_dims=None, which_data_ycols='all', visible_dims=None,
projection='2d', **error_kwargs): projection='2d', label=None, **error_kwargs):
ycols = get_which_data_ycols(self, which_data_ycols) ycols = get_which_data_ycols(self, which_data_ycols)
rows = get_which_data_rows(self, which_data_rows) rows = get_which_data_rows(self, which_data_rows)
@ -139,17 +139,17 @@ def _plot_data_error(self, canvas, which_data_rows='all',
for d in ycols: for d in ycols:
update_not_existing_kwargs(error_kwargs, pl.defaults.xerrorbar) update_not_existing_kwargs(error_kwargs, pl.defaults.xerrorbar)
plots['xerrorplot'].append(pl.xerrorbar(canvas, X[rows, free_dims].flatten(), Y[rows, d].flatten(), plots['xerrorplot'].append(pl.xerrorbar(canvas, X[rows, free_dims].flatten(), Y[rows, d].flatten(),
2 * np.sqrt(X_variance[rows, free_dims].flatten()), 2 * np.sqrt(X_variance[rows, free_dims].flatten()), label=label,
**error_kwargs)) **error_kwargs))
#2D plotting #2D plotting
elif len(free_dims) == 2: elif len(free_dims) == 2:
update_not_existing_kwargs(error_kwargs, pl.defaults.xerrorbar) # @UndefinedVariable update_not_existing_kwargs(error_kwargs, pl.defaults.xerrorbar) # @UndefinedVariable
for d in ycols: for d in ycols:
plots['xerrorplot'].append(pl.xerrorbar(canvas, X[rows, free_dims[0]].flatten(), Y[rows, d].flatten(), plots['xerrorplot'].append(pl.xerrorbar(canvas, X[rows, free_dims[0]].flatten(), Y[rows, d].flatten(),
2 * np.sqrt(X_variance[rows, free_dims[0]].flatten()), 2 * np.sqrt(X_variance[rows, free_dims[0]].flatten()), label=label,
**error_kwargs)) **error_kwargs))
plots['yerrorplot'].append(pl.xerrorbar(canvas, X[rows, free_dims[1]].flatten(), Y[rows, d].flatten(), plots['yerrorplot'].append(pl.xerrorbar(canvas, X[rows, free_dims[1]].flatten(), Y[rows, d].flatten(),
2 * np.sqrt(X_variance[rows, free_dims[1]].flatten()), 2 * np.sqrt(X_variance[rows, free_dims[1]].flatten()), label=label,
**error_kwargs)) **error_kwargs))
elif len(free_dims) == 0: elif len(free_dims) == 0:
pass #Nothing to plot! pass #Nothing to plot!

View file

@ -40,7 +40,7 @@ def plot_mean(self, plot_limits=None, fixed_inputs=None,
apply_link=False, visible_dims=None, apply_link=False, visible_dims=None,
which_data_ycols='all', which_data_ycols='all',
levels=20, projection='2d', levels=20, projection='2d',
label=None, label='gp mean',
predict_kw=None, predict_kw=None,
**kwargs): **kwargs):
""" """
@ -70,8 +70,7 @@ def plot_mean(self, plot_limits=None, fixed_inputs=None,
predict_kw) predict_kw)
plots = _plot_mean(self, canvas, helper_data, helper_prediction, plots = _plot_mean(self, canvas, helper_data, helper_prediction,
levels, projection, label, **kwargs) levels, projection, label, **kwargs)
pl.add_to_canvas(canvas, plots) return pl.add_to_canvas(canvas, plots)
return pl.show_canvas(canvas)
def _plot_mean(self, canvas, helper_data, helper_prediction, def _plot_mean(self, canvas, helper_data, helper_prediction,
levels=20, projection='2d', label=None, levels=20, projection='2d', label=None,
@ -87,7 +86,7 @@ def _plot_mean(self, canvas, helper_data, helper_prediction,
else: else:
if projection == '2d': if projection == '2d':
update_not_existing_kwargs(kwargs, pl.defaults.meanplot_2d) # @UndefinedVariable update_not_existing_kwargs(kwargs, pl.defaults.meanplot_2d) # @UndefinedVariable
plots = dict(gpmean=[pl.contour(canvas, x, y, plots = dict(gpmean=[pl.contour(canvas, x[:,0], y[0,:],
mu.reshape(resolution, resolution), mu.reshape(resolution, resolution),
levels=levels, label=label, **kwargs)]) levels=levels, label=label, **kwargs)])
elif projection == '3d': elif projection == '3d':
@ -105,7 +104,7 @@ def _plot_mean(self, canvas, helper_data, helper_prediction,
def plot_confidence(self, lower=2.5, upper=97.5, plot_limits=None, fixed_inputs=None, def plot_confidence(self, lower=2.5, upper=97.5, plot_limits=None, fixed_inputs=None,
resolution=None, plot_raw=False, resolution=None, plot_raw=False,
apply_link=False, visible_dims=None, apply_link=False, visible_dims=None,
which_data_ycols='all', label=None, which_data_ycols='all', label='gp confidence',
predict_kw=None, predict_kw=None,
**kwargs): **kwargs):
""" """
@ -157,7 +156,7 @@ def plot_samples(self, plot_limits=None, fixed_inputs=None,
resolution=None, plot_raw=True, resolution=None, plot_raw=True,
apply_link=False, visible_dims=None, apply_link=False, visible_dims=None,
which_data_ycols='all', which_data_ycols='all',
samples=3, projection='2d', label=None, samples=3, projection='2d', label='gp_samples',
predict_kw=None, predict_kw=None,
**kwargs): **kwargs):
""" """
@ -214,7 +213,7 @@ def plot_density(self, plot_limits=None, fixed_inputs=None,
resolution=None, plot_raw=False, resolution=None, plot_raw=False,
apply_link=False, visible_dims=None, apply_link=False, visible_dims=None,
which_data_ycols='all', which_data_ycols='all',
levels=35, label=None, levels=35, label='gp density',
predict_kw=None, predict_kw=None,
**kwargs): **kwargs):
""" """
@ -270,7 +269,7 @@ def plot(self, plot_limits=None, fixed_inputs=None,
visible_dims=None, visible_dims=None,
levels=20, samples=0, samples_likelihood=0, lower=2.5, upper=97.5, levels=20, samples=0, samples_likelihood=0, lower=2.5, upper=97.5,
plot_data=True, plot_inducing=True, plot_density=False, plot_data=True, plot_inducing=True, plot_density=False,
predict_kw=None, projection='2d', **kwargs): predict_kw=None, projection='2d', legend=False, **kwargs):
""" """
Convinience function for plotting the fit of a GP. Convinience function for plotting the fit of a GP.
@ -311,18 +310,18 @@ def plot(self, plot_limits=None, fixed_inputs=None,
plot_data = False plot_data = False
plots = {} plots = {}
if plot_data: if plot_data:
plots.update(_plot_data(self, canvas, which_data_rows, which_data_ycols, visible_dims, projection)) plots.update(_plot_data(self, canvas, which_data_rows, which_data_ycols, visible_dims, projection, "Data"))
plots.update(_plot_data_error(self, canvas, which_data_rows, which_data_ycols, visible_dims, projection)) plots.update(_plot_data_error(self, canvas, which_data_rows, which_data_ycols, visible_dims, projection, "Data Error"))
plots.update(_plot(self, canvas, plots, helper_data, helper_prediction, levels, plot_inducing, plot_density, projection)) plots.update(_plot(self, canvas, plots, helper_data, helper_prediction, levels, plot_inducing, plot_density, projection))
if plot_raw and (samples_likelihood > 0): if plot_raw and (samples_likelihood > 0):
helper_prediction = helper_predict_with_model(self, helper_data[5], False, helper_prediction = helper_predict_with_model(self, helper_data[5], False,
apply_link, None, apply_link, None,
get_which_data_ycols(self, which_data_ycols), get_which_data_ycols(self, which_data_ycols),
predict_kw, samples_likelihood) predict_kw, samples_likelihood)
plots.update(_plot_samples(canvas, helper_data, helper_prediction, projection)) plots.update(_plot_samples(canvas, helper_data, helper_prediction, projection, "Lik Samples"))
if hasattr(self, 'Z') and plot_inducing: if hasattr(self, 'Z') and plot_inducing:
plots.update(_plot_inducing(self, canvas, visible_dims, projection, None)) plots.update(_plot_inducing(self, canvas, visible_dims, projection, 'Inducing'))
return pl.add_to_canvas(canvas, plots) return pl.add_to_canvas(canvas, plots, legend=legend)
def plot_f(self, plot_limits=None, fixed_inputs=None, def plot_f(self, plot_limits=None, fixed_inputs=None,
@ -333,7 +332,7 @@ def plot_f(self, plot_limits=None, fixed_inputs=None,
levels=20, samples=0, lower=2.5, upper=97.5, levels=20, samples=0, lower=2.5, upper=97.5,
plot_density=False, plot_density=False,
plot_data=True, plot_inducing=True, plot_data=True, plot_inducing=True,
projection='2d', projection='2d', legend=False,
predict_kw=None, predict_kw=None,
**kwargs): **kwargs):
""" """
@ -366,35 +365,27 @@ def plot_f(self, plot_limits=None, fixed_inputs=None,
:param dict error_kwargs: kwargs for the error plot for the plotting library you are using :param dict error_kwargs: kwargs for the error plot for the plotting library you are using
:param kwargs plot_kwargs: kwargs for the data plot for the plotting library you are using :param kwargs plot_kwargs: kwargs for the data plot for the plotting library you are using
""" """
canvas, _ = pl.new_canvas(projection=projection, **kwargs) plot(self, plot_limits, fixed_inputs, resolution, True,
helper_data = helper_for_plot_data(self, plot_limits, visible_dims, fixed_inputs, resolution) apply_link, which_data_ycols, which_data_rows,
helper_prediction = helper_predict_with_model(self, helper_data[5], True, visible_dims, levels, samples, 0,
apply_link, np.linspace(2.5, 97.5, levels*2) if plot_density else (lower,upper), lower, upper, plot_data, plot_inducing,
get_which_data_ycols(self, which_data_ycols), plot_density, predict_kw, projection, legend)
predict_kw, samples)
if not apply_link:
# It does not make sense to plot the data (which lives not in the latent function space) into latent function space.
plot_data = False
plots = {}
if plot_data:
plots.update(_plot_data(self, canvas, which_data_rows, which_data_ycols, visible_dims, projection))
plots.update(_plot_data_error(self, canvas, which_data_rows, which_data_ycols, visible_dims, projection))
plots.update(_plot(self, canvas, plots, helper_data, helper_prediction, levels, plot_inducing, plot_density, projection))
if hasattr(self, 'Z') and plot_inducing:
plots.update(_plot_inducing(self, canvas, visible_dims, projection, None))
return pl.add_to_canvas(canvas, plots)
def _plot(self, canvas, plots, helper_data, helper_prediction, levels, plot_inducing=True, plot_density=False, projection='2d'): def _plot(self, canvas, plots, helper_data, helper_prediction, levels, plot_inducing=True, plot_density=False, projection='2d'):
plots.update(_plot_mean(self, canvas, helper_data, helper_prediction, levels, projection, None)) plots.update(_plot_mean(self, canvas, helper_data, helper_prediction, levels, projection, 'Mean'))
try:
if projection=='2d': if projection=='2d':
if not plot_density: if not plot_density:
plots.update(_plot_confidence(self, canvas, helper_data, helper_prediction, None)) plots.update(_plot_confidence(self, canvas, helper_data, helper_prediction, "Confidence"))
else: else:
plots.update(_plot_density(self, canvas, helper_data, helper_prediction, None)) plots.update(_plot_density(self, canvas, helper_data, helper_prediction, "Density"))
except RuntimeError:
#plotting in 2d
pass
if helper_prediction[2] is not None: if helper_prediction[2] is not None:
plots.update(_plot_samples(self, canvas, helper_data, helper_prediction, projection, None)) plots.update(_plot_samples(self, canvas, helper_data, helper_prediction, projection, "Samples"))
return plots return plots

View file

@ -35,6 +35,7 @@ from .plot_util import get_x_y_var,\
find_best_layout_for_subplots find_best_layout_for_subplots
def _wait_for_updates(view, updates): def _wait_for_updates(view, updates):
if view is not None:
try: try:
if updates: if updates:
clear = raw_input('yes or enter to deactivate updates - otherwise still do updates - use plots[imshow].deactivate() to clear') clear = raw_input('yes or enter to deactivate updates - otherwise still do updates - use plots[imshow].deactivate() to clear')
@ -45,6 +46,9 @@ def _wait_for_updates(view, updates):
except AttributeError: except AttributeError:
# No updateable view: # No updateable view:
pass pass
except TypeError:
# No updateable view:
pass
def _plot_latent_scatter(canvas, X, visible_dims, labels, marker, num_samples, projection='2d', **kwargs): def _plot_latent_scatter(canvas, X, visible_dims, labels, marker, num_samples, projection='2d', **kwargs):

View file

@ -282,12 +282,12 @@ def get_x_y_var(model):
:returns: (X, X_variance, Y) :returns: (X, X_variance, Y)
""" """
if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs(): if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs():
X = model.X.mean X = model.X.mean.values
X_variance = model.X.variance X_variance = model.X.variance.values
else: else:
X = model.X X = model.X.values
X_variance = None X_variance = None
Y = model.Y Y = model.Y.values
if sparse.issparse(Y): Y = Y.todense().view(np.ndarray) if sparse.issparse(Y): Y = Y.todense().view(np.ndarray)
return X, X_variance, Y return X, X_variance, Y

View file

@ -29,6 +29,7 @@
#=============================================================================== #===============================================================================
from .. import Tango from .. import Tango
from plotly.graph_objs import Line
''' '''
This file is for defaults for the gpy plot, specific to the plotting library. This file is for defaults for the gpy plot, specific to the plotting library.
@ -42,21 +43,22 @@ it gives back an empty default, when defaults are not defined.
''' '''
# Data plots: # Data plots:
data_1d = dict(marker_kwargs=dict(linewidth=.7, ), marker='x', color='black') data_1d = dict(marker_kwargs=dict(), marker='x', color='black')
data_2d = dict(marker='o', cmap='Hot', marker_kwargs=dict(opacity=.5)) data_2d = dict(marker='o', cmap='Hot', marker_kwargs=dict(opacity=1., size='10', line=Line(width=.5, color='black')))
# inducing_1d = dict(lw=0, s=500, facecolors=Tango.colorsHex['darkRed']) inducing_1d = dict(color=Tango.colorsHex['darkRed'])
# inducing_2d = dict(s=14, edgecolors='k', linewidth=.4, facecolors='white', alpha=.5) inducing_2d = dict(marker_kwargs=dict(size='8', opacity=.7, line=Line(width=.5, color='black')), opacity=.7, color='white', marker='star-triangle-up')
# inducing_3d = dict(lw=.3, s=500, facecolors='white', edgecolors='k') inducing_3d = dict(marker_kwargs=dict(size='8', opacity=.7, line=Line(width=.5, color='black')), opacity=.7, color='white', marker='star-triangle-up')
# xerrorbar = dict(color='k', fmt='none', elinewidth=.5, alpha=.5) # xerrorbar = dict(color='k', fmt='none', elinewidth=.5, alpha=.5)
yerrorbar = dict(color=Tango.colorsHex['darkRed'], error_kwargs=dict(thickness=.5), opacity=.5) yerrorbar = dict(color=Tango.colorsHex['darkRed'], error_kwargs=dict(thickness=.5), opacity=.5)
# #
# # GP plots: # # GP plots:
# meanplot_1d = dict(color=Tango.colorsHex['mediumBlue'], linewidth=2) meanplot_1d = dict(color=Tango.colorsHex['mediumBlue'], line_kwargs=dict(width=2))
# meanplot_2d = dict(cmap='hot', linewidth=.5) meanplot_2d = dict(colorscale='Hot')
# meanplot_3d = dict(linewidth=0, antialiased=True, cstride=1, rstride=1, cmap='hot', alpha=.3) meanplot_3d = dict(colorscale='Hot', opacity=.8)
# samples_1d = dict(color=Tango.colorsHex['mediumBlue'], linewidth=.3) samples_1d = dict(color=Tango.colorsHex['mediumBlue'], line_kwargs=dict(width=.3))
# samples_3d = dict(cmap='hot', alpha=.1, antialiased=True, cstride=1, rstride=1, linewidth=0) samples_3d = dict(cmap='Hot', opacity=.5)
# confidence_interval = dict(edgecolor=Tango.colorsHex['darkBlue'], linewidth=.5, color=Tango.colorsHex['lightBlue'],alpha=.2) confidence_interval = dict(mode='lines', line_kwargs=dict(color=Tango.colorsHex['darkBlue'], width=.4),
color=Tango.colorsHex['lightBlue'], opacity=.3)
# density = dict(alpha=.5, color=Tango.colorsHex['lightBlue']) # density = dict(alpha=.5, color=Tango.colorsHex['lightBlue'])
# #
# # GPLVM plots: # # GPLVM plots:
@ -67,8 +69,8 @@ yerrorbar = dict(color=Tango.colorsHex['darkRed'], error_kwargs=dict(thickness=.
# ard = dict(edgecolor='k', linewidth=1.2) # ard = dict(edgecolor='k', linewidth=1.2)
# #
# # Input plots: # # Input plots:
# latent = dict(aspect='auto', cmap='Greys', interpolation='bicubic') latent = dict(colorscale='Greys', reversescale=True)
# gradient = dict(aspect='auto', cmap='RdBu', interpolation='nearest', alpha=.7) gradient = dict(colorscale='RdBu', opacity=.7)
# magnification = dict(aspect='auto', cmap='Greys', interpolation='bicubic') magnification = dict(colorscale='Greys')
# latent_scatter = dict(s=40, linewidth=.2, edgecolor='k', alpha=.9) latent_scatter = dict(marker_kwargs=dict(size='15', opacity=.7))
# annotation = dict(fontdict=dict(family='sans-serif', weight='light', fontsize=9), zorder=.3, alpha=.7) # annotation = dict(fontdict=dict(family='sans-serif', weight='light', fontsize=9), zorder=.3, alpha=.7)

View file

@ -31,11 +31,12 @@ import numpy as np
from ..abstract_plotting_library import AbstractPlottingLibrary from ..abstract_plotting_library import AbstractPlottingLibrary
from .. import Tango from .. import Tango
from . import defaults from . import defaults
import itertools
from plotly import tools from plotly import tools
from plotly import plotly as py from plotly import plotly as py
from plotly import matplotlylib from plotly.graph_objs import Scatter, Scatter3d, Line,\
from plotly.graph_objs import Scatter, Scatter3d, Line, Marker, ErrorX, ErrorY, Bar Marker, ErrorX, ErrorY, Bar, Heatmap, Trace,\
Annotations, Annotation, Contour, Contours, Font, Surface
from plotly.exceptions import PlotlyDictKeyError
SYMBOL_MAP = { SYMBOL_MAP = {
'o': 'dot', 'o': 'dot',
@ -63,39 +64,61 @@ class PlotlyPlots(AbstractPlottingLibrary):
figure = tools.make_subplots(rows, cols, specs=specs) figure = tools.make_subplots(rows, cols, specs=specs)
return figure return figure
def new_canvas(self, figure=None, row=1, col=1, projection='2d', xlabel=None, ylabel=None, zlabel=None, title=None, xlim=None, ylim=None, zlim=None, **kwargs): def new_canvas(self, canvas=None, row=1, col=1, projection='2d', xlabel=None, ylabel=None, zlabel=None, title=None, xlim=None, ylim=None, zlim=None, **kwargs):
if 'filename' not in kwargs: #if 'filename' not in kwargs:
print('PlotlyWarning: filename was not given, this may clutter your plotly workspace') # print('PlotlyWarning: filename was not given, this may clutter your plotly workspace')
filename = None # filename = None
else: #else:
filename = kwargs.pop('filename') # filename = kwargs.pop('filename')
if figure is None: if canvas is None:
figure = self.figure(is_3d=projection=='3d') figure = self.figure(is_3d=projection=='3d')
self.current_states[hex(id(figure))] = dict(filename=filename) figure.layout.font = Font(family="Raleway, sans-serif")
else:
return canvas, kwargs
return (figure, row, col), kwargs return (figure, row, col), kwargs
def add_to_canvas(self, canvas, traces, legend=False, **kwargs): def add_to_canvas(self, canvas, traces, legend=False, **kwargs):
figure, row, col = canvas figure, row, col = canvas
def recursive_append(traces): def append_annotation(a, xref, yref):
for _, trace in traces.items(): if 'xref' not in a:
if isinstance(trace, (tuple, list)): a['xref'] = xref
for t in trace: if 'yref' not in a:
a['yref'] = yref
figure.layout.annotations.append(a)
def append_trace(t, row, col):
figure.append_trace(t, row, col) figure.append_trace(t, row, col)
elif isinstance(trace, dict): def recursive_append(traces):
recursive_append(trace) if isinstance(traces, Annotations):
else: xref, yref = figure._grid_ref[row-1][col-1]
figure.append_trace(trace, row, col) for a in traces:
append_annotation(a, xref, yref)
elif isinstance(traces, (Trace)):
try:
append_trace(traces, row, col)
except PlotlyDictKeyError:
# Its a dictionary of plots:
for t in traces:
recursive_append(traces[t])
elif isinstance(traces, (dict)):
for t in traces:
recursive_append(traces[t])
elif isinstance(traces, (tuple, list)):
for t in traces:
recursive_append(t)
recursive_append(traces) recursive_append(traces)
figure.layout['showlegend'] = legend figure.layout['showlegend'] = legend
return canvas return canvas
def show_canvas(self, canvas, **kwargs): def show_canvas(self, canvas, filename=None, **kwargs):
figure, _, _ = canvas figure, _, _ = canvas
if len(figure.data) == 0:
# add mock data
figure.append_trace(Scatter(x=[], y=[], name='', showlegend=False), 1, 1)
from ..gpy_plot.plot_util import in_ipynb from ..gpy_plot.plot_util import in_ipynb
if in_ipynb(): if in_ipynb():
py.iplot(figure, filename=self.current_states[hex(id(figure))]['filename']) py.iplot(figure, filename=filename)#self.current_states[hex(id(figure))]['filename'])
else: else:
py.plot(figure, filename=self.current_states[hex(id(figure))]['filename']) py.plot(figure, filename=filename)#self.current_states[hex(id(figure))]['filename'])
return figure return figure
def scatter(self, ax, X, Y, Z=None, color=Tango.colorsHex['mediumBlue'], cmap=None, label=None, marker='o', marker_kwargs=None, **kwargs): def scatter(self, ax, X, Y, Z=None, color=Tango.colorsHex['mediumBlue'], cmap=None, label=None, marker='o', marker_kwargs=None, **kwargs):
@ -105,26 +128,35 @@ class PlotlyPlots(AbstractPlottingLibrary):
#not matplotlib marker #not matplotlib marker
pass pass
if Z is not None: if Z is not None:
return Scatter3d(x=X, y=Y, z=Z, mode='markers', marker=Marker(color=color, symbol=marker, colorscale=cmap, **marker_kwargs or {}), name=label, **kwargs) return Scatter3d(x=X, y=Y, z=Z, mode='markers', showlegend=label is not None, marker=Marker(color=color, symbol=marker, colorscale=cmap, **marker_kwargs or {}), name=label, **kwargs)
return Scatter(x=X, y=Y, mode='markers', marker=Marker(color=color, symbol=marker, colorscale=cmap, **marker_kwargs or {}), name=label, **kwargs) return Scatter(x=X, y=Y, mode='markers', showlegend=label is not None, marker=Marker(color=color, symbol=marker, colorscale=cmap, **marker_kwargs or {}), name=label, **kwargs)
def plot(self, ax, X, Y, Z=None, color=None, label=None, line_kwargs=None, **kwargs): def plot(self, ax, X, Y, Z=None, color=None, label=None, line_kwargs=None, **kwargs):
if 'mode' not in kwargs:
kwargs['mode'] = 'lines'
if Z is not None: if Z is not None:
return Scatter3d(x=X, y=Y, z=Z, mode='lines', line=Line(color=color, **line_kwargs or {}), name=label, **kwargs) return Scatter3d(x=X, y=Y, z=Z, showlegend=label is not None, line=Line(color=color, **line_kwargs or {}), name=label, **kwargs)
return Scatter(x=X, y=Y, mode='lines', line=Line(color=color, **line_kwargs or {}), name=label, **kwargs) return Scatter(x=X, y=Y, showlegend=label is not None, line=Line(color=color, **line_kwargs or {}), name=label, **kwargs)
def plot_axis_lines(self, ax, X, color=Tango.colorsHex['mediumBlue'], label=None, **kwargs): def plot_axis_lines(self, ax, X, Z=None, color=Tango.colorsHex['mediumBlue'], label=None, marker_kwargs=None, **kwargs):
from matplotlib import transforms
from matplotlib.path import Path
if 'marker' not in kwargs:
kwargs['marker'] = Path([[-.2,0.], [-.2,.5], [0.,1.], [.2,.5], [.2,0.], [-.2,0.]],
[Path.MOVETO, Path.LINETO, Path.LINETO, Path.LINETO, Path.LINETO, Path.CLOSEPOLY])
if 'transform' not in kwargs:
if X.shape[1] == 1: if X.shape[1] == 1:
kwargs['transform'] = transforms.blended_transform_factory(ax.transData, ax.transAxes) annotations = Annotations()
if X.shape[1] == 2: for n, row in enumerate(X):
return ax.scatter(X[:,0], X[:,1], ax.get_zlim()[0], c=color, label=label, **kwargs) annotations.append(
return ax.scatter(X, np.zeros_like(X), c=color, label=label, **kwargs) Annotation(
text='',
x=row[0], y=0,
yref='paper',
ax=0, ay=20,
arrowhead=2,
arrowsize=1,
arrowwidth=2,
arrowcolor=color,
showarrow=True))
return annotations
#if Z is not None:
# return Scatter3d(x=X[:,0], y=X[:,1], z=0, zref='paper', showlegend=label is not None, mode='markers', marker=Marker(color=color, symbol='diamond-tall', **marker_kwargs or {}), name=label, **kwargs)
#return Scatter(x=X, y=0, mode='markers', showlegend=label is not None, marker=Marker(yref='paper', color=color, symbol='diamond-tall', **marker_kwargs or {}), name=label, **kwargs)
def barplot(self, canvas, x, height, width=0.8, bottom=0, color=Tango.colorsHex['mediumBlue'], label=None, **kwargs): def barplot(self, canvas, x, height, width=0.8, bottom=0, color=Tango.colorsHex['mediumBlue'], label=None, **kwargs):
figure, _, _ = canvas figure, _, _ = canvas
@ -139,8 +171,8 @@ class PlotlyPlots(AbstractPlottingLibrary):
else: else:
error_kwargs.update(dict(array=error, symmetric=True)) error_kwargs.update(dict(array=error, symmetric=True))
if Z is not None: if Z is not None:
return Scatter3d(x=X, y=Y, z=Z, mode='markers', error_x=ErrorX(color=color, **error_kwargs or {}), marker=Marker(size='0'), name=label, **kwargs) return Scatter3d(x=X, y=Y, z=Z, mode='markers', error_x=ErrorX(color=color, **error_kwargs or {}), showlegend=label is not None, marker=Marker(size='0'), name=label, **kwargs)
return Scatter(x=X, y=Y, mode='markers', error_x=ErrorX(color=color, **error_kwargs or {}), marker=Marker(size='0'), name=label, **kwargs) return Scatter(x=X, y=Y, mode='markers', error_x=ErrorX(color=color, **error_kwargs or {}), showlegend=label is not None, marker=Marker(size='0'), name=label, **kwargs)
def yerrorbar(self, ax, X, Y, error, Z=None, color=Tango.colorsHex['mediumBlue'], label=None, error_kwargs=None, **kwargs): def yerrorbar(self, ax, X, Y, error, Z=None, color=Tango.colorsHex['mediumBlue'], label=None, error_kwargs=None, **kwargs):
error_kwargs = error_kwargs or {} error_kwargs = error_kwargs or {}
@ -149,55 +181,71 @@ class PlotlyPlots(AbstractPlottingLibrary):
else: else:
error_kwargs.update(dict(array=error, symmetric=True)) error_kwargs.update(dict(array=error, symmetric=True))
if Z is not None: if Z is not None:
return Scatter3d(x=X, y=Y, z=Z, mode='markers', error_y=ErrorY(color=color, **error_kwargs or {}), marker=Marker(size='0'), name=label, **kwargs) return Scatter3d(x=X, y=Y, z=Z, mode='markers',
return Scatter(x=X, y=Y, mode='markers', error_y=ErrorY(color=color, **error_kwargs or {}), marker=Marker(size='0'), name=label, **kwargs) error_y=ErrorY(color=color, **error_kwargs or {}),
marker=Marker(size='0'), name=label,
showlegend=label is not None, **kwargs)
return Scatter(x=X, y=Y, mode='markers',
error_y=ErrorY(color=color, **error_kwargs or {}),
marker=Marker(size='0'), name=label,
showlegend=label is not None,
**kwargs)
def imshow(self, ax, X, extent=None, label=None, vmin=None, vmax=None, **imshow_kwargs): def imshow(self, ax, X, extent=None, label=None, vmin=None, vmax=None, **imshow_kwargs):
if 'origin' not in imshow_kwargs: if not 'showscale' in imshow_kwargs:
imshow_kwargs['origin'] = 'lower' imshow_kwargs['showscale'] = False
#xmin, xmax, ymin, ymax = extent return Heatmap(z=X, name=label,
#xoffset, yoffset = (xmax - xmin) / (2. * X.shape[0]), (ymax - ymin) / (2. * X.shape[1]) x0=extent[0], dx=float(extent[1]-extent[0])/X.shape[0],
#xmin, xmax, ymin, ymax = extent = xmin-xoffset, xmax+xoffset, ymin-yoffset, ymax+yoffset y0=extent[2], dy=float(extent[3]-extent[2])/X.shape[1],
return ax.imshow(X, label=label, extent=extent, vmin=vmin, vmax=vmax, **imshow_kwargs) zmin=vmin, zmax=vmax,
showlegend=label is not None,
hoverinfo='z',
**imshow_kwargs)
def imshow_interact(self, ax, plot_function, extent=None, label=None, resolution=None, vmin=None, vmax=None, **imshow_kwargs): def imshow_interact(self, ax, plot_function, extent=None, label=None, resolution=None, vmin=None, vmax=None, **imshow_kwargs):
# TODO stream interaction? # TODO stream interaction?
super(PlotlyPlots, self).imshow_interact(ax, plot_function) super(PlotlyPlots, self).imshow_interact(ax, plot_function)
def annotation_heatmap(self, ax, X, annotation, extent=None, label=None, imshow_kwargs=None, **annotation_kwargs): def annotation_heatmap(self, ax, X, annotation, extent=None, label=None, imshow_kwargs=None, **annotation_kwargs):
imshow_kwargs = imshow_kwargs or {}
if 'origin' not in imshow_kwargs:
imshow_kwargs['origin'] = 'lower'
if ('ha' not in annotation_kwargs) and ('horizontalalignment' not in annotation_kwargs):
annotation_kwargs['ha'] = 'center'
if ('va' not in annotation_kwargs) and ('verticalalignment' not in annotation_kwargs):
annotation_kwargs['va'] = 'center'
imshow = self.imshow(ax, X, extent, label, **imshow_kwargs) imshow = self.imshow(ax, X, extent, label, **imshow_kwargs)
if extent is None: x = np.linspace(extent[0], extent[1], X.shape[0])
extent = (0, X.shape[0], 0, X.shape[1]) y = np.linspace(extent[0], extent[1], X.shape[0])
xmin, xmax, ymin, ymax = extent annotations = Annotations()
xoffset, yoffset = (xmax - xmin) / (2. * X.shape[0]), (ymax - ymin) / (2. * X.shape[1]) for n, row in enumerate(annotation):
xmin, xmax, ymin, ymax = extent = xmin+xoffset, xmax-xoffset, ymin+yoffset, ymax-yoffset for m, val in enumerate(row):
xlin = np.linspace(xmin, xmax, X.shape[0], endpoint=False) #var = z[n][m]
ylin = np.linspace(ymin, ymax, X.shape[1], endpoint=False) annotations.append(
annotations = [] Annotation(
for [i, x], [j, y] in itertools.product(enumerate(xlin), enumerate(ylin)): text=str(val),
annotations.append(ax.text(x, y, "{}".format(annotation[j, i]), **annotation_kwargs)) x=x[m], y=y[n],
xref='x1', yref='y1',
font=dict(color='white' if val > 0.5 else 'black'),
showarrow=False))
return imshow, annotations return imshow, annotations
def annotation_heatmap_interact(self, ax, plot_function, extent, label=None, resolution=15, imshow_kwargs=None, **annotation_kwargs): def annotation_heatmap_interact(self, ax, plot_function, extent, label=None, resolution=15, imshow_kwargs=None, **annotation_kwargs):
if 'origin' not in imshow_kwargs: super(PlotlyPlots, self).annotation_heatmap_interact(ax, plot_function, extent)
imshow_kwargs['origin'] = 'lower'
return ImAnnotateController(ax, plot_function, extent, resolution=resolution, imshow_kwargs=imshow_kwargs or {}, **annotation_kwargs)
def contour(self, ax, X, Y, C, levels=20, label=None, **kwargs): def contour(self, ax, X, Y, C, levels=20, label=None, **kwargs):
return ax.contour(X, Y, C, levels=np.linspace(C.min(), C.max(), levels), label=label, **kwargs) return Contour(x=X, y=Y, z=C,
ncontours=levels, contours=Contours(start=C.min(), end=C.max(), size=(C.max()-C.min())/levels),
name=label, **kwargs)
def surface(self, ax, X, Y, Z, color=None, label=None, **kwargs): def surface(self, ax, X, Y, Z, color=None, label=None, **kwargs):
return ax.plot_surface(X, Y, Z, label=label, **kwargs) return Surface(x=X, y=Y, z=Z, name=label, **kwargs)
def fill_between(self, ax, X, lower, upper, color=Tango.colorsHex['mediumBlue'], label=None, **kwargs): def fill_between(self, ax, X, lower, upper, color=Tango.colorsHex['mediumBlue'], label=None, line_kwargs=None, **kwargs):
return ax.fill_between(X, lower, upper, facecolor=color, label=label, **kwargs) if not 'line' in kwargs:
kwargs['line'] = Line(**line_kwargs or {})
else:
kwargs['line'].update(line_kwargs or {})
if color.startswith('#'):
fcolor = 'rgba ({c[0]}, {c[1]}, {c[2]}, {alpha})'.format(c=Tango.hex2rgb(color), alpha=kwargs.get('opacity', 1.0))
else: fcolor = color
u = Scatter(x=X, y=upper, fillcolor=fcolor, showlegend=label is not None, name=label, fill='tonexty', **kwargs)
fcolor = '{}, {alpha})'.format(','.join(fcolor.split(',')[:-1]), alpha=0.0)
l = Scatter(x=X, y=lower, fillcolor=fcolor, showlegend=False, fill='tonexty', name=label, **kwargs)
return l, u
def fill_gradient(self, canvas, X, percentiles, color=Tango.colorsHex['mediumBlue'], label=None, **kwargs): def fill_gradient(self, canvas, X, percentiles, color=Tango.colorsHex['mediumBlue'], label=None, **kwargs):
ax = canvas ax = canvas

View file

@ -32,11 +32,11 @@
#!/usr/bin/env python #!/usr/bin/env python
import matplotlib import matplotlib
matplotlib.use('agg')
matplotlib.rcParams.update(matplotlib.rcParamsDefault) matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.use('agg')
matplotlib.rcParams[u'figure.figsize'] = (4,3) matplotlib.rcParams[u'figure.figsize'] = (4,3)
matplotlib.rcParams[u'text.usetex'] = False matplotlib.rcParams[u'text.usetex'] = False
import nose import nose
nose.main('GPy', defaultTest='GPy/testing/plotting_tests.py') nose.main('GPy', defaultTest='GPy/testing')