[plotting] cleanup first commit, this cleans the plotting library and adds plotting tests

This commit is contained in:
mzwiessele 2015-10-02 18:26:17 +01:00
parent fee2f3f727
commit b9bfd0fc6d
10 changed files with 200 additions and 690 deletions

View file

@ -477,260 +477,6 @@ class GP(Model):
Ysim = self.likelihood.samples(fsim, Y_metadata=Y_metadata) Ysim = self.likelihood.samples(fsim, Y_metadata=Y_metadata)
return Ysim return Ysim
def plot_f(self, plot_limits=None, which_data_rows='all',
which_data_ycols='all', fixed_inputs=[],
levels=20, samples=0, fignum=None, ax=None, resolution=None,
plot_raw=True,
linecol=None,fillcol=None, Y_metadata=None, data_symbol='kx',
apply_link=False):
"""
Plot the GP's view of the world, where the data is normalized and before applying a likelihood.
This is a call to plot with plot_raw=True.
Data will not be plotted in this, as the GP's view of the world
may live in another space, or units then the data.
Can plot only part of the data and part of the posterior functions
using which_data_rowsm which_data_ycols.
:param plot_limits: The limits of the plot. If 1D [xmin,xmax], if 2D [[xmin,ymin],[xmax,ymax]]. Defaluts to data limits
:type plot_limits: np.array
:param which_data_rows: which of the training data to plot (default all)
:type which_data_rows: 'all' or a slice object to slice model.X, model.Y
:param which_data_ycols: when the data has several columns (independant outputs), only plot these
:type which_data_ycols: 'all' or a list of integers
:param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input index i should be set to value v.
:type fixed_inputs: a list of tuples
:param resolution: the number of intervals to sample the GP on. Defaults to 200 in 1D and 50 (a 50x50 grid) in 2D
:type resolution: int
:param levels: number of levels to plot in a contour plot.
:param levels: for 2D plotting, the number of contour levels to use is ax is None, create a new figure
:type levels: int
:param samples: the number of a posteriori samples to plot
:type samples: int
:param fignum: figure to plot on.
:type fignum: figure number
:param ax: axes to plot on.
:type ax: axes handle
:param linecol: color of line to plot [Tango.colorsHex['darkBlue']]
:type linecol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
:param fillcol: color of fill [Tango.colorsHex['lightBlue']]
:type fillcol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
:param Y_metadata: additional data associated with Y which may be needed
:type Y_metadata: dict
:param data_symbol: symbol as used matplotlib, by default this is a black cross ('kx')
:type data_symbol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) alongside marker type, as is standard in matplotlib.
:param apply_link: if there is a link function of the likelihood, plot the link(f*) rather than f*
:type apply_link: boolean
"""
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from ..plotting.matplot_dep import models_plots
kw = {}
if linecol is not None:
kw['linecol'] = linecol
if fillcol is not None:
kw['fillcol'] = fillcol
return models_plots.plot_fit(self, plot_limits, which_data_rows,
which_data_ycols, fixed_inputs,
levels, samples, fignum, ax, resolution,
plot_raw=plot_raw, Y_metadata=Y_metadata,
data_symbol=data_symbol, apply_link=apply_link, **kw)
def plot(self, plot_limits=None, which_data_rows='all',
which_data_ycols='all', fixed_inputs=[],
levels=20, samples=0, fignum=None, ax=None, resolution=None,
plot_raw=False, linecol=None,fillcol=None, Y_metadata=None,
data_symbol='kx', predict_kw=None, plot_training_data=True, samples_y=0, apply_link=False):
"""
Plot the posterior of the GP.
- In one dimension, the function is plotted with a shaded region identifying two standard deviations.
- In two dimsensions, a contour-plot shows the mean predicted function
- In higher dimensions, use fixed_inputs to plot the GP with some of the inputs fixed.
Can plot only part of the data and part of the posterior functions
using which_data_rowsm which_data_ycols.
:param plot_limits: The limits of the plot. If 1D [xmin,xmax], if 2D [[xmin,ymin],[xmax,ymax]]. Defaluts to data limits
:type plot_limits: np.array
:param which_data_rows: which of the training data to plot (default all)
:type which_data_rows: 'all' or a slice object to slice model.X, model.Y
:param which_data_ycols: when the data has several columns (independant outputs), only plot these
:type which_data_ycols: 'all' or a list of integers
:param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input index i should be set to value v.
:type fixed_inputs: a list of tuples
:param resolution: the number of intervals to sample the GP on. Defaults to 200 in 1D and 50 (a 50x50 grid) in 2D
:type resolution: int
:param levels: number of levels to plot in a contour plot.
:param levels: for 2D plotting, the number of contour levels to use is ax is None, create a new figure
:type levels: int
:param samples: the number of a posteriori samples to plot, p(f*|y)
:type samples: int
:param fignum: figure to plot on.
:type fignum: figure number
:param ax: axes to plot on.
:type ax: axes handle
:param linecol: color of line to plot [Tango.colorsHex['darkBlue']]
:type linecol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
:param fillcol: color of fill [Tango.colorsHex['lightBlue']]
:type fillcol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
:param Y_metadata: additional data associated with Y which may be needed
:type Y_metadata: dict
:param data_symbol: symbol as used matplotlib, by default this is a black cross ('kx')
:type data_symbol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) alongside marker type, as is standard in matplotlib.
:param plot_training_data: whether or not to plot the training points
:type plot_training_data: boolean
:param samples_y: the number of a posteriori samples to plot, p(y*|y)
:type samples_y: int
:param apply_link: if there is a link function of the likelihood, plot the link(f*) rather than f*, when plotting posterior samples f
:type apply_link: boolean
"""
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from ..plotting.matplot_dep import models_plots
kw = {}
if linecol is not None:
kw['linecol'] = linecol
if fillcol is not None:
kw['fillcol'] = fillcol
return models_plots.plot_fit(self, plot_limits, which_data_rows,
which_data_ycols, fixed_inputs,
levels, samples, fignum, ax, resolution,
plot_raw=plot_raw, Y_metadata=Y_metadata,
data_symbol=data_symbol, predict_kw=predict_kw,
plot_training_data=plot_training_data, samples_y=samples_y, apply_link=apply_link, **kw)
def plot_density(self, levels=20, plot_limits=None, fignum=None, ax=None,
fixed_inputs=[], plot_raw=False, edgecolor='none', facecolor='#3465a4',
predict_kw=None,Y_metadata=None,
apply_link=False, resolution=200, **patch_kw):
"""
Plot the posterior density of the GP.
- In one dimension, the function is plotted with a shaded gradient, visualizing the density of the posterior.
- Only implemented for one dimension, for higher dimensions use `plot`.
:param levels: number of levels to plot in the density plot. This is a number between 1 and 100. 1 corresponds to the normal plot_fit.
:type levels: int
:param plot_limits: The limits of the plot. If 1D [xmin,xmax], if 2D [[xmin,ymin],[xmax,ymax]]. Defaluts to data limits
:type plot_limits: np.array
:param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input index i should be set to value v.
:type fixed_inputs: a list of tuples
:param resolution: the number of intervals to sample the GP on. Defaults to 200 in 1D and 50 (a 50x50 grid) in 2D
:type resolution: int
:param edgecolor: color of line to plot [Tango.colorsHex['darkBlue']]
:type edgecolor: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
:param facecolor: color of fill [Tango.colorsHex['lightBlue']]
:type facecolor: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
:param Y_metadata: additional data associated with Y which may be needed
:type Y_metadata: dict
:param apply_link: if there is a link function of the likelihood, plot the link(f*) rather than f*, when plotting posterior samples f
:type apply_link: boolean
:param resolution: resolution of interpolation (how many points to interpolate of the posterior).
:type resolution: int
:param: patch_kw: the keyword arguments for the patchcollection fill.
"""
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from ..plotting.matplot_dep import models_plots
return models_plots.plot_density(self, levels, plot_limits, fignum, ax,
fixed_inputs, plot_raw=plot_raw,
Y_metadata=Y_metadata,
predict_kw=predict_kw,
apply_link=apply_link,
edgecolor=edgecolor, facecolor=facecolor,
**patch_kw)
def plot_data(self, which_data_rows='all',
which_data_ycols='all', visible_dims=None,
fignum=None, ax=None, data_symbol='kx'):
"""
Plot the training data
- For higher dimensions than two, use fixed_inputs to plot the data points with some of the inputs fixed.
Can plot only part of the data
using which_data_rows and which_data_ycols.
:param plot_limits: The limits of the plot. If 1D [xmin,xmax], if 2D [[xmin,ymin],[xmax,ymax]]. Defaluts to data limits
:type plot_limits: np.array
:param which_data_rows: which of the training data to plot (default all)
:type which_data_rows: 'all' or a slice object to slice model.X, model.Y
:param which_data_ycols: when the data has several columns (independant outputs), only plot these
:type which_data_ycols: 'all' or a list of integers
:param visible_dims: an array specifying the input dimensions to plot (maximum two)
:type visible_dims: a numpy array
:param resolution: the number of intervals to sample the GP on. Defaults to 200 in 1D and 50 (a 50x50 grid) in 2D
:type resolution: int
:param levels: number of levels to plot in a contour plot.
:param levels: for 2D plotting, the number of contour levels to use is ax is None, create a new figure
:type levels: int
:param samples: the number of a posteriori samples to plot, p(f*|y)
:type samples: int
:param fignum: figure to plot on.
:type fignum: figure number
:param ax: axes to plot on.
:type ax: axes handle
:param linecol: color of line to plot [Tango.colorsHex['darkBlue']]
:type linecol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
:param fillcol: color of fill [Tango.colorsHex['lightBlue']]
:type fillcol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
:param data_symbol: symbol as used matplotlib, by default this is a black cross ('kx')
:type data_symbol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) alongside marker type, as is standard in matplotlib.
"""
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from ..plotting.matplot_dep import models_plots
kw = {}
return models_plots.plot_data(self, which_data_rows,
which_data_ycols, visible_dims,
fignum, ax, data_symbol, **kw)
def plot_errorbars_trainset(self, which_data_rows='all',
which_data_ycols='all', fixed_inputs=[], fignum=None, ax=None,
linecol=None, data_symbol='kx', predict_kw=None, plot_training_data=True,lw=None):
"""
Plot the posterior error bars corresponding to the training data
- For higher dimensions than two, use fixed_inputs to plot the data points with some of the inputs fixed.
Can plot only part of the data
using which_data_rows and which_data_ycols.
:param which_data_rows: which of the training data to plot (default all)
:type which_data_rows: 'all' or a slice object to slice model.X, model.Y
:param which_data_ycols: when the data has several columns (independant outputs), only plot these
:type which_data_rows: 'all' or a list of integers
:param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input index i should be set to value v.
:type fixed_inputs: a list of tuples
:param fignum: figure to plot on.
:type fignum: figure number
:param ax: axes to plot on.
:type ax: axes handle
:param plot_training_data: whether or not to plot the training points
:type plot_training_data: boolean
"""
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from ..plotting.matplot_dep import models_plots
kw = {}
if lw is not None:
kw['lw'] = lw
return models_plots.plot_errorbars_trainset(self, which_data_rows, which_data_ycols, fixed_inputs,
fignum, ax, linecol, data_symbol,
predict_kw, plot_training_data, **kw)
def plot_magnification(self, labels=None, which_indices=None,
resolution=50, ax=None, marker='o', s=40,
fignum=None, legend=True,
plot_limits=None,
aspect='auto', updates=False, plot_inducing=True, kern=None, **kwargs):
import sys
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from ..plotting.matplot_dep import dim_reduction_plots
return dim_reduction_plots.plot_magnification(self, labels, which_indices,
resolution, ax, marker, s,
fignum, plot_inducing, legend,
plot_limits, aspect, updates, **kwargs)
def input_sensitivity(self, summarize=True): def input_sensitivity(self, summarize=True):
""" """
Returns the sensitivity for each dimension of this model Returns the sensitivity for each dimension of this model

