mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-15 06:52:39 +02:00
merged last devel
This commit is contained in:
commit
45f692340a
170 changed files with 39094 additions and 4107 deletions
|
|
@ -104,4 +104,4 @@ cdict_Alu = {'red' :((0./5,colorsRGB['Aluminium1'][0]/256.,colorsRGB['Aluminium1
|
|||
(2./5,colorsRGB['Aluminium3'][2]/256.,colorsRGB['Aluminium3'][2]/256.),
|
||||
(3./5,colorsRGB['Aluminium4'][2]/256.,colorsRGB['Aluminium4'][2]/256.),
|
||||
(4./5,colorsRGB['Aluminium5'][2]/256.,colorsRGB['Aluminium5'][2]/256.),
|
||||
(5./5,colorsRGB['Aluminium6'][2]/256.,colorsRGB['Aluminium6'][2]/256.))}
|
||||
(5./5,colorsRGB['Aluminium6'][2]/256.,colorsRGB['Aluminium6'][2]/256.))}
|
||||
|
|
|
|||
|
|
@ -52,6 +52,17 @@ def inject_plotting():
|
|||
GP.plot_f = gpy_plot.gp_plots.plot_f
|
||||
GP.plot_magnification = gpy_plot.latent_plots.plot_magnification
|
||||
|
||||
from ..models import StateSpace
|
||||
StateSpace.plot_data = gpy_plot.data_plots.plot_data
|
||||
StateSpace.plot_data_error = gpy_plot.data_plots.plot_data_error
|
||||
StateSpace.plot_errorbars_trainset = gpy_plot.data_plots.plot_errorbars_trainset
|
||||
StateSpace.plot_mean = gpy_plot.gp_plots.plot_mean
|
||||
StateSpace.plot_confidence = gpy_plot.gp_plots.plot_confidence
|
||||
StateSpace.plot_density = gpy_plot.gp_plots.plot_density
|
||||
StateSpace.plot_samples = gpy_plot.gp_plots.plot_samples
|
||||
StateSpace.plot = gpy_plot.gp_plots.plot
|
||||
StateSpace.plot_f = gpy_plot.gp_plots.plot_f
|
||||
|
||||
from ..core import SparseGP
|
||||
SparseGP.plot_inducing = gpy_plot.data_plots.plot_inducing
|
||||
|
||||
|
|
@ -107,4 +118,4 @@ try:
|
|||
lib = config.get('plotting', 'library')
|
||||
change_plotting_library(lib)
|
||||
except NoOptionError:
|
||||
print("No plotting library was specified in config file. \n{}".format(error_suggestion))
|
||||
print("No plotting library was specified in config file. \n{}".format(error_suggestion))
|
||||
|
|
|
|||
|
|
@ -235,8 +235,6 @@ def plot_density(self, plot_limits=None, fixed_inputs=None,
|
|||
|
||||
Give the Y_metadata in the predict_kw if you need it.
|
||||
|
||||
|
||||
|
||||
:param plot_limits: The limits of the plot. If 1D [xmin,xmax], if 2D [[xmin,ymin],[xmax,ymax]]. Defaluts to data limits
|
||||
:type plot_limits: np.array
|
||||
:param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input dimension i should be set to value v.
|
||||
|
|
@ -420,4 +418,4 @@ def _plot(self, canvas, plots, helper_data, helper_prediction, levels, plot_indu
|
|||
|
||||
if helper_prediction[2] is not None:
|
||||
plots.update(_plot_samples(self, canvas, helper_data, helper_prediction, projection, "Samples"))
|
||||
return plots
|
||||
return plots
|
||||
|
|
|
|||
|
|
@ -140,4 +140,4 @@ def plot_covariance(kernel, x=None, label=None,
|
|||
return pl().add_to_canvas(canvas, plots)
|
||||
|
||||
else:
|
||||
raise NotImplementedError("Cannot plot a kernel with more than two input dimensions")
|
||||
raise NotImplementedError("Cannot plot a kernel with more than two input dimensions")
|
||||
|
|
|
|||
|
|
@ -131,7 +131,9 @@ def plot_latent_inducing(self,
|
|||
|
||||
Z = self.Z.values
|
||||
labels = np.array(['inducing'] * Z.shape[0])
|
||||
scatters = _plot_latent_scatter(canvas, Z, sig_dims, labels, marker, num_samples, projection=projection, **kwargs)
|
||||
kwargs['marker'] = marker
|
||||
update_not_existing_kwargs(kwargs, pl().defaults.inducing_2d) # @UndefinedVariable
|
||||
scatters = _plot_latent_scatter(canvas, Z, sig_dims, labels, num_samples=num_samples, projection=projection, **kwargs)
|
||||
return pl().add_to_canvas(canvas, dict(scatter=scatters), legend=legend)
|
||||
|
||||
|
||||
|
|
@ -147,6 +149,7 @@ def _plot_magnification(self, canvas, which_indices, Xgrid,
|
|||
def plot_function(x):
|
||||
Xtest_full = np.zeros((x.shape[0], Xgrid.shape[1]))
|
||||
Xtest_full[:, which_indices] = x
|
||||
|
||||
mf = self.predict_magnification(Xtest_full, kern=kern, mean=mean, covariance=covariance)
|
||||
return mf.reshape(resolution, resolution).T
|
||||
imshow_kwargs = update_not_existing_kwargs(imshow_kwargs, pl().defaults.magnification)
|
||||
|
|
@ -215,7 +218,12 @@ def _plot_latent(self, canvas, which_indices, Xgrid,
|
|||
def plot_function(x):
|
||||
Xtest_full = np.zeros((x.shape[0], Xgrid.shape[1]))
|
||||
Xtest_full[:, which_indices] = x
|
||||
mf = np.log(self.predict(Xtest_full, kern=kern)[1])
|
||||
mf = self.predict(Xtest_full, kern=kern)[1]
|
||||
if mf.shape[1]==self.output_dim:
|
||||
mf = mf.sum(-1)
|
||||
else:
|
||||
mf *= self.output_dim
|
||||
mf = np.log(mf)
|
||||
return mf.reshape(resolution, resolution).T
|
||||
|
||||
imshow_kwargs = update_not_existing_kwargs(imshow_kwargs, pl().defaults.latent)
|
||||
|
|
|
|||
|
|
@ -194,6 +194,7 @@ def scatter_label_generator(labels, X, visible_dims, marker=None):
|
|||
x = X[index, input_1]
|
||||
y = X[index, input_2]
|
||||
z = X[index, input_3]
|
||||
|
||||
yield x, y, z, this_label, index, m
|
||||
|
||||
def subsample_X(X, labels, num_samples=1000):
|
||||
|
|
@ -385,5 +386,5 @@ def x_frame2D(X,plot_limits=None,resolution=None):
|
|||
|
||||
resolution = resolution or 50
|
||||
xx, yy = np.mgrid[xmin[0]:xmax[0]:1j*resolution,xmin[1]:xmax[1]:1j*resolution]
|
||||
Xnew = np.vstack((xx.flatten(),yy.flatten())).T
|
||||
Xnew = np.c_[xx.flat, yy.flat]
|
||||
return Xnew, xx, yy, xmin, xmax
|
||||
|
|
|
|||
|
|
@ -18,4 +18,4 @@
|
|||
|
||||
|
||||
from .util import align_subplot_array, align_subplots, fewerXticks, removeRightTicks, removeUpperTicks
|
||||
from . import controllers, base_plots
|
||||
from . import controllers, base_plots
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
from .imshow_controller import ImshowController, ImAnnotateController
|
||||
from .imshow_controller import ImshowController, ImAnnotateController
|
||||
|
|
|
|||
|
|
@ -72,4 +72,4 @@ class ImAnnotateController(ImshowController):
|
|||
text.set_x(x+xoffset)
|
||||
text.set_y(y+yoffset)
|
||||
text.set_text("{}".format(X[1][j, i]))
|
||||
return view
|
||||
return view
|
||||
|
|
|
|||
|
|
@ -1,21 +1,21 @@
|
|||
#===============================================================================
|
||||
# Copyright (c) 2015, Max Zwiessele
|
||||
# All rights reserved.
|
||||
#
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
#
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
#
|
||||
# * Neither the name of GPy nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
|
|
@ -34,20 +34,20 @@ from .. import Tango
|
|||
'''
|
||||
This file is for defaults for the gpy plot, specific to the plotting library.
|
||||
|
||||
Create a kwargs dictionary with the right name for the plotting function
|
||||
Create a kwargs dictionary with the right name for the plotting function
|
||||
you are implementing. If you do not provide defaults, the default behaviour of
|
||||
the plotting library will be used.
|
||||
the plotting library will be used.
|
||||
|
||||
In the code, always ise plotting.gpy_plots.defaults to get the defaults, as
|
||||
In the code, always ise plotting.gpy_plots.defaults to get the defaults, as
|
||||
it gives back an empty default, when defaults are not defined.
|
||||
'''
|
||||
|
||||
# Data plots:
|
||||
data_1d = dict(lw=1.5, marker='x', edgecolor='k')
|
||||
data_1d = dict(lw=1.5, marker='x', color='k')
|
||||
data_2d = dict(s=35, edgecolors='none', linewidth=0., cmap=cm.get_cmap('hot'), alpha=.5)
|
||||
inducing_1d = dict(lw=0, s=500, facecolors=Tango.colorsHex['darkRed'])
|
||||
inducing_2d = dict(s=14, edgecolors='k', linewidth=.4, facecolors='white', alpha=.5, marker='^')
|
||||
inducing_3d = dict(lw=.3, s=500, facecolors='white', edgecolors='k')
|
||||
inducing_2d = dict(s=17, edgecolor='k', linewidth=.4, color='white', alpha=.5, marker='^')
|
||||
inducing_3d = dict(lw=.3, s=500, color=Tango.colorsHex['darkRed'], edgecolor='k')
|
||||
xerrorbar = dict(color='k', fmt='none', elinewidth=.5, alpha=.5)
|
||||
yerrorbar = dict(color=Tango.colorsHex['darkRed'], fmt='none', elinewidth=.5, alpha=.5)
|
||||
|
||||
|
|
@ -71,5 +71,5 @@ ard = dict(edgecolor='k', linewidth=1.2)
|
|||
latent = dict(aspect='auto', cmap='Greys', interpolation='bicubic')
|
||||
gradient = dict(aspect='auto', cmap='RdBu', interpolation='nearest', alpha=.7)
|
||||
magnification = dict(aspect='auto', cmap='Greys', interpolation='bicubic')
|
||||
latent_scatter = dict(s=40, linewidth=.2, edgecolor='k', alpha=.9)
|
||||
annotation = dict(fontdict=dict(family='sans-serif', weight='light', fontsize=9), zorder=.3, alpha=.7)
|
||||
latent_scatter = dict(s=20, linewidth=.2, edgecolor='k', alpha=.9)
|
||||
annotation = dict(fontdict=dict(family='sans-serif', weight='light', fontsize=9), zorder=.3, alpha=.7)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
#===============================================================================
|
||||
# Copyright (c) 2015, Max Zwiessele
|
||||
# Copyright (c) 2016, Max Zwiessele, Alan saul
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
|
|
@ -116,4 +116,43 @@ def align_subplot_array(axes,xlim=None, ylim=None):
|
|||
if i<(M*(N-1)):
|
||||
ax.set_xticks([])
|
||||
else:
|
||||
removeUpperTicks(ax)
|
||||
removeUpperTicks(ax)
|
||||
|
||||
def fixed_inputs(model, non_fixed_inputs, fix_routine='median', as_list=True, X_all=False):
|
||||
"""
|
||||
Convenience function for returning back fixed_inputs where the other inputs
|
||||
are fixed using fix_routine
|
||||
:param model: model
|
||||
:type model: Model
|
||||
:param non_fixed_inputs: dimensions of non fixed inputs
|
||||
:type non_fixed_inputs: list
|
||||
:param fix_routine: fixing routine to use, 'mean', 'median', 'zero'
|
||||
:type fix_routine: string
|
||||
:param as_list: if true, will return a list of tuples with (dimension, fixed_val) otherwise it will create the corresponding X matrix
|
||||
:type as_list: boolean
|
||||
"""
|
||||
from ...inference.latent_function_inference.posterior import VariationalPosterior
|
||||
f_inputs = []
|
||||
if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs():
|
||||
X = model.X.mean.values.copy()
|
||||
elif isinstance(model.X, VariationalPosterior):
|
||||
X = model.X.values.copy()
|
||||
else:
|
||||
if X_all:
|
||||
X = model.X_all.copy()
|
||||
else:
|
||||
X = model.X.copy()
|
||||
for i in range(X.shape[1]):
|
||||
if i not in non_fixed_inputs:
|
||||
if fix_routine == 'mean':
|
||||
f_inputs.append( (i, np.mean(X[:,i])) )
|
||||
if fix_routine == 'median':
|
||||
f_inputs.append( (i, np.median(X[:,i])) )
|
||||
else: # set to zero zero
|
||||
f_inputs.append( (i, 0) )
|
||||
if not as_list:
|
||||
X[:,i] = f_inputs[-1][1]
|
||||
if as_list:
|
||||
return f_inputs
|
||||
else:
|
||||
return X
|
||||
|
|
|
|||
|
|
@ -15,7 +15,9 @@ def plot(parameterized, fignum=None, ax=None, colors=None, figsize=(12, 6)):
|
|||
if ax is None:
|
||||
fig = pb.figure(num=fignum, figsize=figsize)
|
||||
if colors is None:
|
||||
colors = pb.gca()._get_lines.color_cycle
|
||||
from ..Tango import mediumList
|
||||
from itertools import cycle
|
||||
colors = cycle(mediumList)
|
||||
pb.clf()
|
||||
else:
|
||||
colors = iter(colors)
|
||||
|
|
@ -64,7 +66,9 @@ def plot_SpikeSlab(parameterized, fignum=None, ax=None, colors=None, side_by_sid
|
|||
else:
|
||||
fig = pb.figure(num=fignum, figsize=(8, min(12, (2 * parameterized.mean.shape[1]))))
|
||||
if colors is None:
|
||||
colors = pb.gca()._get_lines.color_cycle
|
||||
from ..Tango import mediumList
|
||||
from itertools import cycle
|
||||
colors = cycle(mediumList)
|
||||
pb.clf()
|
||||
else:
|
||||
colors = iter(colors)
|
||||
|
|
|
|||
|
|
@ -73,4 +73,4 @@ latent = dict(colorscale='Greys', reversescale=True, zsmooth='best')
|
|||
gradient = dict(colorscale='RdBu', opacity=.7)
|
||||
magnification = dict(colorscale='Greys', zsmooth='best', reversescale=True)
|
||||
latent_scatter = dict(marker_kwargs=dict(size='5', opacity=.7))
|
||||
# annotation = dict(fontdict=dict(family='sans-serif', weight='light', fontsize=9), zorder=.3, alpha=.7)
|
||||
# annotation = dict(fontdict=dict(family='sans-serif', weight='light', fontsize=9), zorder=.3, alpha=.7)
|
||||
|
|
|
|||
|
|
@ -131,14 +131,15 @@ class PlotlyPlots(AbstractPlottingLibrary):
|
|||
#not matplotlib marker
|
||||
pass
|
||||
marker_kwargs = marker_kwargs or {}
|
||||
marker_kwargs.setdefault('symbol', marker)
|
||||
if 'symbol' not in marker_kwargs:
|
||||
marker_kwargs['symbol'] = marker
|
||||
if Z is not None:
|
||||
return Scatter3d(x=X, y=Y, z=Z, mode='markers',
|
||||
showlegend=label is not None,
|
||||
marker=Marker(color=color, colorscale=cmap, **marker_kwargs),
|
||||
name=label, **kwargs)
|
||||
return Scatter(x=X, y=Y, mode='markers', showlegend=label is not None,
|
||||
marker=Marker(color=color, colorscale=cmap, **marker_kwargs or {}),
|
||||
marker=Marker(color=color, colorscale=cmap, **marker_kwargs),
|
||||
name=label, **kwargs)
|
||||
|
||||
def plot(self, ax, X, Y, Z=None, color=None, label=None, line_kwargs=None, **kwargs):
|
||||
|
|
@ -254,7 +255,7 @@ class PlotlyPlots(AbstractPlottingLibrary):
|
|||
font=dict(color='white' if np.abs(var) > 0.8 else 'black', size=10),
|
||||
opacity=.5,
|
||||
showarrow=False,
|
||||
hoverinfo='x'))
|
||||
))
|
||||
return imshow, annotations
|
||||
|
||||
def annotation_heatmap_interact(self, ax, plot_function, extent, label=None, resolution=15, imshow_kwargs=None, **annotation_kwargs):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue