[plotly] todos: fill_gradient

This commit is contained in:
mzwiessele 2015-10-08 14:05:20 +01:00
parent 7f84bec6fb
commit b3154e43b4
10 changed files with 196 additions and 149 deletions

View file

@ -17,6 +17,8 @@ from numpy.testing import Tester
from . import kern
from . import plotting
from .plotting import plotting_library
# Direct imports for convenience:
from .core import Model
from .core.parameterization import Param, Parameterized, ObsAr

View file

@ -1,5 +1,5 @@
# This is the local installation configuration file for GPy
[plotting]
#library = plotly
library = matplotlib
library = plotly
#library = matplotlib

View file

@ -1,3 +1,3 @@
from .. import plotting_library as pl
from .. import plotting_library
pl = plotting_library
from . import data_plots, gp_plots, latent_plots, kernel_plots, plot_util, inference_plots

View file

@ -123,7 +123,7 @@ def plot_data_error(self, which_data_rows='all',
def _plot_data_error(self, canvas, which_data_rows='all',
which_data_ycols='all', visible_dims=None,
projection='2d', **error_kwargs):
projection='2d', label=None, **error_kwargs):
ycols = get_which_data_ycols(self, which_data_ycols)
rows = get_which_data_rows(self, which_data_rows)
@ -139,17 +139,17 @@ def _plot_data_error(self, canvas, which_data_rows='all',
for d in ycols:
update_not_existing_kwargs(error_kwargs, pl.defaults.xerrorbar)
plots['xerrorplot'].append(pl.xerrorbar(canvas, X[rows, free_dims].flatten(), Y[rows, d].flatten(),
2 * np.sqrt(X_variance[rows, free_dims].flatten()),
2 * np.sqrt(X_variance[rows, free_dims].flatten()), label=label,
**error_kwargs))
#2D plotting
elif len(free_dims) == 2:
update_not_existing_kwargs(error_kwargs, pl.defaults.xerrorbar) # @UndefinedVariable
for d in ycols:
plots['xerrorplot'].append(pl.xerrorbar(canvas, X[rows, free_dims[0]].flatten(), Y[rows, d].flatten(),
2 * np.sqrt(X_variance[rows, free_dims[0]].flatten()),
2 * np.sqrt(X_variance[rows, free_dims[0]].flatten()), label=label,
**error_kwargs))
plots['yerrorplot'].append(pl.xerrorbar(canvas, X[rows, free_dims[1]].flatten(), Y[rows, d].flatten(),
2 * np.sqrt(X_variance[rows, free_dims[1]].flatten()),
2 * np.sqrt(X_variance[rows, free_dims[1]].flatten()), label=label,
**error_kwargs))
elif len(free_dims) == 0:
pass #Nothing to plot!

View file

@ -40,7 +40,7 @@ def plot_mean(self, plot_limits=None, fixed_inputs=None,
apply_link=False, visible_dims=None,
which_data_ycols='all',
levels=20, projection='2d',
label=None,
label='gp mean',
predict_kw=None,
**kwargs):
"""
@ -70,8 +70,7 @@ def plot_mean(self, plot_limits=None, fixed_inputs=None,
predict_kw)
plots = _plot_mean(self, canvas, helper_data, helper_prediction,
levels, projection, label, **kwargs)
pl.add_to_canvas(canvas, plots)
return pl.show_canvas(canvas)
return pl.add_to_canvas(canvas, plots)
def _plot_mean(self, canvas, helper_data, helper_prediction,
levels=20, projection='2d', label=None,
@ -87,7 +86,7 @@ def _plot_mean(self, canvas, helper_data, helper_prediction,
else:
if projection == '2d':
update_not_existing_kwargs(kwargs, pl.defaults.meanplot_2d) # @UndefinedVariable
plots = dict(gpmean=[pl.contour(canvas, x, y,
plots = dict(gpmean=[pl.contour(canvas, x[:,0], y[0,:],
mu.reshape(resolution, resolution),
levels=levels, label=label, **kwargs)])
elif projection == '3d':
@ -105,7 +104,7 @@ def _plot_mean(self, canvas, helper_data, helper_prediction,
def plot_confidence(self, lower=2.5, upper=97.5, plot_limits=None, fixed_inputs=None,
resolution=None, plot_raw=False,
apply_link=False, visible_dims=None,
which_data_ycols='all', label=None,
which_data_ycols='all', label='gp confidence',
predict_kw=None,
**kwargs):
"""
@ -157,7 +156,7 @@ def plot_samples(self, plot_limits=None, fixed_inputs=None,
resolution=None, plot_raw=True,
apply_link=False, visible_dims=None,
which_data_ycols='all',
samples=3, projection='2d', label=None,
samples=3, projection='2d', label='gp_samples',
predict_kw=None,
**kwargs):
"""
@ -214,7 +213,7 @@ def plot_density(self, plot_limits=None, fixed_inputs=None,
resolution=None, plot_raw=False,
apply_link=False, visible_dims=None,
which_data_ycols='all',
levels=35, label=None,
levels=35, label='gp density',
predict_kw=None,
**kwargs):
"""
@ -270,7 +269,7 @@ def plot(self, plot_limits=None, fixed_inputs=None,
visible_dims=None,
levels=20, samples=0, samples_likelihood=0, lower=2.5, upper=97.5,
plot_data=True, plot_inducing=True, plot_density=False,
predict_kw=None, projection='2d', **kwargs):
predict_kw=None, projection='2d', legend=False, **kwargs):
"""
Convinience function for plotting the fit of a GP.
@ -311,18 +310,18 @@ def plot(self, plot_limits=None, fixed_inputs=None,
plot_data = False
plots = {}
if plot_data:
plots.update(_plot_data(self, canvas, which_data_rows, which_data_ycols, visible_dims, projection))
plots.update(_plot_data_error(self, canvas, which_data_rows, which_data_ycols, visible_dims, projection))
plots.update(_plot_data(self, canvas, which_data_rows, which_data_ycols, visible_dims, projection, "Data"))
plots.update(_plot_data_error(self, canvas, which_data_rows, which_data_ycols, visible_dims, projection, "Data Error"))
plots.update(_plot(self, canvas, plots, helper_data, helper_prediction, levels, plot_inducing, plot_density, projection))
if plot_raw and (samples_likelihood > 0):
helper_prediction = helper_predict_with_model(self, helper_data[5], False,
apply_link, None,
get_which_data_ycols(self, which_data_ycols),
predict_kw, samples_likelihood)
plots.update(_plot_samples(canvas, helper_data, helper_prediction, projection))
plots.update(_plot_samples(canvas, helper_data, helper_prediction, projection, "Lik Samples"))
if hasattr(self, 'Z') and plot_inducing:
plots.update(_plot_inducing(self, canvas, visible_dims, projection, None))
return pl.add_to_canvas(canvas, plots)
plots.update(_plot_inducing(self, canvas, visible_dims, projection, 'Inducing'))
return pl.add_to_canvas(canvas, plots, legend=legend)
def plot_f(self, plot_limits=None, fixed_inputs=None,
@ -333,7 +332,7 @@ def plot_f(self, plot_limits=None, fixed_inputs=None,
levels=20, samples=0, lower=2.5, upper=97.5,
plot_density=False,
plot_data=True, plot_inducing=True,
projection='2d',
projection='2d', legend=False,
predict_kw=None,
**kwargs):
"""
@ -366,35 +365,27 @@ def plot_f(self, plot_limits=None, fixed_inputs=None,
:param dict error_kwargs: kwargs for the error plot for the plotting library you are using
:param kwargs plot_kwargs: kwargs for the data plot for the plotting library you are using
"""
canvas, _ = pl.new_canvas(projection=projection, **kwargs)
helper_data = helper_for_plot_data(self, plot_limits, visible_dims, fixed_inputs, resolution)
helper_prediction = helper_predict_with_model(self, helper_data[5], True,
apply_link, np.linspace(2.5, 97.5, levels*2) if plot_density else (lower,upper),
get_which_data_ycols(self, which_data_ycols),
predict_kw, samples)
if not apply_link:
# It does not make sense to plot the data (which lives not in the latent function space) into latent function space.
plot_data = False
plots = {}
if plot_data:
plots.update(_plot_data(self, canvas, which_data_rows, which_data_ycols, visible_dims, projection))
plots.update(_plot_data_error(self, canvas, which_data_rows, which_data_ycols, visible_dims, projection))
plots.update(_plot(self, canvas, plots, helper_data, helper_prediction, levels, plot_inducing, plot_density, projection))
if hasattr(self, 'Z') and plot_inducing:
plots.update(_plot_inducing(self, canvas, visible_dims, projection, None))
return pl.add_to_canvas(canvas, plots)
plot(self, plot_limits, fixed_inputs, resolution, True,
apply_link, which_data_ycols, which_data_rows,
visible_dims, levels, samples, 0,
lower, upper, plot_data, plot_inducing,
plot_density, predict_kw, projection, legend)
def _plot(self, canvas, plots, helper_data, helper_prediction, levels, plot_inducing=True, plot_density=False, projection='2d'):
plots.update(_plot_mean(self, canvas, helper_data, helper_prediction, levels, projection, None))
plots.update(_plot_mean(self, canvas, helper_data, helper_prediction, levels, projection, 'Mean'))
if projection=='2d':
if not plot_density:
plots.update(_plot_confidence(self, canvas, helper_data, helper_prediction, None))
else:
plots.update(_plot_density(self, canvas, helper_data, helper_prediction, None))
try:
if projection=='2d':
if not plot_density:
plots.update(_plot_confidence(self, canvas, helper_data, helper_prediction, "Confidence"))
else:
plots.update(_plot_density(self, canvas, helper_data, helper_prediction, "Density"))
except RuntimeError:
#plotting in 2d
pass
if helper_prediction[2] is not None:
plots.update(_plot_samples(self, canvas, helper_data, helper_prediction, projection, None))
plots.update(_plot_samples(self, canvas, helper_data, helper_prediction, projection, "Samples"))
return plots

View file

@ -35,16 +35,20 @@ from .plot_util import get_x_y_var,\
find_best_layout_for_subplots
def _wait_for_updates(view, updates):
try:
if updates:
clear = raw_input('yes or enter to deactivate updates - otherwise still do updates - use plots[imshow].deactivate() to clear')
if clear.lower() in 'yes' or clear == '':
if view is not None:
try:
if updates:
clear = raw_input('yes or enter to deactivate updates - otherwise still do updates - use plots[imshow].deactivate() to clear')
if clear.lower() in 'yes' or clear == '':
view.deactivate()
else:
view.deactivate()
else:
view.deactivate()
except AttributeError:
# No updateable view:
pass
except AttributeError:
# No updateable view:
pass
except TypeError:
# No updateable view:
pass
def _plot_latent_scatter(canvas, X, visible_dims, labels, marker, num_samples, projection='2d', **kwargs):

View file

@ -282,12 +282,12 @@ def get_x_y_var(model):
:returns: (X, X_variance, Y)
"""
if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs():
X = model.X.mean
X_variance = model.X.variance
X = model.X.mean.values
X_variance = model.X.variance.values
else:
X = model.X
X = model.X.values
X_variance = None
Y = model.Y
Y = model.Y.values
if sparse.issparse(Y): Y = Y.todense().view(np.ndarray)
return X, X_variance, Y

View file

@ -29,6 +29,7 @@
#===============================================================================
from .. import Tango
from plotly.graph_objs import Line
'''
This file is for defaults for the gpy plot, specific to the plotting library.
@ -42,21 +43,22 @@ it gives back an empty default, when defaults are not defined.
'''
# Data plots:
data_1d = dict(marker_kwargs=dict(linewidth=.7, ), marker='x', color='black')
data_2d = dict(marker='o', cmap='Hot', marker_kwargs=dict(opacity=.5))
# inducing_1d = dict(lw=0, s=500, facecolors=Tango.colorsHex['darkRed'])
# inducing_2d = dict(s=14, edgecolors='k', linewidth=.4, facecolors='white', alpha=.5)
# inducing_3d = dict(lw=.3, s=500, facecolors='white', edgecolors='k')
data_1d = dict(marker_kwargs=dict(), marker='x', color='black')
data_2d = dict(marker='o', cmap='Hot', marker_kwargs=dict(opacity=1., size='10', line=Line(width=.5, color='black')))
inducing_1d = dict(color=Tango.colorsHex['darkRed'])
inducing_2d = dict(marker_kwargs=dict(size='8', opacity=.7, line=Line(width=.5, color='black')), opacity=.7, color='white', marker='star-triangle-up')
inducing_3d = dict(marker_kwargs=dict(size='8', opacity=.7, line=Line(width=.5, color='black')), opacity=.7, color='white', marker='star-triangle-up')
# xerrorbar = dict(color='k', fmt='none', elinewidth=.5, alpha=.5)
yerrorbar = dict(color=Tango.colorsHex['darkRed'], error_kwargs=dict(thickness=.5), opacity=.5)
#
# # GP plots:
# meanplot_1d = dict(color=Tango.colorsHex['mediumBlue'], linewidth=2)
# meanplot_2d = dict(cmap='hot', linewidth=.5)
# meanplot_3d = dict(linewidth=0, antialiased=True, cstride=1, rstride=1, cmap='hot', alpha=.3)
# samples_1d = dict(color=Tango.colorsHex['mediumBlue'], linewidth=.3)
# samples_3d = dict(cmap='hot', alpha=.1, antialiased=True, cstride=1, rstride=1, linewidth=0)
# confidence_interval = dict(edgecolor=Tango.colorsHex['darkBlue'], linewidth=.5, color=Tango.colorsHex['lightBlue'],alpha=.2)
meanplot_1d = dict(color=Tango.colorsHex['mediumBlue'], line_kwargs=dict(width=2))
meanplot_2d = dict(colorscale='Hot')
meanplot_3d = dict(colorscale='Hot', opacity=.8)
samples_1d = dict(color=Tango.colorsHex['mediumBlue'], line_kwargs=dict(width=.3))
samples_3d = dict(cmap='Hot', opacity=.5)
confidence_interval = dict(mode='lines', line_kwargs=dict(color=Tango.colorsHex['darkBlue'], width=.4),
color=Tango.colorsHex['lightBlue'], opacity=.3)
# density = dict(alpha=.5, color=Tango.colorsHex['lightBlue'])
#
# # GPLVM plots:
@ -67,8 +69,8 @@ yerrorbar = dict(color=Tango.colorsHex['darkRed'], error_kwargs=dict(thickness=.
# ard = dict(edgecolor='k', linewidth=1.2)
#
# # Input plots:
# latent = dict(aspect='auto', cmap='Greys', interpolation='bicubic')
# gradient = dict(aspect='auto', cmap='RdBu', interpolation='nearest', alpha=.7)
# magnification = dict(aspect='auto', cmap='Greys', interpolation='bicubic')
# latent_scatter = dict(s=40, linewidth=.2, edgecolor='k', alpha=.9)
latent = dict(colorscale='Greys', reversescale=True)
gradient = dict(colorscale='RdBu', opacity=.7)
magnification = dict(colorscale='Greys')
latent_scatter = dict(marker_kwargs=dict(size='15', opacity=.7))
# annotation = dict(fontdict=dict(family='sans-serif', weight='light', fontsize=9), zorder=.3, alpha=.7)

View file

@ -31,11 +31,12 @@ import numpy as np
from ..abstract_plotting_library import AbstractPlottingLibrary
from .. import Tango
from . import defaults
import itertools
from plotly import tools
from plotly import plotly as py
from plotly import matplotlylib
from plotly.graph_objs import Scatter, Scatter3d, Line, Marker, ErrorX, ErrorY, Bar
from plotly.graph_objs import Scatter, Scatter3d, Line,\
Marker, ErrorX, ErrorY, Bar, Heatmap, Trace,\
Annotations, Annotation, Contour, Contours, Font, Surface
from plotly.exceptions import PlotlyDictKeyError
SYMBOL_MAP = {
'o': 'dot',
@ -63,39 +64,61 @@ class PlotlyPlots(AbstractPlottingLibrary):
figure = tools.make_subplots(rows, cols, specs=specs)
return figure
def new_canvas(self, figure=None, row=1, col=1, projection='2d', xlabel=None, ylabel=None, zlabel=None, title=None, xlim=None, ylim=None, zlim=None, **kwargs):
if 'filename' not in kwargs:
print('PlotlyWarning: filename was not given, this may clutter your plotly workspace')
filename = None
else:
filename = kwargs.pop('filename')
if figure is None:
def new_canvas(self, canvas=None, row=1, col=1, projection='2d', xlabel=None, ylabel=None, zlabel=None, title=None, xlim=None, ylim=None, zlim=None, **kwargs):
#if 'filename' not in kwargs:
# print('PlotlyWarning: filename was not given, this may clutter your plotly workspace')
# filename = None
#else:
# filename = kwargs.pop('filename')
if canvas is None:
figure = self.figure(is_3d=projection=='3d')
self.current_states[hex(id(figure))] = dict(filename=filename)
figure.layout.font = Font(family="Raleway, sans-serif")
else:
return canvas, kwargs
return (figure, row, col), kwargs
def add_to_canvas(self, canvas, traces, legend=False, **kwargs):
figure, row, col = canvas
def append_annotation(a, xref, yref):
if 'xref' not in a:
a['xref'] = xref
if 'yref' not in a:
a['yref'] = yref
figure.layout.annotations.append(a)
def append_trace(t, row, col):
figure.append_trace(t, row, col)
def recursive_append(traces):
for _, trace in traces.items():
if isinstance(trace, (tuple, list)):
for t in trace:
figure.append_trace(t, row, col)
elif isinstance(trace, dict):
recursive_append(trace)
else:
figure.append_trace(trace, row, col)
if isinstance(traces, Annotations):
xref, yref = figure._grid_ref[row-1][col-1]
for a in traces:
append_annotation(a, xref, yref)
elif isinstance(traces, (Trace)):
try:
append_trace(traces, row, col)
except PlotlyDictKeyError:
# Its a dictionary of plots:
for t in traces:
recursive_append(traces[t])
elif isinstance(traces, (dict)):
for t in traces:
recursive_append(traces[t])
elif isinstance(traces, (tuple, list)):
for t in traces:
recursive_append(t)
recursive_append(traces)
figure.layout['showlegend'] = legend
return canvas
def show_canvas(self, canvas, **kwargs):
def show_canvas(self, canvas, filename=None, **kwargs):
figure, _, _ = canvas
if len(figure.data) == 0:
# add mock data
figure.append_trace(Scatter(x=[], y=[], name='', showlegend=False), 1, 1)
from ..gpy_plot.plot_util import in_ipynb
if in_ipynb():
py.iplot(figure, filename=self.current_states[hex(id(figure))]['filename'])
py.iplot(figure, filename=filename)#self.current_states[hex(id(figure))]['filename'])
else:
py.plot(figure, filename=self.current_states[hex(id(figure))]['filename'])
py.plot(figure, filename=filename)#self.current_states[hex(id(figure))]['filename'])
return figure
def scatter(self, ax, X, Y, Z=None, color=Tango.colorsHex['mediumBlue'], cmap=None, label=None, marker='o', marker_kwargs=None, **kwargs):
@ -105,27 +128,36 @@ class PlotlyPlots(AbstractPlottingLibrary):
#not matplotlib marker
pass
if Z is not None:
return Scatter3d(x=X, y=Y, z=Z, mode='markers', marker=Marker(color=color, symbol=marker, colorscale=cmap, **marker_kwargs or {}), name=label, **kwargs)
return Scatter(x=X, y=Y, mode='markers', marker=Marker(color=color, symbol=marker, colorscale=cmap, **marker_kwargs or {}), name=label, **kwargs)
return Scatter3d(x=X, y=Y, z=Z, mode='markers', showlegend=label is not None, marker=Marker(color=color, symbol=marker, colorscale=cmap, **marker_kwargs or {}), name=label, **kwargs)
return Scatter(x=X, y=Y, mode='markers', showlegend=label is not None, marker=Marker(color=color, symbol=marker, colorscale=cmap, **marker_kwargs or {}), name=label, **kwargs)
def plot(self, ax, X, Y, Z=None, color=None, label=None, line_kwargs=None, **kwargs):
if 'mode' not in kwargs:
kwargs['mode'] = 'lines'
if Z is not None:
return Scatter3d(x=X, y=Y, z=Z, mode='lines', line=Line(color=color, **line_kwargs or {}), name=label, **kwargs)
return Scatter(x=X, y=Y, mode='lines', line=Line(color=color, **line_kwargs or {}), name=label, **kwargs)
def plot_axis_lines(self, ax, X, color=Tango.colorsHex['mediumBlue'], label=None, **kwargs):
from matplotlib import transforms
from matplotlib.path import Path
if 'marker' not in kwargs:
kwargs['marker'] = Path([[-.2,0.], [-.2,.5], [0.,1.], [.2,.5], [.2,0.], [-.2,0.]],
[Path.MOVETO, Path.LINETO, Path.LINETO, Path.LINETO, Path.LINETO, Path.CLOSEPOLY])
if 'transform' not in kwargs:
if X.shape[1] == 1:
kwargs['transform'] = transforms.blended_transform_factory(ax.transData, ax.transAxes)
if X.shape[1] == 2:
return ax.scatter(X[:,0], X[:,1], ax.get_zlim()[0], c=color, label=label, **kwargs)
return ax.scatter(X, np.zeros_like(X), c=color, label=label, **kwargs)
return Scatter3d(x=X, y=Y, z=Z, showlegend=label is not None, line=Line(color=color, **line_kwargs or {}), name=label, **kwargs)
return Scatter(x=X, y=Y, showlegend=label is not None, line=Line(color=color, **line_kwargs or {}), name=label, **kwargs)
def plot_axis_lines(self, ax, X, Z=None, color=Tango.colorsHex['mediumBlue'], label=None, marker_kwargs=None, **kwargs):
if X.shape[1] == 1:
annotations = Annotations()
for n, row in enumerate(X):
annotations.append(
Annotation(
text='',
x=row[0], y=0,
yref='paper',
ax=0, ay=20,
arrowhead=2,
arrowsize=1,
arrowwidth=2,
arrowcolor=color,
showarrow=True))
return annotations
#if Z is not None:
# return Scatter3d(x=X[:,0], y=X[:,1], z=0, zref='paper', showlegend=label is not None, mode='markers', marker=Marker(color=color, symbol='diamond-tall', **marker_kwargs or {}), name=label, **kwargs)
#return Scatter(x=X, y=0, mode='markers', showlegend=label is not None, marker=Marker(yref='paper', color=color, symbol='diamond-tall', **marker_kwargs or {}), name=label, **kwargs)
def barplot(self, canvas, x, height, width=0.8, bottom=0, color=Tango.colorsHex['mediumBlue'], label=None, **kwargs):
figure, _, _ = canvas
if 'barmode' in kwargs:
@ -139,8 +171,8 @@ class PlotlyPlots(AbstractPlottingLibrary):
else:
error_kwargs.update(dict(array=error, symmetric=True))
if Z is not None:
return Scatter3d(x=X, y=Y, z=Z, mode='markers', error_x=ErrorX(color=color, **error_kwargs or {}), marker=Marker(size='0'), name=label, **kwargs)
return Scatter(x=X, y=Y, mode='markers', error_x=ErrorX(color=color, **error_kwargs or {}), marker=Marker(size='0'), name=label, **kwargs)
return Scatter3d(x=X, y=Y, z=Z, mode='markers', error_x=ErrorX(color=color, **error_kwargs or {}), showlegend=label is not None, marker=Marker(size='0'), name=label, **kwargs)
return Scatter(x=X, y=Y, mode='markers', error_x=ErrorX(color=color, **error_kwargs or {}), showlegend=label is not None, marker=Marker(size='0'), name=label, **kwargs)
def yerrorbar(self, ax, X, Y, error, Z=None, color=Tango.colorsHex['mediumBlue'], label=None, error_kwargs=None, **kwargs):
error_kwargs = error_kwargs or {}
@ -149,55 +181,71 @@ class PlotlyPlots(AbstractPlottingLibrary):
else:
error_kwargs.update(dict(array=error, symmetric=True))
if Z is not None:
return Scatter3d(x=X, y=Y, z=Z, mode='markers', error_y=ErrorY(color=color, **error_kwargs or {}), marker=Marker(size='0'), name=label, **kwargs)
return Scatter(x=X, y=Y, mode='markers', error_y=ErrorY(color=color, **error_kwargs or {}), marker=Marker(size='0'), name=label, **kwargs)
return Scatter3d(x=X, y=Y, z=Z, mode='markers',
error_y=ErrorY(color=color, **error_kwargs or {}),
marker=Marker(size='0'), name=label,
showlegend=label is not None, **kwargs)
return Scatter(x=X, y=Y, mode='markers',
error_y=ErrorY(color=color, **error_kwargs or {}),
marker=Marker(size='0'), name=label,
showlegend=label is not None,
**kwargs)
def imshow(self, ax, X, extent=None, label=None, vmin=None, vmax=None, **imshow_kwargs):
if 'origin' not in imshow_kwargs:
imshow_kwargs['origin'] = 'lower'
#xmin, xmax, ymin, ymax = extent
#xoffset, yoffset = (xmax - xmin) / (2. * X.shape[0]), (ymax - ymin) / (2. * X.shape[1])
#xmin, xmax, ymin, ymax = extent = xmin-xoffset, xmax+xoffset, ymin-yoffset, ymax+yoffset
return ax.imshow(X, label=label, extent=extent, vmin=vmin, vmax=vmax, **imshow_kwargs)
if not 'showscale' in imshow_kwargs:
imshow_kwargs['showscale'] = False
return Heatmap(z=X, name=label,
x0=extent[0], dx=float(extent[1]-extent[0])/X.shape[0],
y0=extent[2], dy=float(extent[3]-extent[2])/X.shape[1],
zmin=vmin, zmax=vmax,
showlegend=label is not None,
hoverinfo='z',
**imshow_kwargs)
def imshow_interact(self, ax, plot_function, extent=None, label=None, resolution=None, vmin=None, vmax=None, **imshow_kwargs):
# TODO stream interaction?
super(PlotlyPlots, self).imshow_interact(ax, plot_function)
def annotation_heatmap(self, ax, X, annotation, extent=None, label=None, imshow_kwargs=None, **annotation_kwargs):
imshow_kwargs = imshow_kwargs or {}
if 'origin' not in imshow_kwargs:
imshow_kwargs['origin'] = 'lower'
if ('ha' not in annotation_kwargs) and ('horizontalalignment' not in annotation_kwargs):
annotation_kwargs['ha'] = 'center'
if ('va' not in annotation_kwargs) and ('verticalalignment' not in annotation_kwargs):
annotation_kwargs['va'] = 'center'
imshow = self.imshow(ax, X, extent, label, **imshow_kwargs)
if extent is None:
extent = (0, X.shape[0], 0, X.shape[1])
xmin, xmax, ymin, ymax = extent
xoffset, yoffset = (xmax - xmin) / (2. * X.shape[0]), (ymax - ymin) / (2. * X.shape[1])
xmin, xmax, ymin, ymax = extent = xmin+xoffset, xmax-xoffset, ymin+yoffset, ymax-yoffset
xlin = np.linspace(xmin, xmax, X.shape[0], endpoint=False)
ylin = np.linspace(ymin, ymax, X.shape[1], endpoint=False)
annotations = []
for [i, x], [j, y] in itertools.product(enumerate(xlin), enumerate(ylin)):
annotations.append(ax.text(x, y, "{}".format(annotation[j, i]), **annotation_kwargs))
x = np.linspace(extent[0], extent[1], X.shape[0])
y = np.linspace(extent[0], extent[1], X.shape[0])
annotations = Annotations()
for n, row in enumerate(annotation):
for m, val in enumerate(row):
#var = z[n][m]
annotations.append(
Annotation(
text=str(val),
x=x[m], y=y[n],
xref='x1', yref='y1',
font=dict(color='white' if val > 0.5 else 'black'),
showarrow=False))
return imshow, annotations
def annotation_heatmap_interact(self, ax, plot_function, extent, label=None, resolution=15, imshow_kwargs=None, **annotation_kwargs):
if 'origin' not in imshow_kwargs:
imshow_kwargs['origin'] = 'lower'
return ImAnnotateController(ax, plot_function, extent, resolution=resolution, imshow_kwargs=imshow_kwargs or {}, **annotation_kwargs)
super(PlotlyPlots, self).annotation_heatmap_interact(ax, plot_function, extent)
def contour(self, ax, X, Y, C, levels=20, label=None, **kwargs):
return ax.contour(X, Y, C, levels=np.linspace(C.min(), C.max(), levels), label=label, **kwargs)
return Contour(x=X, y=Y, z=C,
ncontours=levels, contours=Contours(start=C.min(), end=C.max(), size=(C.max()-C.min())/levels),
name=label, **kwargs)
def surface(self, ax, X, Y, Z, color=None, label=None, **kwargs):
return ax.plot_surface(X, Y, Z, label=label, **kwargs)
return Surface(x=X, y=Y, z=Z, name=label, **kwargs)
def fill_between(self, ax, X, lower, upper, color=Tango.colorsHex['mediumBlue'], label=None, **kwargs):
return ax.fill_between(X, lower, upper, facecolor=color, label=label, **kwargs)
def fill_between(self, ax, X, lower, upper, color=Tango.colorsHex['mediumBlue'], label=None, line_kwargs=None, **kwargs):
if not 'line' in kwargs:
kwargs['line'] = Line(**line_kwargs or {})
else:
kwargs['line'].update(line_kwargs or {})
if color.startswith('#'):
fcolor = 'rgba ({c[0]}, {c[1]}, {c[2]}, {alpha})'.format(c=Tango.hex2rgb(color), alpha=kwargs.get('opacity', 1.0))
else: fcolor = color
u = Scatter(x=X, y=upper, fillcolor=fcolor, showlegend=label is not None, name=label, fill='tonexty', **kwargs)
fcolor = '{}, {alpha})'.format(','.join(fcolor.split(',')[:-1]), alpha=0.0)
l = Scatter(x=X, y=lower, fillcolor=fcolor, showlegend=False, fill='tonexty', name=label, **kwargs)
return l, u
def fill_gradient(self, canvas, X, percentiles, color=Tango.colorsHex['mediumBlue'], label=None, **kwargs):
ax = canvas

View file

@ -32,11 +32,11 @@
#!/usr/bin/env python
import matplotlib
matplotlib.use('agg')
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.use('agg')
matplotlib.rcParams[u'figure.figsize'] = (4,3)
matplotlib.rcParams[u'text.usetex'] = False
import nose
nose.main('GPy', defaultTest='GPy/testing/plotting_tests.py')
nose.main('GPy', defaultTest='GPy/testing')