View file

@ -28,3 +28,6 @@ working = True
[cython] [cython]
working = True working = True
[plotting]
library = matplotlib

View file

@ -13,7 +13,7 @@ from functools import wraps
def put_clean(dct, name, func): def put_clean(dct, name, func):
if name in dct: if name in dct:
dct['_clean_{}'.format(name)] = dct[name] #dct['_clean_{}'.format(name)] = dct[name]
dct[name] = func(dct[name]) dct[name] = func(dct[name])
class KernCallsViaSlicerMeta(ParametersChangedMeta): class KernCallsViaSlicerMeta(ParametersChangedMeta):

View file

@ -2,11 +2,30 @@
# Licensed under the BSD 3-clause license (see LICENSE.txt) # Licensed under the BSD 3-clause license (see LICENSE.txt)
try: try:
import matplotlib from ..util.config import config
from . import matplot_dep lib = config.get('plotting', 'library')
if lib == 'matplotlib':
import matplotlib
from . import matplot_dep as plotting_library
except (ImportError, NameError): except (ImportError, NameError):
# Matplotlib not available
import warnings import warnings
warnings.warn(ImportWarning("Matplotlib not available, install newest version of Matplotlib for plotting")) warnings.warn(ImportWarning("{} not available, install newest version of {} for plotting").format(lib, lib))
#sys.modules['matplotlib'] = config.set('plotting', 'library', 'none')
#sys.modules[__name__+'.matplot_dep'] = ImportWarning("Matplotlib not available, install newest version of Matplotlib for plotting")
if config.get('plotting', 'library') is not 'none':
# Inject the plots into classes here:
# Already converted to new style:
from . import gpy_plot
from ..core import GP
GP.plot_data = gpy_plot.data_plots.plot_data
# Still to convert to new style:
GP.plot = plotting_library.models_plots.plot_fit
GP.plot_f = plotting_library.models_plots.plot_fit_f
GP.plot_density = plotting_library.models_plots.plot_density
GP.plot_errorbars_trainset = plotting_library.models_plots.plot_errorbars_trainset
GP.plot_magnification = plotting_library.dim_reduction_plots.plot_magnification

View file

@ -1,6 +1,54 @@
# Copyright (c) 2014, GPy authors (see AUTHORS.txt). # Copyright (c) 2014, GPy authors (see AUTHORS.txt).
# Licensed under the BSD 3-clause license (see LICENSE.txt) # Licensed under the BSD 3-clause license (see LICENSE.txt)
from matplotlib import pyplot as plt
from . import defaults
def get_new_canvas(kwargs):
"""
Return a canvas, kwargupdate for matplotlib. This just a
dictionary for the collection and we add the an axis to kwarg.
This method does two things, it creates an empty canvas
and updates the kwargs (deletes the unnecessary kwargs)
for further usage in normal plotting.
in matplotlib this means it deletes references to ax, as
plotting is done on the axis itself and is not a kwarg.
"""
if 'ax' in kwargs:
ax = kwargs.pop('ax')
elif 'num' in kwargs and 'figsize' in kwargs:
ax = plt.figure(num=kwargs.pop('num'), figsize=kwargs.pop('figsize')).add_subplot(111)
elif 'num' in kwargs:
ax = plt.figure(num=kwargs.pop('num')).add_subplot(111)
elif 'figsize' in kwargs:
ax = plt.figure(figsize=kwargs.pop('figsize')).add_subplot(111)
else:
ax = plt.figure().add_subplot(111)
# Add ax to kwargs to add all subsequent plots to this axis:
#kwargs['ax'] = ax
return ax, kwargs
def show_canvas(canvas):
try:
canvas.figure.canvas.draw()
canvas.figure.tight_layout()
except:
pass
return canvas
def scatter(ax, *args, **kwargs):
ax.scatter(*args, **kwargs)
def plot(ax, *args, **kwargs):
ax.plot(*args, **kwargs)
def imshow(ax, *args, **kwargs):
ax.imshow(*args, **kwargs)
from . import base_plots from . import base_plots
from . import models_plots from . import models_plots
from . import priors_plots from . import priors_plots
@ -11,8 +59,9 @@ from . import mapping_plots
from . import Tango from . import Tango
from . import visualize from . import visualize
from . import latent_space_visualizations from . import latent_space_visualizations
from . import netpbmfile
from . import inference_plots from . import inference_plots
from . import maps from . import maps
from . import img_plots from . import img_plots
from .ssgplvm import SSGPLVM_plot from .ssgplvm import SSGPLVM_plot

View file

@ -1,11 +1,11 @@
# #Copyright (c) 2012, GPy authors (see AUTHORS.txt). # #Copyright (c) 2012, GPy authors (see AUTHORS.txt).
# Licensed under the BSD 3-clause license (see LICENSE.txt) # Licensed under the BSD 3-clause license (see LICENSE.txt)
from matplotlib import pyplot as pb from matplotlib import pyplot as plt
import numpy as np import numpy as np
def ax_default(fignum, ax): def ax_default(fignum, ax):
if ax is None: if ax is None:
fig = pb.figure(fignum) fig = plt.figure(fignum)
ax = fig.add_subplot(111) ax = fig.add_subplot(111)
else: else:
fig = ax.figure fig = ax.figure
@ -35,12 +35,14 @@ def gpplot(x, mu, lower, upper, edgecol='#3300FF', fillcol='#33CCFF', ax=None, f
plots.append(axes.fill(np.hstack((x,x[::-1])),np.hstack((upper,lower[::-1])),color=fillcol,**kwargs)) plots.append(axes.fill(np.hstack((x,x[::-1])),np.hstack((upper,lower[::-1])),color=fillcol,**kwargs))
#this is the edge: #this is the edge:
plots.append(meanplot(x, upper,color=edgecol,linewidth=0.2,ax=axes)) plots.append(meanplot(x, upper,color=edgecol, linewidth=0.2, ax=axes))
plots.append(meanplot(x, lower,color=edgecol,linewidth=0.2,ax=axes)) plots.append(meanplot(x, lower,color=edgecol, linewidth=0.2, ax=axes))
return plots return plots
def plot_gradient_fill(ax, x, percentiles, **kwargs): def gradient_fill(x, percentiles, ax=None, fignum=None, **kwargs):
_, ax = ax_default(fignum, ax)
plots = [] plots = []
#here's the box #here's the box
@ -150,19 +152,19 @@ def gperrors(x, mu, lower, upper, edgecol=None, ax=None, fignum=None, **kwargs):
def removeRightTicks(ax=None): def removeRightTicks(ax=None):
ax = ax or pb.gca() ax = ax or plt.gca()
for i, line in enumerate(ax.get_yticklines()): for i, line in enumerate(ax.get_yticklines()):
if i%2 == 1: # odd indices if i%2 == 1: # odd indices
line.set_visible(False) line.set_visible(False)
def removeUpperTicks(ax=None): def removeUpperTicks(ax=None):
ax = ax or pb.gca() ax = ax or plt.gca()
for i, line in enumerate(ax.get_xticklines()): for i, line in enumerate(ax.get_xticklines()):
if i%2 == 1: # odd indices if i%2 == 1: # odd indices
line.set_visible(False) line.set_visible(False)
def fewerXticks(ax=None,divideby=2): def fewerXticks(ax=None,divideby=2):
ax = ax or pb.gca() ax = ax or plt.gca()
ax.set_xticks(ax.get_xticks()[::divideby]) ax.set_xticks(ax.get_xticks()[::divideby])
def align_subplots(N,M,xlim=None, ylim=None): def align_subplots(N,M,xlim=None, ylim=None):
@ -171,33 +173,33 @@ def align_subplots(N,M,xlim=None, ylim=None):
if xlim is None: if xlim is None:
xlim = [np.inf,-np.inf] xlim = [np.inf,-np.inf]
for i in range(N*M): for i in range(N*M):
pb.subplot(N,M,i+1) plt.subplot(N,M,i+1)
xlim[0] = min(xlim[0],pb.xlim()[0]) xlim[0] = min(xlim[0],plt.xlim()[0])
xlim[1] = max(xlim[1],pb.xlim()[1]) xlim[1] = max(xlim[1],plt.xlim()[1])
if ylim is None: if ylim is None:
ylim = [np.inf,-np.inf] ylim = [np.inf,-np.inf]
for i in range(N*M): for i in range(N*M):
pb.subplot(N,M,i+1) plt.subplot(N,M,i+1)
ylim[0] = min(ylim[0],pb.ylim()[0]) ylim[0] = min(ylim[0],plt.ylim()[0])
ylim[1] = max(ylim[1],pb.ylim()[1]) ylim[1] = max(ylim[1],plt.ylim()[1])
for i in range(N*M): for i in range(N*M):
pb.subplot(N,M,i+1) plt.subplot(N,M,i+1)
pb.xlim(xlim) plt.xlim(xlim)
pb.ylim(ylim) plt.ylim(ylim)
if (i)%M: if (i)%M:
pb.yticks([]) plt.yticks([])
else: else:
removeRightTicks() removeRightTicks()
if i<(M*(N-1)): if i<(M*(N-1)):
pb.xticks([]) plt.xticks([])
else: else:
removeUpperTicks() removeUpperTicks()
def align_subplot_array(axes,xlim=None, ylim=None): def align_subplot_array(axes,xlim=None, ylim=None):
""" """
Make all of the axes in the array hae the same limits, turn off unnecessary ticks Make all of the axes in the array hae the same limits, turn off unnecessary ticks
use pb.subplots() to get an array of axes use plt.subplots() to get an array of axes
""" """
#find sensible xlim,ylim #find sensible xlim,ylim
if xlim is None: if xlim is None:

View file

@ -9,12 +9,13 @@ from ...models.sparse_gp_coregionalized_regression import SparseGPCoregionalized
from scipy import sparse from scipy import sparse
from ...core.parameterization.variational import VariationalPosterior from ...core.parameterization.variational import VariationalPosterior
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from .base_plots import plot_gradient_fill from .base_plots import gradient_fill
from functools import wraps
def plot_data(model, which_data_rows='all', def plot_data(self, which_data_rows='all',
which_data_ycols='all', visible_dims=None, which_data_ycols='all', visible_dims=None,
fignum=None, ax=None, data_symbol='kx',mew=1.5): fignum=None, ax=None, data_symbol='kx',mew=1.5,**kwargs):
""" """
Plot the training data Plot the training data
- For higher dimensions than two, use fixed_inputs to plot the data points with some of the inputs fixed. - For higher dimensions than two, use fixed_inputs to plot the data points with some of the inputs fixed.
@ -23,7 +24,7 @@ def plot_data(model, which_data_rows='all',
using which_data_rows and which_data_ycols. using which_data_rows and which_data_ycols.
:param which_data_rows: which of the training data to plot (default all) :param which_data_rows: which of the training data to plot (default all)
:type which_data_rows: 'all' or a slice object to slice model.X, model.Y :type which_data_rows: 'all' or a slice object to slice self.X, self.Y
:param which_data_ycols: when the data has several columns (independant outputs), only plot these :param which_data_ycols: when the data has several columns (independant outputs), only plot these
:type which_data_rows: 'all' or a list of integers :type which_data_rows: 'all' or a list of integers
:param visible_dims: an array specifying the input dimensions to plot (maximum two) :param visible_dims: an array specifying the input dimensions to plot (maximum two)
@ -37,23 +38,23 @@ def plot_data(model, which_data_rows='all',
if which_data_rows == 'all': if which_data_rows == 'all':
which_data_rows = slice(None) which_data_rows = slice(None)
if which_data_ycols == 'all': if which_data_ycols == 'all':
which_data_ycols = np.arange(model.output_dim) which_data_ycols = np.arange(self.output_dim)
if ax is None: if ax is None:
fig = plt.figure(num=fignum) fig = plt.figure(num=fignum)
ax = fig.add_subplot(111) ax = fig.add_subplot(111)
if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs(): if hasattr(self, 'has_uncertain_inputs') and self.has_uncertain_inputs():
X = model.X.mean X = self.X.mean
X_variance = model.X.variance X_variance = self.X.variance
else: else:
X = model.X X = self.X
X_variance = None X_variance = None
Y = model.Y Y = self.Y
#work out what the inputs are for plotting (1D or 2D) #work out what the inputs are for plotting (1D or 2D)
if visible_dims is None: if visible_dims is None:
visible_dims = np.arange(model.input_dim) visible_dims = np.arange(self.input_dim)
assert visible_dims.size <= 2, "Visible inputs cannot be larger than two" assert visible_dims.size <= 2, "Visible inputs cannot be larger than two"
free_dims = visible_dims free_dims = visible_dims
plots = {} plots = {}
@ -80,7 +81,7 @@ def plot_data(model, which_data_rows='all',
return plots return plots
def plot_fit(model, plot_limits=None, which_data_rows='all', def plot_fit(self, plot_limits=None, which_data_rows='all',
which_data_ycols='all', fixed_inputs=[], which_data_ycols='all', fixed_inputs=[],
levels=20, samples=0, fignum=None, ax=None, resolution=None, levels=20, samples=0, fignum=None, ax=None, resolution=None,
plot_raw=False, plot_raw=False,
@ -98,7 +99,7 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
:param plot_limits: The limits of the plot. If 1D [xmin,xmax], if 2D [[xmin,ymin],[xmax,ymax]]. Defaluts to data limits :param plot_limits: The limits of the plot. If 1D [xmin,xmax], if 2D [[xmin,ymin],[xmax,ymax]]. Defaluts to data limits
:type plot_limits: np.array :type plot_limits: np.array
:param which_data_rows: which of the training data to plot (default all) :param which_data_rows: which of the training data to plot (default all)
:type which_data_rows: 'all' or a slice object to slice model.X, model.Y :type which_data_rows: 'all' or a slice object to slice self.X, self.Y
:param which_data_ycols: when the data has several columns (independant outputs), only plot these :param which_data_ycols: when the data has several columns (independant outputs), only plot these
:type which_data_rows: 'all' or a list of integers :type which_data_rows: 'all' or a list of integers
:param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input index i should be set to value v. :param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input index i should be set to value v.
@ -134,98 +135,98 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
if which_data_rows == 'all': if which_data_rows == 'all':
which_data_rows = slice(None) which_data_rows = slice(None)
if which_data_ycols == 'all': if which_data_ycols == 'all':
which_data_ycols = np.arange(model.output_dim) which_data_ycols = np.arange(self.output_dim)
#if len(which_data_ycols)==0: #if len(which_data_ycols)==0:
#raise ValueError('No data selected for plotting') #raise ValueError('No data selected for plotting')
if ax is None: if ax is None:
fig = plt.figure(num=fignum) fig = plt.figure(num=fignum)
ax = fig.add_subplot(111) ax = fig.add_subplot(111)
if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs(): if hasattr(self, 'has_uncertain_inputs') and self.has_uncertain_inputs():
X = model.X.mean X = self.X.mean
X_variance = model.X.variance X_variance = self.X.variance
else: else:
X = model.X X = self.X
Y = model.Y Y = self.Y
if sparse.issparse(Y): Y = Y.todense().view(np.ndarray) if sparse.issparse(Y): Y = Y.todense().view(np.ndarray)
if hasattr(model, 'Z'): Z = model.Z if hasattr(self, 'Z'): Z = self.Z
if predict_kw is None: if predict_kw is None:
predict_kw = {} predict_kw = {}
#work out what the inputs are for plotting (1D or 2D) #work out what the inputs are for plotting (1D or 2D)
fixed_dims = np.array([i for i,v in fixed_inputs]) fixed_dims = np.array([i for i,v in fixed_inputs])
free_dims = np.setdiff1d(np.arange(model.input_dim),fixed_dims) free_dims = np.setdiff1d(np.arange(self.input_dim),fixed_dims)
plots = {} plots = {}
#one dimensional plotting #one dimensional plotting
if len(free_dims) == 1: if len(free_dims) == 1:
#define the frame on which to plot #define the frame on which to plot
Xnew, xmin, xmax = x_frame1D(X[:,free_dims], plot_limits=plot_limits, resolution=resolution or 200) Xnew, xmin, xmax = x_frame1D(X[:,free_dims], plot_limits=plot_limits, resolution=resolution or 200)
Xgrid = np.empty((Xnew.shape[0],model.input_dim)) Xgrid = np.empty((Xnew.shape[0],self.input_dim))
Xgrid[:,free_dims] = Xnew Xgrid[:,free_dims] = Xnew
for i,v in fixed_inputs: for i,v in fixed_inputs:
Xgrid[:,i] = v Xgrid[:,i] = v
#make a prediction on the frame and plot it #make a prediction on the frame and plot it
if plot_raw: if plot_raw:
m, v = model._raw_predict(Xgrid, **predict_kw) m, v = self._raw_predict(Xgrid, **predict_kw)
if apply_link: if apply_link:
lower = model.likelihood.gp_link.transf(m - 2*np.sqrt(v)) lower = self.likelihood.gp_link.transf(m - 2*np.sqrt(v))
upper = model.likelihood.gp_link.transf(m + 2*np.sqrt(v)) upper = self.likelihood.gp_link.transf(m + 2*np.sqrt(v))
#Once transformed this is now the median of the function #Once transformed this is now the median of the function
m = model.likelihood.gp_link.transf(m) m = self.likelihood.gp_link.transf(m)
else: else:
lower = m - 2*np.sqrt(v) lower = m - 2*np.sqrt(v)
upper = m + 2*np.sqrt(v) upper = m + 2*np.sqrt(v)
else: else:
if isinstance(model,GPCoregionalizedRegression) or isinstance(model,SparseGPCoregionalizedRegression): if isinstance(self,GPCoregionalizedRegression) or isinstance(self,SparseGPCoregionalizedRegression):
extra_data = Xgrid[:,-1:].astype(np.int) extra_data = Xgrid[:,-1:].astype(np.int)
if Y_metadata is None: if Y_metadata is None:
Y_metadata = {'output_index': extra_data} Y_metadata = {'output_index': extra_data}
else: else:
Y_metadata['output_index'] = extra_data Y_metadata['output_index'] = extra_data
m, v = model.predict(Xgrid, full_cov=False, Y_metadata=Y_metadata, **predict_kw) m, v = self.predict(Xgrid, full_cov=False, Y_metadata=Y_metadata, **predict_kw)
fmu, fv = model._raw_predict(Xgrid, full_cov=False, **predict_kw) fmu, fv = self._raw_predict(Xgrid, full_cov=False, **predict_kw)
lower, upper = model.likelihood.predictive_quantiles(fmu, fv, (2.5, 97.5), Y_metadata=Y_metadata) lower, upper = self.likelihood.predictive_quantiles(fmu, fv, (2.5, 97.5), Y_metadata=Y_metadata)
for d in which_data_ycols: for d in which_data_ycols:
plots['gpplot'] = gpplot(Xnew, m[:, d], lower[:, d], upper[:, d], ax=ax, edgecol=linecol, fillcol=fillcol) plots['gpplot'] = gpplot(Xnew, m[:, d], lower[:, d], upper[:, d], ax=ax, edgecol=linecol, fillcol=fillcol)
#if not plot_raw: plots['dataplot'] = ax.plot(X[which_data_rows,free_dims], Y[which_data_rows, d], data_symbol, mew=1.5) #if not plot_raw: plots['dataplot'] = ax.plot(X[which_data_rows,free_dims], Y[which_data_rows, d], data_symbol, mew=1.5)
if not plot_raw and plot_training_data: if not plot_raw and plot_training_data:
plots['dataplot'] = plot_data(model=model, which_data_rows=which_data_rows, plots['dataplot'] = plot_data(self=self, which_data_rows=which_data_rows,
visible_dims=free_dims, data_symbol=data_symbol, mew=1.5, ax=ax, fignum=fignum) visible_dims=free_dims, data_symbol=data_symbol, mew=1.5, ax=ax, fignum=fignum)
#optionally plot some samples #optionally plot some samples
if samples: #NOTE not tested with fixed_inputs if samples: #NOTE not tested with fixed_inputs
Fsim = model.posterior_samples_f(Xgrid, samples) Fsim = self.posterior_samples_f(Xgrid, samples)
if apply_link: if apply_link:
Fsim = model.likelihood.gp_link.transf(Fsim) Fsim = self.likelihood.gp_link.transf(Fsim)
for fi in Fsim.T: for fi in Fsim.T:
plots['posterior_samples'] = ax.plot(Xnew, fi[:,None], '#3300FF', linewidth=0.25) plots['posterior_samples'] = ax.plot(Xnew, fi[:,None], '#3300FF', linewidth=0.25)
#ax.plot(Xnew, fi[:,None], marker='x', linestyle='--',color=Tango.colorsHex['darkBlue']) #TODO apply this line for discrete outputs. #ax.plot(Xnew, fi[:,None], marker='x', linestyle='--',color=Tango.colorsHex['darkBlue']) #TODO apply this line for discrete outputs.
if samples_y: #NOTE not tested with fixed_inputs if samples_y: #NOTE not tested with fixed_inputs
Ysim = model.posterior_samples(Xgrid, samples_y, Y_metadata=Y_metadata) Ysim = self.posterior_samples(Xgrid, samples_y, Y_metadata=Y_metadata)
for yi in Ysim.T: for yi in Ysim.T:
plots['posterior_samples_y'] = ax.scatter(Xnew, yi[:,None], s=5, c=Tango.colorsHex['darkBlue'], marker='o', alpha=0.5) plots['posterior_samples_y'] = ax.scatter(Xnew, yi[:,None], s=5, c=Tango.colorsHex['darkBlue'], marker='o', alpha=0.5)
#ax.plot(Xnew, yi[:,None], marker='x', linestyle='--',color=Tango.colorsHex['darkBlue']) #TODO apply this line for discrete outputs. #ax.plot(Xnew, yi[:,None], marker='x', linestyle='--',color=Tango.colorsHex['darkBlue']) #TODO apply this line for discrete outputs.
#add error bars for uncertain (if input uncertainty is being modelled) #add error bars for uncertain (if input uncertainty is being modelled)
if hasattr(model,"has_uncertain_inputs") and model.has_uncertain_inputs() and plot_uncertain_inputs: if hasattr(self,"has_uncertain_inputs") and self.has_uncertain_inputs() and plot_uncertain_inputs:
if plot_raw: if plot_raw:
#add error bars for uncertain (if input uncertainty is being modelled), for plot_f #add error bars for uncertain (if input uncertainty is being modelled), for plot_f
#Hack to plot error bars on latent function, rather than on the data #Hack to plot error bars on latent function, rather than on the data
vs = model.X.mean.values.copy() vs = self.X.mean.values.copy()
for i,v in fixed_inputs: for i,v in fixed_inputs:
vs[:,i] = v vs[:,i] = v
m_X, _ = model._raw_predict(vs) m_X, _ = self._raw_predict(vs)
if apply_link: if apply_link:
m_X = model.likelihood.gp_link.transf(m_X) m_X = self.likelihood.gp_link.transf(m_X)
plots['xerrorbar'] = ax.errorbar(X[which_data_rows, free_dims].flatten(), m_X[which_data_rows, which_data_ycols].flatten(), plots['xerrorbar'] = ax.errorbar(X[which_data_rows, free_dims].flatten(), m_X[which_data_rows, which_data_ycols].flatten(),
xerr=2 * np.sqrt(X_variance[which_data_rows, free_dims].flatten()), xerr=2 * np.sqrt(X_variance[which_data_rows, free_dims].flatten()),
ecolor='k', fmt=None, elinewidth=.5, alpha=.5) ecolor='k', fmt=None, elinewidth=.5, alpha=.5)
@ -243,9 +244,9 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
pass pass
#add inducing inputs (if a sparse model is used) #add inducing inputs (if a sparse model is used)
if hasattr(model,"Z"): if hasattr(self,"Z"):
#Zu = model.Z[:,free_dims] * model._Xscale[:,free_dims] + model._Xoffset[:,free_dims] #Zu = self.Z[:,free_dims] * self._Xscale[:,free_dims] + self._Xoffset[:,free_dims]
if isinstance(model,SparseGPCoregionalizedRegression): if isinstance(self,SparseGPCoregionalizedRegression):
Z = Z[Z[:,-1] == Y_metadata['output_index'],:] Z = Z[Z[:,-1] == Y_metadata['output_index'],:]
Zu = Z[:,free_dims] Zu = Z[:,free_dims]
z_height = ax.get_ylim()[0] z_height = ax.get_ylim()[0]
@ -259,7 +260,7 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
#define the frame for plotting on #define the frame for plotting on
resolution = resolution or 50 resolution = resolution or 50
Xnew, _, _, xmin, xmax = x_frame2D(X[:,free_dims], plot_limits, resolution) Xnew, _, _, xmin, xmax = x_frame2D(X[:,free_dims], plot_limits, resolution)
Xgrid = np.empty((Xnew.shape[0],model.input_dim)) Xgrid = np.empty((Xnew.shape[0],self.input_dim))
Xgrid[:,free_dims] = Xnew Xgrid[:,free_dims] = Xnew
for i,v in fixed_inputs: for i,v in fixed_inputs:
Xgrid[:,i] = v Xgrid[:,i] = v
@ -267,15 +268,15 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
#predict on the frame and plot #predict on the frame and plot
if plot_raw: if plot_raw:
m, _ = model._raw_predict(Xgrid, **predict_kw) m, _ = self._raw_predict(Xgrid, **predict_kw)
else: else:
if isinstance(model,GPCoregionalizedRegression) or isinstance(model,SparseGPCoregionalizedRegression): if isinstance(self,GPCoregionalizedRegression) or isinstance(self,SparseGPCoregionalizedRegression):
extra_data = Xgrid[:,-1:].astype(np.int) extra_data = Xgrid[:,-1:].astype(np.int)
if Y_metadata is None: if Y_metadata is None:
Y_metadata = {'output_index': extra_data} Y_metadata = {'output_index': extra_data}
else: else:
Y_metadata['output_index'] = extra_data Y_metadata['output_index'] = extra_data
m, v = model.predict(Xgrid, full_cov=False, Y_metadata=Y_metadata, **predict_kw) m, v = self.predict(Xgrid, full_cov=False, Y_metadata=Y_metadata, **predict_kw)
for d in which_data_ycols: for d in which_data_ycols:
m_d = m[:,d].reshape(resolution, resolution).T m_d = m[:,d].reshape(resolution, resolution).T
plots['contour'] = ax.contour(x, y, m_d, levels, vmin=m.min(), vmax=m.max(), cmap=plt.cm.jet) plots['contour'] = ax.contour(x, y, m_d, levels, vmin=m.min(), vmax=m.max(), cmap=plt.cm.jet)
@ -290,9 +291,9 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
if samples: if samples:
warnings.warn("Samples are rather difficult to plot for 2D inputs...") warnings.warn("Samples are rather difficult to plot for 2D inputs...")
#add inducing inputs (if a sparse model is used) #add inducing inputs (if a sparse self is used)
if hasattr(model,"Z"): if hasattr(self,"Z"):
#Zu = model.Z[:,free_dims] * model._Xscale[:,free_dims] + model._Xoffset[:,free_dims] #Zu = self.Z[:,free_dims] * self._Xscale[:,free_dims] + self._Xoffset[:,free_dims]
Zu = Z[:,free_dims] Zu = Z[:,free_dims]
plots['inducing_inputs'] = ax.plot(Zu[:,0], Zu[:,1], 'wo') plots['inducing_inputs'] = ax.plot(Zu[:,0], Zu[:,1], 'wo')
@ -300,20 +301,41 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
raise NotImplementedError("Cannot define a frame with more than two input dimensions") raise NotImplementedError("Cannot define a frame with more than two input dimensions")
return plots return plots
def plot_density(model, levels=20, plot_limits=None, fignum=None, ax=None, def plot_density(self, levels=20, plot_limits=None,
fixed_inputs=[], plot_raw=False, edgecolor='none', facecolor='#3465a4', fixed_inputs=[], plot_raw=False, edgecolor='none', facecolor='#3465a4',
predict_kw=None,Y_metadata=None, predict_kw=None,Y_metadata=None,
apply_link=False, resolution=200, **patch_kwargs): apply_link=False, resolution=200, **patch_kwargs):
#deal with optional arguments """
if ax is None: Plot the posterior density of the GP.
fig = plt.figure(num=fignum) - In one dimension, the function is plotted with a shaded gradient, visualizing the density of the posterior.
ax = fig.add_subplot(111) - Only implemented for one dimension, for higher dimensions use `plot`.
if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs(): :param levels: number of levels to plot in the density plot. This is a number between 1 and 100. 1 corresponds to the normal plot_fit.
X = model.X.mean :type levels: int
:param plot_limits: The limits of the plot. If 1D [xmin,xmax], if 2D [[xmin,ymin],[xmax,ymax]]. Defaluts to data limits
:type plot_limits: np.array
:param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input index i should be set to value v.
:type fixed_inputs: a list of tuples
:param resolution: the number of intervals to sample the GP on. Defaults to 200 in 1D and 50 (a 50x50 grid) in 2D
:type resolution: int
:param edgecolor: color of line to plot [Tango.colorsHex['darkBlue']]
:type edgecolor: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
:param facecolor: color of fill [Tango.colorsHex['lightBlue']]
:type facecolor: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
:param Y_metadata: additional data associated with Y which may be needed
:type Y_metadata: dict
:param apply_link: if there is a link function of the likelihood, plot the link(f*) rather than f*, when plotting posterior samples f
:type apply_link: boolean
:param resolution: resolution of interpolation (how many points to interpolate of the posterior).
:type resolution: int
:param: patch_kw: the keyword arguments for the patchcollection fill.
"""
#deal with optional arguments
if hasattr(self, 'has_uncertain_inputs') and self.has_uncertain_inputs():
X = self.X.mean
else: else:
X = model.X X = self.X
Y = model.Y Y = self.Y
if sparse.issparse(Y): Y = Y.todense().view(np.ndarray) if sparse.issparse(Y): Y = Y.todense().view(np.ndarray)
if predict_kw is None: if predict_kw is None:
@ -321,13 +343,13 @@ def plot_density(model, levels=20, plot_limits=None, fignum=None, ax=None,
#work out what the inputs are for plotting (1D or 2D) #work out what the inputs are for plotting (1D or 2D)
fixed_dims = np.array([i for i,v in fixed_inputs]) fixed_dims = np.array([i for i,v in fixed_inputs])
free_dims = np.setdiff1d(np.arange(model.input_dim),fixed_dims) free_dims = np.setdiff1d(np.arange(self.input_dim),fixed_dims)
plots = {} plots = {}
#one dimensional plotting #one dimensional plotting
if len(free_dims) == 1: if len(free_dims) == 1:
#define the frame on which to plot #define the frame on which to plot
Xnew, xmin, xmax = x_frame1D(X[:,free_dims], plot_limits=plot_limits, resolution=resolution) Xnew, xmin, xmax = x_frame1D(X[:,free_dims], plot_limits=plot_limits, resolution=resolution)
Xgrid = np.empty((Xnew.shape[0],model.input_dim)) Xgrid = np.empty((Xnew.shape[0],self.input_dim))
Xgrid[:,free_dims] = Xnew Xgrid[:,free_dims] = Xnew
for i,v in fixed_inputs: for i,v in fixed_inputs:
Xgrid[:,i] = v Xgrid[:,i] = v
@ -340,32 +362,32 @@ def plot_density(model, levels=20, plot_limits=None, fignum=None, ax=None,
from ...likelihoods import Gaussian from ...likelihoods import Gaussian
lik = Gaussian(variance=0) lik = Gaussian(variance=0)
else: else:
if isinstance(model,GPCoregionalizedRegression) or isinstance(model,SparseGPCoregionalizedRegression): if isinstance(self,GPCoregionalizedRegression) or isinstance(self,SparseGPCoregionalizedRegression):
extra_data = Xgrid[:,-1:].astype(np.int) extra_data = Xgrid[:,-1:].astype(np.int)
if Y_metadata is None: if Y_metadata is None:
Y_metadata = {'output_index': extra_data} Y_metadata = {'output_index': extra_data}
else: else:
Y_metadata['output_index'] = extra_data Y_metadata['output_index'] = extra_data
lik = None lik = None
percentiles = [i[:, 0] for i in model.predict_quantiles(Xgrid, percs, Y_metadata=Y_metadata, likelihood=lik, **predict_kw)] percentiles = [i[:, 0] for i in self.predict_quantiles(Xgrid, percs, Y_metadata=Y_metadata, likelihood=lik, **predict_kw)]
if apply_link: if apply_link:
percentiles = model.likelihood.gp_link.transf(percentiles) percentiles = self.likelihood.gp_link.transf(percentiles)
patch_kwargs['facecolor'] = facecolor patch_kwargs['facecolor'] = facecolor
patch_kwargs['edgecolor'] = edgecolor patch_kwargs['edgecolor'] = edgecolor
plots['density'] = plot_gradient_fill(ax, Xgrid[:, 0], percentiles, **patch_kwargs) plots['density'] = gradient_fill(Xgrid[:, 0], percentiles, **patch_kwargs)
else: else:
raise NotImplementedError('Only 1D density plottable.') raise NotImplementedError('Only 1D density plottable.')
return plots return plots
def plot_fit_f(model, *args, **kwargs): @wraps(plot_fit)
""" def plot_fit_f(self, plot_limits=None, which_data_rows='all',
Plot the GP's view of the world, where the data is normalized and before applying a likelihood. which_data_ycols='all', fixed_inputs=[],
levels=20, samples=0, fignum=None, ax=None, resolution=None,
All args and kwargs are passed on to models_plots.plot. plot_raw=True,
""" linecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue'], Y_metadata=None, data_symbol='kx',
kwargs['plot_raw'] = True apply_link=False, samples_y=0, plot_uncertain_inputs=True, predict_kw=None, plot_training_data=True):
plot_fit(model,*args, **kwargs) return plot_fit(self, plot_limits, which_data_rows, which_data_ycols, fixed_inputs, levels, samples, fignum, ax, resolution, plot_raw, linecol, fillcol, Y_metadata, data_symbol, apply_link, samples_y, plot_uncertain_inputs, predict_kw, plot_training_data)
def fixed_inputs(model, non_fixed_inputs, fix_routine='median', as_list=True, X_all=False): def fixed_inputs(model, non_fixed_inputs, fix_routine='median', as_list=True, X_all=False):
""" """
@ -465,7 +487,7 @@ def plot_errorbars_trainset(model, which_data_rows='all',
for d in which_data_ycols: for d in which_data_ycols:
plots['gperrors'] = gperrors(X, m[:, d], lower[:, d], upper[:, d], edgecol=linecol, ax=ax, fignum=fignum, **kwargs ) plots['gperrors'] = gperrors(X, m[:, d], lower[:, d], upper[:, d], edgecol=linecol, ax=ax, fignum=fignum, **kwargs )
if plot_training_data: if plot_training_data:
plots['dataplot'] = plot_data(model=model, which_data_rows=which_data_rows, plots['dataplot'] = plot_data(self=model, which_data_rows=which_data_rows,
visible_dims=free_dims, data_symbol=data_symbol, mew=1.5, ax=ax, fignum=fignum) visible_dims=free_dims, data_symbol=data_symbol, mew=1.5, ax=ax, fignum=fignum)

View file

@ -1,331 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# netpbmfile.py
# Copyright (c) 2011-2013, Christoph Gohlke
# Copyright (c) 2011-2013, The Regents of the University of California
# Produced at the Laboratory for Fluorescence Dynamics.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the copyright holders nor the names of any
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
"""Read and write image data from respectively to Netpbm files.
This implementation follows the Netpbm format specifications at
http://netpbm.sourceforge.net/doc/. No gamma correction is performed.
The following image formats are supported: PBM (bi-level), PGM (grayscale),
PPM (color), PAM (arbitrary), XV thumbnail (RGB332, read-only).
:Author:
`Christoph Gohlke <http://www.lfd.uci.edu/~gohlke/>`_
:Organization:
Laboratory for Fluorescence Dynamics, University of California, Irvine
:Version: 2013.01.18
Requirements
------------
* `CPython 2.7, 3.2 or 3.3 <http://www.python.org>`_
* `Numpy 1.7 <http://www.numpy.org>`_
* `Matplotlib 1.2 <http://www.matplotlib.org>`_ (optional for plotting)
Examples
--------
>>> im1 = numpy.array([[0, 1],[65534, 65535]], dtype=numpy.uint16)
>>> imsave('_tmp.pgm', im1)
>>> im2 = imread('_tmp.pgm')
>>> assert numpy.all(im1 == im2)
"""
from __future__ import division, print_function
import sys
import re
import math
from copy import deepcopy
import numpy
__version__ = '2013.01.18'
__docformat__ = 'restructuredtext en'
__all__ = ['imread', 'imsave', 'NetpbmFile']
def imread(filename, *args, **kwargs):
"""Return image data from Netpbm file as numpy array.
`args` and `kwargs` are arguments to NetpbmFile.asarray().
Examples
--------
>>> image = imread('_tmp.pgm')
"""
try:
netpbm = NetpbmFile(filename)
image = netpbm.asarray()
finally:
netpbm.close()
return image
def imsave(filename, data, maxval=None, pam=False):
"""Write image data to Netpbm file.
Examples
--------
>>> image = numpy.array([[0, 1],[65534, 65535]], dtype=numpy.uint16)
>>> imsave('_tmp.pgm', image)
"""
try:
netpbm = NetpbmFile(data, maxval=maxval)
netpbm.write(filename, pam=pam)
finally:
netpbm.close()
class NetpbmFile(object):
"""Read and write Netpbm PAM, PBM, PGM, PPM, files."""
_types = {b'P1': b'BLACKANDWHITE', b'P2': b'GRAYSCALE', b'P3': b'RGB',
b'P4': b'BLACKANDWHITE', b'P5': b'GRAYSCALE', b'P6': b'RGB',
b'P7 332': b'RGB', b'P7': b'RGB_ALPHA'}
def __init__(self, arg=None, **kwargs):
"""Initialize instance from filename, open file, or numpy array."""
for attr in ('header', 'magicnum', 'width', 'height', 'maxval',
'depth', 'tupltypes', '_filename', '_fh', '_data'):
setattr(self, attr, None)
if arg is None:
self._fromdata([], **kwargs)
elif isinstance(arg, basestring):
self._fh = open(arg, 'rb')
self._filename = arg
self._fromfile(self._fh, **kwargs)
elif hasattr(arg, 'seek'):
self._fromfile(arg, **kwargs)
self._fh = arg
else:
self._fromdata(arg, **kwargs)
def asarray(self, copy=True, cache=False, **kwargs):
"""Return image data from file as numpy array."""
data = self._data
if data is None:
data = self._read_data(self._fh, **kwargs)
if cache:
self._data = data
else:
return data
return deepcopy(data) if copy else data
def write(self, arg, **kwargs):
"""Write instance to file."""
if hasattr(arg, 'seek'):
self._tofile(arg, **kwargs)
else:
with open(arg, 'wb') as fid:
self._tofile(fid, **kwargs)
def close(self):
"""Close open file. Future asarray calls might fail."""
if self._filename and self._fh:
self._fh.close()
self._fh = None
def __del__(self):
self.close()
def _fromfile(self, fh):
"""Initialize instance from open file."""
fh.seek(0)
data = fh.read(4096)
if (len(data) < 7) or not (b'0' < data[1:2] < b'8'):
raise ValueError("Not a Netpbm file:\n%s" % data[:32])
try:
self._read_pam_header(data)
except Exception:
try:
self._read_pnm_header(data)
except Exception:
raise ValueError("Not a Netpbm file:\n%s" % data[:32])
def _read_pam_header(self, data):
"""Read PAM header and initialize instance."""
regroups = re.search(
b"(^P7[\n\r]+(?:(?:[\n\r]+)|(?:#.*)|"
b"(HEIGHT\s+\d+)|(WIDTH\s+\d+)|(DEPTH\s+\d+)|(MAXVAL\s+\d+)|"
b"(?:TUPLTYPE\s+\w+))*ENDHDR\n)", data).groups()
self.header = regroups[0]
self.magicnum = b'P7'
for group in regroups[1:]:
key, value = group.split()
setattr(self, unicode(key).lower(), int(value))
matches = re.findall(b"(TUPLTYPE\s+\w+)", self.header)
self.tupltypes = [s.split(None, 1)[1] for s in matches]
def _read_pnm_header(self, data):
"""Read PNM header and initialize instance."""
bpm = data[1:2] in b"14"
regroups = re.search(b"".join((
b"(^(P[123456]|P7 332)\s+(?:#.*[\r\n])*",
b"\s*(\d+)\s+(?:#.*[\r\n])*",
b"\s*(\d+)\s+(?:#.*[\r\n])*" * (not bpm),
b"\s*(\d+)\s(?:\s*#.*[\r\n]\s)*)")), data).groups() + (1, ) * bpm
self.header = regroups[0]
self.magicnum = regroups[1]
self.width = int(regroups[2])
self.height = int(regroups[3])
self.maxval = int(regroups[4])
self.depth = 3 if self.magicnum in b"P3P6P7 332" else 1
self.tupltypes = [self._types[self.magicnum]]
def _read_data(self, fh, byteorder='>'):
"""Return image data from open file as numpy array."""
fh.seek(len(self.header))
data = fh.read()
dtype = 'u1' if self.maxval < 256 else byteorder + 'u2'
depth = 1 if self.magicnum == b"P7 332" else self.depth
shape = [-1, self.height, self.width, depth]
size = numpy.prod(shape[1:])
if self.magicnum in b"P1P2P3":
data = numpy.array(data.split(None, size)[:size], dtype)
data = data.reshape(shape)
elif self.maxval == 1:
shape[2] = int(math.ceil(self.width / 8))
data = numpy.frombuffer(data, dtype).reshape(shape)
data = numpy.unpackbits(data, axis=-2)[:, :, :self.width, :]
else:
data = numpy.frombuffer(data, dtype)
data = data[:size * (data.size // size)].reshape(shape)
if data.shape[0] < 2:
data = data.reshape(data.shape[1:])
if data.shape[-1] < 2:
data = data.reshape(data.shape[:-1])
if self.magicnum == b"P7 332":
rgb332 = numpy.array(list(numpy.ndindex(8, 8, 4)), numpy.uint8)
rgb332 *= [36, 36, 85]
data = numpy.take(rgb332, data, axis=0)
return data
def _fromdata(self, data, maxval=None):
"""Initialize instance from numpy array."""
data = numpy.array(data, ndmin=2, copy=True)
if data.dtype.kind not in "uib":
raise ValueError("not an integer type: %s" % data.dtype)
if data.dtype.kind == 'i' and numpy.min(data) < 0:
raise ValueError("data out of range: %i" % numpy.min(data))
if maxval is None:
maxval = numpy.max(data)
maxval = 255 if maxval < 256 else 65535
if maxval < 0 or maxval > 65535:
raise ValueError("data out of range: %i" % maxval)
data = data.astype('u1' if maxval < 256 else '>u2')
self._data = data
if data.ndim > 2 and data.shape[-1] in (3, 4):
self.depth = data.shape[-1]
self.width = data.shape[-2]
self.height = data.shape[-3]
self.magicnum = b'P7' if self.depth == 4 else b'P6'
else:
self.depth = 1
self.width = data.shape[-1]
self.height = data.shape[-2]
self.magicnum = b'P5' if maxval > 1 else b'P4'
self.maxval = maxval
self.tupltypes = [self._types[self.magicnum]]
self.header = self._header()
def _tofile(self, fh, pam=False):
"""Write Netbm file."""
fh.seek(0)
fh.write(self._header(pam))
data = self.asarray(copy=False)
if self.maxval == 1:
data = numpy.packbits(data, axis=-1)
data.tofile(fh)
def _header(self, pam=False):
"""Return file header as byte string."""
if pam or self.magicnum == b'P7':
header = "\n".join((
"P7",
"HEIGHT %i" % self.height,
"WIDTH %i" % self.width,
"DEPTH %i" % self.depth,
"MAXVAL %i" % self.maxval,
"\n".join("TUPLTYPE %s" % unicode(i) for i in self.tupltypes),
"ENDHDR\n"))
elif self.maxval == 1:
header = "P4 %i %i\n" % (self.width, self.height)
elif self.depth == 1:
header = "P5 %i %i %i\n" % (self.width, self.height, self.maxval)
else:
header = "P6 %i %i %i\n" % (self.width, self.height, self.maxval)
if sys.version_info[0] > 2:
header = bytes(header, 'ascii')
return header
def __str__(self):
"""Return information about instance."""
return unicode(self.header)
if sys.version_info[0] > 2:
basestring = str
unicode = lambda x: str(x, 'ascii')
if __name__ == "__main__":
# Show images specified on command line or all images in current directory
from glob import glob
from matplotlib import pyplot
files = sys.argv[1:] if len(sys.argv) > 1 else glob('*.p*m')
for fname in files:
try:
pam = NetpbmFile(fname)
img = pam.asarray(copy=False)
if False:
pam.write('_tmp.pgm.out', pam=True)
img2 = imread('_tmp.pgm.out')
assert numpy.all(img == img2)
imsave('_tmp.pgm.out', img)
img2 = imread('_tmp.pgm.out')
assert numpy.all(img == img2)
pam.close()
except ValueError as e:
print(fname, e)
continue
_shape = img.shape
if img.ndim > 3 or (img.ndim > 2 and img.shape[-1] not in (3, 4)):
img = img[0]
cmap = 'gray' if pam.maxval > 1 else 'binary'
pyplot.imshow(img, cmap, interpolation='nearest')
pyplot.title("%s %s %s %s" % (fname, unicode(pam.magicnum),
_shape, img.dtype))
pyplot.show()

View file

@ -18,7 +18,7 @@
# this list of conditions and the following disclaimer in the documentation # this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution. # and/or other materials provided with the distribution.
# #
# * Neither the name of paramax nor the names of its # * Neither the name of GPy nor the names of its
# contributors may be used to endorse or promote products derived from # contributors may be used to endorse or promote products derived from
# this software without specific prior written permission. # this software without specific prior written permission.
# #

View file

@ -13,7 +13,7 @@
# this list of conditions and the following disclaimer in the documentation # this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution. # and/or other materials provided with the distribution.
# #
# * Neither the name of paramax nor the names of its # * Neither the name of GPy nor the names of its
# contributors may be used to endorse or promote products derived from # contributors may be used to endorse or promote products derived from
# this software without specific prior written permission. # this software without specific prior written permission.
# #
@ -32,8 +32,8 @@
#!/usr/bin/env python #!/usr/bin/env python
import matplotlib import matplotlib
matplotlib.use('svg') matplotlib.use('agg')
import nose import nose
nose.main('GPy', defaultTest='GPy/testing') nose.main('GPy', defaultTest='GPy/testing/plotting_tests.py')