[plotting] cleanup first commit, this cleans the plotting library and adds plotting tests

This commit is contained in:
mzwiessele 2015-10-02 18:26:17 +01:00
parent fee2f3f727
commit b9bfd0fc6d
10 changed files with 200 additions and 690 deletions

View file

@ -477,260 +477,6 @@ class GP(Model):
Ysim = self.likelihood.samples(fsim, Y_metadata=Y_metadata)
return Ysim
def plot_f(self, plot_limits=None, which_data_rows='all',
which_data_ycols='all', fixed_inputs=[],
levels=20, samples=0, fignum=None, ax=None, resolution=None,
plot_raw=True,
linecol=None,fillcol=None, Y_metadata=None, data_symbol='kx',
apply_link=False):
"""
Plot the GP's view of the world, where the data is normalized and before applying a likelihood.
This is a call to plot with plot_raw=True.
Data will not be plotted in this, as the GP's view of the world
may live in another space, or units then the data.
Can plot only part of the data and part of the posterior functions
using which_data_rowsm which_data_ycols.
:param plot_limits: The limits of the plot. If 1D [xmin,xmax], if 2D [[xmin,ymin],[xmax,ymax]]. Defaluts to data limits
:type plot_limits: np.array
:param which_data_rows: which of the training data to plot (default all)
:type which_data_rows: 'all' or a slice object to slice model.X, model.Y
:param which_data_ycols: when the data has several columns (independant outputs), only plot these
:type which_data_ycols: 'all' or a list of integers
:param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input index i should be set to value v.
:type fixed_inputs: a list of tuples
:param resolution: the number of intervals to sample the GP on. Defaults to 200 in 1D and 50 (a 50x50 grid) in 2D
:type resolution: int
:param levels: number of levels to plot in a contour plot.
:param levels: for 2D plotting, the number of contour levels to use is ax is None, create a new figure
:type levels: int
:param samples: the number of a posteriori samples to plot
:type samples: int
:param fignum: figure to plot on.
:type fignum: figure number
:param ax: axes to plot on.
:type ax: axes handle
:param linecol: color of line to plot [Tango.colorsHex['darkBlue']]
:type linecol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
:param fillcol: color of fill [Tango.colorsHex['lightBlue']]
:type fillcol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
:param Y_metadata: additional data associated with Y which may be needed
:type Y_metadata: dict
:param data_symbol: symbol as used matplotlib, by default this is a black cross ('kx')
:type data_symbol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) alongside marker type, as is standard in matplotlib.
:param apply_link: if there is a link function of the likelihood, plot the link(f*) rather than f*
:type apply_link: boolean
"""
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from ..plotting.matplot_dep import models_plots
kw = {}
if linecol is not None:
kw['linecol'] = linecol
if fillcol is not None:
kw['fillcol'] = fillcol
return models_plots.plot_fit(self, plot_limits, which_data_rows,
which_data_ycols, fixed_inputs,
levels, samples, fignum, ax, resolution,
plot_raw=plot_raw, Y_metadata=Y_metadata,
data_symbol=data_symbol, apply_link=apply_link, **kw)
def plot(self, plot_limits=None, which_data_rows='all',
which_data_ycols='all', fixed_inputs=[],
levels=20, samples=0, fignum=None, ax=None, resolution=None,
plot_raw=False, linecol=None,fillcol=None, Y_metadata=None,
data_symbol='kx', predict_kw=None, plot_training_data=True, samples_y=0, apply_link=False):
"""
Plot the posterior of the GP.
- In one dimension, the function is plotted with a shaded region identifying two standard deviations.
- In two dimsensions, a contour-plot shows the mean predicted function
- In higher dimensions, use fixed_inputs to plot the GP with some of the inputs fixed.
Can plot only part of the data and part of the posterior functions
using which_data_rowsm which_data_ycols.
:param plot_limits: The limits of the plot. If 1D [xmin,xmax], if 2D [[xmin,ymin],[xmax,ymax]]. Defaluts to data limits
:type plot_limits: np.array
:param which_data_rows: which of the training data to plot (default all)
:type which_data_rows: 'all' or a slice object to slice model.X, model.Y
:param which_data_ycols: when the data has several columns (independant outputs), only plot these
:type which_data_ycols: 'all' or a list of integers
:param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input index i should be set to value v.
:type fixed_inputs: a list of tuples
:param resolution: the number of intervals to sample the GP on. Defaults to 200 in 1D and 50 (a 50x50 grid) in 2D
:type resolution: int
:param levels: number of levels to plot in a contour plot.
:param levels: for 2D plotting, the number of contour levels to use is ax is None, create a new figure
:type levels: int
:param samples: the number of a posteriori samples to plot, p(f*|y)
:type samples: int
:param fignum: figure to plot on.
:type fignum: figure number
:param ax: axes to plot on.
:type ax: axes handle
:param linecol: color of line to plot [Tango.colorsHex['darkBlue']]
:type linecol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
:param fillcol: color of fill [Tango.colorsHex['lightBlue']]
:type fillcol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
:param Y_metadata: additional data associated with Y which may be needed
:type Y_metadata: dict
:param data_symbol: symbol as used matplotlib, by default this is a black cross ('kx')
:type data_symbol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) alongside marker type, as is standard in matplotlib.
:param plot_training_data: whether or not to plot the training points
:type plot_training_data: boolean
:param samples_y: the number of a posteriori samples to plot, p(y*|y)
:type samples_y: int
:param apply_link: if there is a link function of the likelihood, plot the link(f*) rather than f*, when plotting posterior samples f
:type apply_link: boolean
"""
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from ..plotting.matplot_dep import models_plots
kw = {}
if linecol is not None:
kw['linecol'] = linecol
if fillcol is not None:
kw['fillcol'] = fillcol
return models_plots.plot_fit(self, plot_limits, which_data_rows,
which_data_ycols, fixed_inputs,
levels, samples, fignum, ax, resolution,
plot_raw=plot_raw, Y_metadata=Y_metadata,
data_symbol=data_symbol, predict_kw=predict_kw,
plot_training_data=plot_training_data, samples_y=samples_y, apply_link=apply_link, **kw)
def plot_density(self, levels=20, plot_limits=None, fignum=None, ax=None,
fixed_inputs=[], plot_raw=False, edgecolor='none', facecolor='#3465a4',
predict_kw=None,Y_metadata=None,
apply_link=False, resolution=200, **patch_kw):
"""
Plot the posterior density of the GP.
- In one dimension, the function is plotted with a shaded gradient, visualizing the density of the posterior.
- Only implemented for one dimension, for higher dimensions use `plot`.
:param levels: number of levels to plot in the density plot. This is a number between 1 and 100. 1 corresponds to the normal plot_fit.
:type levels: int
:param plot_limits: The limits of the plot. If 1D [xmin,xmax], if 2D [[xmin,ymin],[xmax,ymax]]. Defaluts to data limits
:type plot_limits: np.array
:param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input index i should be set to value v.
:type fixed_inputs: a list of tuples
:param resolution: the number of intervals to sample the GP on. Defaults to 200 in 1D and 50 (a 50x50 grid) in 2D
:type resolution: int
:param edgecolor: color of line to plot [Tango.colorsHex['darkBlue']]
:type edgecolor: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
:param facecolor: color of fill [Tango.colorsHex['lightBlue']]
:type facecolor: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
:param Y_metadata: additional data associated with Y which may be needed
:type Y_metadata: dict
:param apply_link: if there is a link function of the likelihood, plot the link(f*) rather than f*, when plotting posterior samples f
:type apply_link: boolean
:param resolution: resolution of interpolation (how many points to interpolate of the posterior).
:type resolution: int
:param: patch_kw: the keyword arguments for the patchcollection fill.
"""
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from ..plotting.matplot_dep import models_plots
return models_plots.plot_density(self, levels, plot_limits, fignum, ax,
fixed_inputs, plot_raw=plot_raw,
Y_metadata=Y_metadata,
predict_kw=predict_kw,
apply_link=apply_link,
edgecolor=edgecolor, facecolor=facecolor,
**patch_kw)
def plot_data(self, which_data_rows='all',
which_data_ycols='all', visible_dims=None,
fignum=None, ax=None, data_symbol='kx'):
"""
Plot the training data
- For higher dimensions than two, use fixed_inputs to plot the data points with some of the inputs fixed.
Can plot only part of the data
using which_data_rows and which_data_ycols.
:param plot_limits: The limits of the plot. If 1D [xmin,xmax], if 2D [[xmin,ymin],[xmax,ymax]]. Defaluts to data limits
:type plot_limits: np.array
:param which_data_rows: which of the training data to plot (default all)
:type which_data_rows: 'all' or a slice object to slice model.X, model.Y
:param which_data_ycols: when the data has several columns (independant outputs), only plot these
:type which_data_ycols: 'all' or a list of integers
:param visible_dims: an array specifying the input dimensions to plot (maximum two)
:type visible_dims: a numpy array
:param resolution: the number of intervals to sample the GP on. Defaults to 200 in 1D and 50 (a 50x50 grid) in 2D
:type resolution: int
:param levels: number of levels to plot in a contour plot.
:param levels: for 2D plotting, the number of contour levels to use is ax is None, create a new figure
:type levels: int
:param samples: the number of a posteriori samples to plot, p(f*|y)
:type samples: int
:param fignum: figure to plot on.
:type fignum: figure number
:param ax: axes to plot on.
:type ax: axes handle
:param linecol: color of line to plot [Tango.colorsHex['darkBlue']]
:type linecol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
:param fillcol: color of fill [Tango.colorsHex['lightBlue']]
:type fillcol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
:param data_symbol: symbol as used matplotlib, by default this is a black cross ('kx')
:type data_symbol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) alongside marker type, as is standard in matplotlib.
"""
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from ..plotting.matplot_dep import models_plots
kw = {}
return models_plots.plot_data(self, which_data_rows,
which_data_ycols, visible_dims,
fignum, ax, data_symbol, **kw)
def plot_errorbars_trainset(self, which_data_rows='all',
which_data_ycols='all', fixed_inputs=[], fignum=None, ax=None,
linecol=None, data_symbol='kx', predict_kw=None, plot_training_data=True,lw=None):
"""
Plot the posterior error bars corresponding to the training data
- For higher dimensions than two, use fixed_inputs to plot the data points with some of the inputs fixed.
Can plot only part of the data
using which_data_rows and which_data_ycols.
:param which_data_rows: which of the training data to plot (default all)
:type which_data_rows: 'all' or a slice object to slice model.X, model.Y
:param which_data_ycols: when the data has several columns (independant outputs), only plot these
:type which_data_rows: 'all' or a list of integers
:param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input index i should be set to value v.
:type fixed_inputs: a list of tuples
:param fignum: figure to plot on.
:type fignum: figure number
:param ax: axes to plot on.
:type ax: axes handle
:param plot_training_data: whether or not to plot the training points
:type plot_training_data: boolean
"""
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from ..plotting.matplot_dep import models_plots
kw = {}
if lw is not None:
kw['lw'] = lw
return models_plots.plot_errorbars_trainset(self, which_data_rows, which_data_ycols, fixed_inputs,
fignum, ax, linecol, data_symbol,
predict_kw, plot_training_data, **kw)
def plot_magnification(self, labels=None, which_indices=None,
resolution=50, ax=None, marker='o', s=40,
fignum=None, legend=True,
plot_limits=None,
aspect='auto', updates=False, plot_inducing=True, kern=None, **kwargs):
import sys
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
from ..plotting.matplot_dep import dim_reduction_plots
return dim_reduction_plots.plot_magnification(self, labels, which_indices,
resolution, ax, marker, s,
fignum, plot_inducing, legend,
plot_limits, aspect, updates, **kwargs)
def input_sensitivity(self, summarize=True):
"""
Returns the sensitivity for each dimension of this model

View file

@ -28,3 +28,6 @@ working = True
[cython]
working = True
[plotting]
library = matplotlib

View file

@ -13,7 +13,7 @@ from functools import wraps
def put_clean(dct, name, func):
if name in dct:
dct['_clean_{}'.format(name)] = dct[name]
#dct['_clean_{}'.format(name)] = dct[name]
dct[name] = func(dct[name])
class KernCallsViaSlicerMeta(ParametersChangedMeta):

View file

@ -2,11 +2,30 @@
# Licensed under the BSD 3-clause license (see LICENSE.txt)
try:
import matplotlib
from . import matplot_dep
from ..util.config import config
lib = config.get('plotting', 'library')
if lib == 'matplotlib':
import matplotlib
from . import matplot_dep as plotting_library
except (ImportError, NameError):
# Matplotlib not available
import warnings
warnings.warn(ImportWarning("Matplotlib not available, install newest version of Matplotlib for plotting"))
#sys.modules['matplotlib'] =
#sys.modules[__name__+'.matplot_dep'] = ImportWarning("Matplotlib not available, install newest version of Matplotlib for plotting")
warnings.warn(ImportWarning("{} not available, install newest version of {} for plotting").format(lib, lib))
config.set('plotting', 'library', 'none')
if config.get('plotting', 'library') is not 'none':
# Inject the plots into classes here:
# Already converted to new style:
from . import gpy_plot
from ..core import GP
GP.plot_data = gpy_plot.data_plots.plot_data
# Still to convert to new style:
GP.plot = plotting_library.models_plots.plot_fit
GP.plot_f = plotting_library.models_plots.plot_fit_f
GP.plot_density = plotting_library.models_plots.plot_density
GP.plot_errorbars_trainset = plotting_library.models_plots.plot_errorbars_trainset
GP.plot_magnification = plotting_library.dim_reduction_plots.plot_magnification

View file

@ -1,6 +1,54 @@
# Copyright (c) 2014, GPy authors (see AUTHORS.txt).
# Licensed under the BSD 3-clause license (see LICENSE.txt)
from matplotlib import pyplot as plt
from . import defaults
def get_new_canvas(kwargs):
"""
Return a canvas, kwargupdate for matplotlib. This just a
dictionary for the collection and we add the an axis to kwarg.
This method does two things, it creates an empty canvas
and updates the kwargs (deletes the unnecessary kwargs)
for further usage in normal plotting.
in matplotlib this means it deletes references to ax, as
plotting is done on the axis itself and is not a kwarg.
"""
if 'ax' in kwargs:
ax = kwargs.pop('ax')
elif 'num' in kwargs and 'figsize' in kwargs:
ax = plt.figure(num=kwargs.pop('num'), figsize=kwargs.pop('figsize')).add_subplot(111)
elif 'num' in kwargs:
ax = plt.figure(num=kwargs.pop('num')).add_subplot(111)
elif 'figsize' in kwargs:
ax = plt.figure(figsize=kwargs.pop('figsize')).add_subplot(111)
else:
ax = plt.figure().add_subplot(111)
# Add ax to kwargs to add all subsequent plots to this axis:
#kwargs['ax'] = ax
return ax, kwargs
def show_canvas(canvas):
try:
canvas.figure.canvas.draw()
canvas.figure.tight_layout()
except:
pass
return canvas
def scatter(ax, *args, **kwargs):
ax.scatter(*args, **kwargs)
def plot(ax, *args, **kwargs):
ax.plot(*args, **kwargs)
def imshow(ax, *args, **kwargs):
ax.imshow(*args, **kwargs)
from . import base_plots
from . import models_plots
from . import priors_plots
@ -11,8 +59,9 @@ from . import mapping_plots
from . import Tango
from . import visualize
from . import latent_space_visualizations
from . import netpbmfile
from . import inference_plots
from . import maps
from . import img_plots
from .ssgplvm import SSGPLVM_plot
from .ssgplvm import SSGPLVM_plot

View file

@ -1,11 +1,11 @@
# #Copyright (c) 2012, GPy authors (see AUTHORS.txt).
# Licensed under the BSD 3-clause license (see LICENSE.txt)
from matplotlib import pyplot as pb
from matplotlib import pyplot as plt
import numpy as np
def ax_default(fignum, ax):
if ax is None:
fig = pb.figure(fignum)
fig = plt.figure(fignum)
ax = fig.add_subplot(111)
else:
fig = ax.figure
@ -35,12 +35,14 @@ def gpplot(x, mu, lower, upper, edgecol='#3300FF', fillcol='#33CCFF', ax=None, f
plots.append(axes.fill(np.hstack((x,x[::-1])),np.hstack((upper,lower[::-1])),color=fillcol,**kwargs))
#this is the edge:
plots.append(meanplot(x, upper,color=edgecol,linewidth=0.2,ax=axes))
plots.append(meanplot(x, lower,color=edgecol,linewidth=0.2,ax=axes))
plots.append(meanplot(x, upper,color=edgecol, linewidth=0.2, ax=axes))
plots.append(meanplot(x, lower,color=edgecol, linewidth=0.2, ax=axes))
return plots
def plot_gradient_fill(ax, x, percentiles, **kwargs):
def gradient_fill(x, percentiles, ax=None, fignum=None, **kwargs):
_, ax = ax_default(fignum, ax)
plots = []
#here's the box
@ -150,19 +152,19 @@ def gperrors(x, mu, lower, upper, edgecol=None, ax=None, fignum=None, **kwargs):
def removeRightTicks(ax=None):
ax = ax or pb.gca()
ax = ax or plt.gca()
for i, line in enumerate(ax.get_yticklines()):
if i%2 == 1: # odd indices
line.set_visible(False)
def removeUpperTicks(ax=None):
ax = ax or pb.gca()
ax = ax or plt.gca()
for i, line in enumerate(ax.get_xticklines()):
if i%2 == 1: # odd indices
line.set_visible(False)
def fewerXticks(ax=None,divideby=2):
ax = ax or pb.gca()
ax = ax or plt.gca()
ax.set_xticks(ax.get_xticks()[::divideby])
def align_subplots(N,M,xlim=None, ylim=None):
@ -171,33 +173,33 @@ def align_subplots(N,M,xlim=None, ylim=None):
if xlim is None:
xlim = [np.inf,-np.inf]
for i in range(N*M):
pb.subplot(N,M,i+1)
xlim[0] = min(xlim[0],pb.xlim()[0])
xlim[1] = max(xlim[1],pb.xlim()[1])
plt.subplot(N,M,i+1)
xlim[0] = min(xlim[0],plt.xlim()[0])
xlim[1] = max(xlim[1],plt.xlim()[1])
if ylim is None:
ylim = [np.inf,-np.inf]
for i in range(N*M):
pb.subplot(N,M,i+1)
ylim[0] = min(ylim[0],pb.ylim()[0])
ylim[1] = max(ylim[1],pb.ylim()[1])
plt.subplot(N,M,i+1)
ylim[0] = min(ylim[0],plt.ylim()[0])
ylim[1] = max(ylim[1],plt.ylim()[1])
for i in range(N*M):
pb.subplot(N,M,i+1)
pb.xlim(xlim)
pb.ylim(ylim)
plt.subplot(N,M,i+1)
plt.xlim(xlim)
plt.ylim(ylim)
if (i)%M:
pb.yticks([])
plt.yticks([])
else:
removeRightTicks()
if i<(M*(N-1)):
pb.xticks([])
plt.xticks([])
else:
removeUpperTicks()
def align_subplot_array(axes,xlim=None, ylim=None):
"""
Make all of the axes in the array hae the same limits, turn off unnecessary ticks
use pb.subplots() to get an array of axes
use plt.subplots() to get an array of axes
"""
#find sensible xlim,ylim
if xlim is None:

View file

@ -9,12 +9,13 @@ from ...models.sparse_gp_coregionalized_regression import SparseGPCoregionalized
from scipy import sparse
from ...core.parameterization.variational import VariationalPosterior
from matplotlib import pyplot as plt
from .base_plots import plot_gradient_fill
from .base_plots import gradient_fill
from functools import wraps
def plot_data(model, which_data_rows='all',
def plot_data(self, which_data_rows='all',
which_data_ycols='all', visible_dims=None,
fignum=None, ax=None, data_symbol='kx',mew=1.5):
fignum=None, ax=None, data_symbol='kx',mew=1.5,**kwargs):
"""
Plot the training data
- For higher dimensions than two, use fixed_inputs to plot the data points with some of the inputs fixed.
@ -23,7 +24,7 @@ def plot_data(model, which_data_rows='all',
using which_data_rows and which_data_ycols.
:param which_data_rows: which of the training data to plot (default all)
:type which_data_rows: 'all' or a slice object to slice model.X, model.Y
:type which_data_rows: 'all' or a slice object to slice self.X, self.Y
:param which_data_ycols: when the data has several columns (independant outputs), only plot these
:type which_data_rows: 'all' or a list of integers
:param visible_dims: an array specifying the input dimensions to plot (maximum two)
@ -37,23 +38,23 @@ def plot_data(model, which_data_rows='all',
if which_data_rows == 'all':
which_data_rows = slice(None)
if which_data_ycols == 'all':
which_data_ycols = np.arange(model.output_dim)
which_data_ycols = np.arange(self.output_dim)
if ax is None:
fig = plt.figure(num=fignum)
ax = fig.add_subplot(111)
if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs():
X = model.X.mean
X_variance = model.X.variance
if hasattr(self, 'has_uncertain_inputs') and self.has_uncertain_inputs():
X = self.X.mean
X_variance = self.X.variance
else:
X = model.X
X = self.X
X_variance = None
Y = model.Y
Y = self.Y
#work out what the inputs are for plotting (1D or 2D)
if visible_dims is None:
visible_dims = np.arange(model.input_dim)
visible_dims = np.arange(self.input_dim)
assert visible_dims.size <= 2, "Visible inputs cannot be larger than two"
free_dims = visible_dims
plots = {}
@ -80,7 +81,7 @@ def plot_data(model, which_data_rows='all',
return plots
def plot_fit(model, plot_limits=None, which_data_rows='all',
def plot_fit(self, plot_limits=None, which_data_rows='all',
which_data_ycols='all', fixed_inputs=[],
levels=20, samples=0, fignum=None, ax=None, resolution=None,
plot_raw=False,
@ -98,7 +99,7 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
:param plot_limits: The limits of the plot. If 1D [xmin,xmax], if 2D [[xmin,ymin],[xmax,ymax]]. Defaluts to data limits
:type plot_limits: np.array
:param which_data_rows: which of the training data to plot (default all)
:type which_data_rows: 'all' or a slice object to slice model.X, model.Y
:type which_data_rows: 'all' or a slice object to slice self.X, self.Y
:param which_data_ycols: when the data has several columns (independant outputs), only plot these
:type which_data_rows: 'all' or a list of integers
:param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input index i should be set to value v.
@ -134,98 +135,98 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
if which_data_rows == 'all':
which_data_rows = slice(None)
if which_data_ycols == 'all':
which_data_ycols = np.arange(model.output_dim)
which_data_ycols = np.arange(self.output_dim)
#if len(which_data_ycols)==0:
#raise ValueError('No data selected for plotting')
if ax is None:
fig = plt.figure(num=fignum)
ax = fig.add_subplot(111)
if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs():
X = model.X.mean
X_variance = model.X.variance
if hasattr(self, 'has_uncertain_inputs') and self.has_uncertain_inputs():
X = self.X.mean
X_variance = self.X.variance
else:
X = model.X
Y = model.Y
X = self.X
Y = self.Y
if sparse.issparse(Y): Y = Y.todense().view(np.ndarray)
if hasattr(model, 'Z'): Z = model.Z
if hasattr(self, 'Z'): Z = self.Z
if predict_kw is None:
predict_kw = {}
#work out what the inputs are for plotting (1D or 2D)
fixed_dims = np.array([i for i,v in fixed_inputs])
free_dims = np.setdiff1d(np.arange(model.input_dim),fixed_dims)
free_dims = np.setdiff1d(np.arange(self.input_dim),fixed_dims)
plots = {}
#one dimensional plotting
if len(free_dims) == 1:
#define the frame on which to plot
Xnew, xmin, xmax = x_frame1D(X[:,free_dims], plot_limits=plot_limits, resolution=resolution or 200)
Xgrid = np.empty((Xnew.shape[0],model.input_dim))
Xgrid = np.empty((Xnew.shape[0],self.input_dim))
Xgrid[:,free_dims] = Xnew
for i,v in fixed_inputs:
Xgrid[:,i] = v
#make a prediction on the frame and plot it
if plot_raw:
m, v = model._raw_predict(Xgrid, **predict_kw)
m, v = self._raw_predict(Xgrid, **predict_kw)
if apply_link:
lower = model.likelihood.gp_link.transf(m - 2*np.sqrt(v))
upper = model.likelihood.gp_link.transf(m + 2*np.sqrt(v))
lower = self.likelihood.gp_link.transf(m - 2*np.sqrt(v))
upper = self.likelihood.gp_link.transf(m + 2*np.sqrt(v))
#Once transformed this is now the median of the function
m = model.likelihood.gp_link.transf(m)
m = self.likelihood.gp_link.transf(m)
else:
lower = m - 2*np.sqrt(v)
upper = m + 2*np.sqrt(v)
else:
if isinstance(model,GPCoregionalizedRegression) or isinstance(model,SparseGPCoregionalizedRegression):
if isinstance(self,GPCoregionalizedRegression) or isinstance(self,SparseGPCoregionalizedRegression):
extra_data = Xgrid[:,-1:].astype(np.int)
if Y_metadata is None:
Y_metadata = {'output_index': extra_data}
else:
Y_metadata['output_index'] = extra_data
m, v = model.predict(Xgrid, full_cov=False, Y_metadata=Y_metadata, **predict_kw)
fmu, fv = model._raw_predict(Xgrid, full_cov=False, **predict_kw)
lower, upper = model.likelihood.predictive_quantiles(fmu, fv, (2.5, 97.5), Y_metadata=Y_metadata)
m, v = self.predict(Xgrid, full_cov=False, Y_metadata=Y_metadata, **predict_kw)
fmu, fv = self._raw_predict(Xgrid, full_cov=False, **predict_kw)
lower, upper = self.likelihood.predictive_quantiles(fmu, fv, (2.5, 97.5), Y_metadata=Y_metadata)
for d in which_data_ycols:
plots['gpplot'] = gpplot(Xnew, m[:, d], lower[:, d], upper[:, d], ax=ax, edgecol=linecol, fillcol=fillcol)
#if not plot_raw: plots['dataplot'] = ax.plot(X[which_data_rows,free_dims], Y[which_data_rows, d], data_symbol, mew=1.5)
if not plot_raw and plot_training_data:
plots['dataplot'] = plot_data(model=model, which_data_rows=which_data_rows,
plots['dataplot'] = plot_data(self=self, which_data_rows=which_data_rows,
visible_dims=free_dims, data_symbol=data_symbol, mew=1.5, ax=ax, fignum=fignum)
#optionally plot some samples
if samples: #NOTE not tested with fixed_inputs
Fsim = model.posterior_samples_f(Xgrid, samples)
Fsim = self.posterior_samples_f(Xgrid, samples)
if apply_link:
Fsim = model.likelihood.gp_link.transf(Fsim)
Fsim = self.likelihood.gp_link.transf(Fsim)
for fi in Fsim.T:
plots['posterior_samples'] = ax.plot(Xnew, fi[:,None], '#3300FF', linewidth=0.25)
#ax.plot(Xnew, fi[:,None], marker='x', linestyle='--',color=Tango.colorsHex['darkBlue']) #TODO apply this line for discrete outputs.
if samples_y: #NOTE not tested with fixed_inputs
Ysim = model.posterior_samples(Xgrid, samples_y, Y_metadata=Y_metadata)
Ysim = self.posterior_samples(Xgrid, samples_y, Y_metadata=Y_metadata)
for yi in Ysim.T:
plots['posterior_samples_y'] = ax.scatter(Xnew, yi[:,None], s=5, c=Tango.colorsHex['darkBlue'], marker='o', alpha=0.5)
#ax.plot(Xnew, yi[:,None], marker='x', linestyle='--',color=Tango.colorsHex['darkBlue']) #TODO apply this line for discrete outputs.
#add error bars for uncertain (if input uncertainty is being modelled)
if hasattr(model,"has_uncertain_inputs") and model.has_uncertain_inputs() and plot_uncertain_inputs:
if hasattr(self,"has_uncertain_inputs") and self.has_uncertain_inputs() and plot_uncertain_inputs:
if plot_raw:
#add error bars for uncertain (if input uncertainty is being modelled), for plot_f
#Hack to plot error bars on latent function, rather than on the data
vs = model.X.mean.values.copy()
vs = self.X.mean.values.copy()
for i,v in fixed_inputs:
vs[:,i] = v
m_X, _ = model._raw_predict(vs)
m_X, _ = self._raw_predict(vs)
if apply_link:
m_X = model.likelihood.gp_link.transf(m_X)
m_X = self.likelihood.gp_link.transf(m_X)
plots['xerrorbar'] = ax.errorbar(X[which_data_rows, free_dims].flatten(), m_X[which_data_rows, which_data_ycols].flatten(),
xerr=2 * np.sqrt(X_variance[which_data_rows, free_dims].flatten()),
ecolor='k', fmt=None, elinewidth=.5, alpha=.5)
@ -243,9 +244,9 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
pass
#add inducing inputs (if a sparse model is used)
if hasattr(model,"Z"):
#Zu = model.Z[:,free_dims] * model._Xscale[:,free_dims] + model._Xoffset[:,free_dims]
if isinstance(model,SparseGPCoregionalizedRegression):
if hasattr(self,"Z"):
#Zu = self.Z[:,free_dims] * self._Xscale[:,free_dims] + self._Xoffset[:,free_dims]
if isinstance(self,SparseGPCoregionalizedRegression):
Z = Z[Z[:,-1] == Y_metadata['output_index'],:]
Zu = Z[:,free_dims]
z_height = ax.get_ylim()[0]
@ -259,7 +260,7 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
#define the frame for plotting on
resolution = resolution or 50
Xnew, _, _, xmin, xmax = x_frame2D(X[:,free_dims], plot_limits, resolution)
Xgrid = np.empty((Xnew.shape[0],model.input_dim))
Xgrid = np.empty((Xnew.shape[0],self.input_dim))
Xgrid[:,free_dims] = Xnew
for i,v in fixed_inputs:
Xgrid[:,i] = v
@ -267,15 +268,15 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
#predict on the frame and plot
if plot_raw:
m, _ = model._raw_predict(Xgrid, **predict_kw)
m, _ = self._raw_predict(Xgrid, **predict_kw)
else:
if isinstance(model,GPCoregionalizedRegression) or isinstance(model,SparseGPCoregionalizedRegression):
if isinstance(self,GPCoregionalizedRegression) or isinstance(self,SparseGPCoregionalizedRegression):
extra_data = Xgrid[:,-1:].astype(np.int)
if Y_metadata is None:
Y_metadata = {'output_index': extra_data}
else:
Y_metadata['output_index'] = extra_data
m, v = model.predict(Xgrid, full_cov=False, Y_metadata=Y_metadata, **predict_kw)
m, v = self.predict(Xgrid, full_cov=False, Y_metadata=Y_metadata, **predict_kw)
for d in which_data_ycols:
m_d = m[:,d].reshape(resolution, resolution).T
plots['contour'] = ax.contour(x, y, m_d, levels, vmin=m.min(), vmax=m.max(), cmap=plt.cm.jet)
@ -290,9 +291,9 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
if samples:
warnings.warn("Samples are rather difficult to plot for 2D inputs...")
#add inducing inputs (if a sparse model is used)
if hasattr(model,"Z"):
#Zu = model.Z[:,free_dims] * model._Xscale[:,free_dims] + model._Xoffset[:,free_dims]
#add inducing inputs (if a sparse self is used)
if hasattr(self,"Z"):
#Zu = self.Z[:,free_dims] * self._Xscale[:,free_dims] + self._Xoffset[:,free_dims]
Zu = Z[:,free_dims]
plots['inducing_inputs'] = ax.plot(Zu[:,0], Zu[:,1], 'wo')
@ -300,20 +301,41 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
raise NotImplementedError("Cannot define a frame with more than two input dimensions")
return plots
def plot_density(model, levels=20, plot_limits=None, fignum=None, ax=None,
def plot_density(self, levels=20, plot_limits=None,
fixed_inputs=[], plot_raw=False, edgecolor='none', facecolor='#3465a4',
predict_kw=None,Y_metadata=None,
apply_link=False, resolution=200, **patch_kwargs):
#deal with optional arguments
if ax is None:
fig = plt.figure(num=fignum)
ax = fig.add_subplot(111)
"""
Plot the posterior density of the GP.
- In one dimension, the function is plotted with a shaded gradient, visualizing the density of the posterior.
- Only implemented for one dimension, for higher dimensions use `plot`.
if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs():
X = model.X.mean
:param levels: number of levels to plot in the density plot. This is a number between 1 and 100. 1 corresponds to the normal plot_fit.
:type levels: int
:param plot_limits: The limits of the plot. If 1D [xmin,xmax], if 2D [[xmin,ymin],[xmax,ymax]]. Defaluts to data limits
:type plot_limits: np.array
:param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input index i should be set to value v.
:type fixed_inputs: a list of tuples
:param resolution: the number of intervals to sample the GP on. Defaults to 200 in 1D and 50 (a 50x50 grid) in 2D
:type resolution: int
:param edgecolor: color of line to plot [Tango.colorsHex['darkBlue']]
:type edgecolor: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
:param facecolor: color of fill [Tango.colorsHex['lightBlue']]
:type facecolor: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
:param Y_metadata: additional data associated with Y which may be needed
:type Y_metadata: dict
:param apply_link: if there is a link function of the likelihood, plot the link(f*) rather than f*, when plotting posterior samples f
:type apply_link: boolean
:param resolution: resolution of interpolation (how many points to interpolate of the posterior).
:type resolution: int
:param: patch_kw: the keyword arguments for the patchcollection fill.
"""
#deal with optional arguments
if hasattr(self, 'has_uncertain_inputs') and self.has_uncertain_inputs():
X = self.X.mean
else:
X = model.X
Y = model.Y
X = self.X
Y = self.Y
if sparse.issparse(Y): Y = Y.todense().view(np.ndarray)
if predict_kw is None:
@ -321,13 +343,13 @@ def plot_density(model, levels=20, plot_limits=None, fignum=None, ax=None,
#work out what the inputs are for plotting (1D or 2D)
fixed_dims = np.array([i for i,v in fixed_inputs])
free_dims = np.setdiff1d(np.arange(model.input_dim),fixed_dims)
free_dims = np.setdiff1d(np.arange(self.input_dim),fixed_dims)
plots = {}
#one dimensional plotting
if len(free_dims) == 1:
#define the frame on which to plot
Xnew, xmin, xmax = x_frame1D(X[:,free_dims], plot_limits=plot_limits, resolution=resolution)
Xgrid = np.empty((Xnew.shape[0],model.input_dim))
Xgrid = np.empty((Xnew.shape[0],self.input_dim))
Xgrid[:,free_dims] = Xnew
for i,v in fixed_inputs:
Xgrid[:,i] = v
@ -340,32 +362,32 @@ def plot_density(model, levels=20, plot_limits=None, fignum=None, ax=None,
from ...likelihoods import Gaussian
lik = Gaussian(variance=0)
else:
if isinstance(model,GPCoregionalizedRegression) or isinstance(model,SparseGPCoregionalizedRegression):
if isinstance(self,GPCoregionalizedRegression) or isinstance(self,SparseGPCoregionalizedRegression):
extra_data = Xgrid[:,-1:].astype(np.int)
if Y_metadata is None:
Y_metadata = {'output_index': extra_data}
else:
Y_metadata['output_index'] = extra_data
lik = None
percentiles = [i[:, 0] for i in model.predict_quantiles(Xgrid, percs, Y_metadata=Y_metadata, likelihood=lik, **predict_kw)]
percentiles = [i[:, 0] for i in self.predict_quantiles(Xgrid, percs, Y_metadata=Y_metadata, likelihood=lik, **predict_kw)]
if apply_link:
percentiles = model.likelihood.gp_link.transf(percentiles)
percentiles = self.likelihood.gp_link.transf(percentiles)
patch_kwargs['facecolor'] = facecolor
patch_kwargs['edgecolor'] = edgecolor
plots['density'] = plot_gradient_fill(ax, Xgrid[:, 0], percentiles, **patch_kwargs)
plots['density'] = gradient_fill(Xgrid[:, 0], percentiles, **patch_kwargs)
else:
raise NotImplementedError('Only 1D density plottable.')
return plots
def plot_fit_f(model, *args, **kwargs):
"""
Plot the GP's view of the world, where the data is normalized and before applying a likelihood.
All args and kwargs are passed on to models_plots.plot.
"""
kwargs['plot_raw'] = True
plot_fit(model,*args, **kwargs)
@wraps(plot_fit)
def plot_fit_f(self, plot_limits=None, which_data_rows='all',
which_data_ycols='all', fixed_inputs=[],
levels=20, samples=0, fignum=None, ax=None, resolution=None,
plot_raw=True,
linecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue'], Y_metadata=None, data_symbol='kx',
apply_link=False, samples_y=0, plot_uncertain_inputs=True, predict_kw=None, plot_training_data=True):
return plot_fit(self, plot_limits, which_data_rows, which_data_ycols, fixed_inputs, levels, samples, fignum, ax, resolution, plot_raw, linecol, fillcol, Y_metadata, data_symbol, apply_link, samples_y, plot_uncertain_inputs, predict_kw, plot_training_data)
def fixed_inputs(model, non_fixed_inputs, fix_routine='median', as_list=True, X_all=False):
"""
@ -465,7 +487,7 @@ def plot_errorbars_trainset(model, which_data_rows='all',
for d in which_data_ycols:
plots['gperrors'] = gperrors(X, m[:, d], lower[:, d], upper[:, d], edgecol=linecol, ax=ax, fignum=fignum, **kwargs )
if plot_training_data:
plots['dataplot'] = plot_data(model=model, which_data_rows=which_data_rows,
plots['dataplot'] = plot_data(self=model, which_data_rows=which_data_rows,
visible_dims=free_dims, data_symbol=data_symbol, mew=1.5, ax=ax, fignum=fignum)

View file

@ -1,331 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# netpbmfile.py
# Copyright (c) 2011-2013, Christoph Gohlke
# Copyright (c) 2011-2013, The Regents of the University of California
# Produced at the Laboratory for Fluorescence Dynamics.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the copyright holders nor the names of any
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
"""Read and write image data from respectively to Netpbm files.
This implementation follows the Netpbm format specifications at
http://netpbm.sourceforge.net/doc/. No gamma correction is performed.
The following image formats are supported: PBM (bi-level), PGM (grayscale),
PPM (color), PAM (arbitrary), XV thumbnail (RGB332, read-only).
:Author:
`Christoph Gohlke <http://www.lfd.uci.edu/~gohlke/>`_
:Organization:
Laboratory for Fluorescence Dynamics, University of California, Irvine
:Version: 2013.01.18
Requirements
------------
* `CPython 2.7, 3.2 or 3.3 <http://www.python.org>`_
* `Numpy 1.7 <http://www.numpy.org>`_
* `Matplotlib 1.2 <http://www.matplotlib.org>`_ (optional for plotting)
Examples
--------
>>> im1 = numpy.array([[0, 1],[65534, 65535]], dtype=numpy.uint16)
>>> imsave('_tmp.pgm', im1)
>>> im2 = imread('_tmp.pgm')
>>> assert numpy.all(im1 == im2)
"""
from __future__ import division, print_function
import sys
import re
import math
from copy import deepcopy
import numpy
__version__ = '2013.01.18'
__docformat__ = 'restructuredtext en'
__all__ = ['imread', 'imsave', 'NetpbmFile']
def imread(filename, *args, **kwargs):
"""Return image data from Netpbm file as numpy array.
`args` and `kwargs` are arguments to NetpbmFile.asarray().
Examples
--------
>>> image = imread('_tmp.pgm')
"""
try:
netpbm = NetpbmFile(filename)
image = netpbm.asarray()
finally:
netpbm.close()
return image
def imsave(filename, data, maxval=None, pam=False):
"""Write image data to Netpbm file.
Examples
--------
>>> image = numpy.array([[0, 1],[65534, 65535]], dtype=numpy.uint16)
>>> imsave('_tmp.pgm', image)
"""
try:
netpbm = NetpbmFile(data, maxval=maxval)
netpbm.write(filename, pam=pam)
finally:
netpbm.close()
class NetpbmFile(object):
"""Read and write Netpbm PAM, PBM, PGM, PPM, files."""
_types = {b'P1': b'BLACKANDWHITE', b'P2': b'GRAYSCALE', b'P3': b'RGB',
b'P4': b'BLACKANDWHITE', b'P5': b'GRAYSCALE', b'P6': b'RGB',
b'P7 332': b'RGB', b'P7': b'RGB_ALPHA'}
def __init__(self, arg=None, **kwargs):
"""Initialize instance from filename, open file, or numpy array."""
for attr in ('header', 'magicnum', 'width', 'height', 'maxval',
'depth', 'tupltypes', '_filename', '_fh', '_data'):
setattr(self, attr, None)
if arg is None:
self._fromdata([], **kwargs)
elif isinstance(arg, basestring):
self._fh = open(arg, 'rb')
self._filename = arg
self._fromfile(self._fh, **kwargs)
elif hasattr(arg, 'seek'):
self._fromfile(arg, **kwargs)
self._fh = arg
else:
self._fromdata(arg, **kwargs)
def asarray(self, copy=True, cache=False, **kwargs):
"""Return image data from file as numpy array."""
data = self._data
if data is None:
data = self._read_data(self._fh, **kwargs)
if cache:
self._data = data
else:
return data
return deepcopy(data) if copy else data
def write(self, arg, **kwargs):
"""Write instance to file."""
if hasattr(arg, 'seek'):
self._tofile(arg, **kwargs)
else:
with open(arg, 'wb') as fid:
self._tofile(fid, **kwargs)
def close(self):
"""Close open file. Future asarray calls might fail."""
if self._filename and self._fh:
self._fh.close()
self._fh = None
def __del__(self):
self.close()
def _fromfile(self, fh):
"""Initialize instance from open file."""
fh.seek(0)
data = fh.read(4096)
if (len(data) < 7) or not (b'0' < data[1:2] < b'8'):
raise ValueError("Not a Netpbm file:\n%s" % data[:32])
try:
self._read_pam_header(data)
except Exception:
try:
self._read_pnm_header(data)
except Exception:
raise ValueError("Not a Netpbm file:\n%s" % data[:32])
def _read_pam_header(self, data):
"""Read PAM header and initialize instance."""
regroups = re.search(
b"(^P7[\n\r]+(?:(?:[\n\r]+)|(?:#.*)|"
b"(HEIGHT\s+\d+)|(WIDTH\s+\d+)|(DEPTH\s+\d+)|(MAXVAL\s+\d+)|"
b"(?:TUPLTYPE\s+\w+))*ENDHDR\n)", data).groups()
self.header = regroups[0]
self.magicnum = b'P7'
for group in regroups[1:]:
key, value = group.split()
setattr(self, unicode(key).lower(), int(value))
matches = re.findall(b"(TUPLTYPE\s+\w+)", self.header)
self.tupltypes = [s.split(None, 1)[1] for s in matches]
def _read_pnm_header(self, data):
"""Read PNM header and initialize instance."""
bpm = data[1:2] in b"14"
regroups = re.search(b"".join((
b"(^(P[123456]|P7 332)\s+(?:#.*[\r\n])*",
b"\s*(\d+)\s+(?:#.*[\r\n])*",
b"\s*(\d+)\s+(?:#.*[\r\n])*" * (not bpm),
b"\s*(\d+)\s(?:\s*#.*[\r\n]\s)*)")), data).groups() + (1, ) * bpm
self.header = regroups[0]
self.magicnum = regroups[1]
self.width = int(regroups[2])
self.height = int(regroups[3])
self.maxval = int(regroups[4])
self.depth = 3 if self.magicnum in b"P3P6P7 332" else 1
self.tupltypes = [self._types[self.magicnum]]
def _read_data(self, fh, byteorder='>'):
"""Return image data from open file as numpy array."""
fh.seek(len(self.header))
data = fh.read()
dtype = 'u1' if self.maxval < 256 else byteorder + 'u2'
depth = 1 if self.magicnum == b"P7 332" else self.depth
shape = [-1, self.height, self.width, depth]
size = numpy.prod(shape[1:])
if self.magicnum in b"P1P2P3":
data = numpy.array(data.split(None, size)[:size], dtype)
data = data.reshape(shape)
elif self.maxval == 1:
shape[2] = int(math.ceil(self.width / 8))
data = numpy.frombuffer(data, dtype).reshape(shape)
data = numpy.unpackbits(data, axis=-2)[:, :, :self.width, :]
else:
data = numpy.frombuffer(data, dtype)
data = data[:size * (data.size // size)].reshape(shape)
if data.shape[0] < 2:
data = data.reshape(data.shape[1:])
if data.shape[-1] < 2:
data = data.reshape(data.shape[:-1])
if self.magicnum == b"P7 332":
rgb332 = numpy.array(list(numpy.ndindex(8, 8, 4)), numpy.uint8)
rgb332 *= [36, 36, 85]
data = numpy.take(rgb332, data, axis=0)
return data
def _fromdata(self, data, maxval=None):
"""Initialize instance from numpy array."""
data = numpy.array(data, ndmin=2, copy=True)
if data.dtype.kind not in "uib":
raise ValueError("not an integer type: %s" % data.dtype)
if data.dtype.kind == 'i' and numpy.min(data) < 0:
raise ValueError("data out of range: %i" % numpy.min(data))
if maxval is None:
maxval = numpy.max(data)
maxval = 255 if maxval < 256 else 65535
if maxval < 0 or maxval > 65535:
raise ValueError("data out of range: %i" % maxval)
data = data.astype('u1' if maxval < 256 else '>u2')
self._data = data
if data.ndim > 2 and data.shape[-1] in (3, 4):
self.depth = data.shape[-1]
self.width = data.shape[-2]
self.height = data.shape[-3]
self.magicnum = b'P7' if self.depth == 4 else b'P6'
else:
self.depth = 1
self.width = data.shape[-1]
self.height = data.shape[-2]
self.magicnum = b'P5' if maxval > 1 else b'P4'
self.maxval = maxval
self.tupltypes = [self._types[self.magicnum]]
self.header = self._header()
def _tofile(self, fh, pam=False):
"""Write Netbm file."""
fh.seek(0)
fh.write(self._header(pam))
data = self.asarray(copy=False)
if self.maxval == 1:
data = numpy.packbits(data, axis=-1)
data.tofile(fh)
def _header(self, pam=False):
"""Return file header as byte string."""
if pam or self.magicnum == b'P7':
header = "\n".join((
"P7",
"HEIGHT %i" % self.height,
"WIDTH %i" % self.width,
"DEPTH %i" % self.depth,
"MAXVAL %i" % self.maxval,
"\n".join("TUPLTYPE %s" % unicode(i) for i in self.tupltypes),
"ENDHDR\n"))
elif self.maxval == 1:
header = "P4 %i %i\n" % (self.width, self.height)
elif self.depth == 1:
header = "P5 %i %i %i\n" % (self.width, self.height, self.maxval)
else:
header = "P6 %i %i %i\n" % (self.width, self.height, self.maxval)
if sys.version_info[0] > 2:
header = bytes(header, 'ascii')
return header
def __str__(self):
"""Return information about instance."""
return unicode(self.header)
if sys.version_info[0] > 2:
basestring = str
unicode = lambda x: str(x, 'ascii')
if __name__ == "__main__":
# Show images specified on command line or all images in current directory
from glob import glob
from matplotlib import pyplot
files = sys.argv[1:] if len(sys.argv) > 1 else glob('*.p*m')
for fname in files:
try:
pam = NetpbmFile(fname)
img = pam.asarray(copy=False)
if False:
pam.write('_tmp.pgm.out', pam=True)
img2 = imread('_tmp.pgm.out')
assert numpy.all(img == img2)
imsave('_tmp.pgm.out', img)
img2 = imread('_tmp.pgm.out')
assert numpy.all(img == img2)
pam.close()
except ValueError as e:
print(fname, e)
continue
_shape = img.shape
if img.ndim > 3 or (img.ndim > 2 and img.shape[-1] not in (3, 4)):
img = img[0]
cmap = 'gray' if pam.maxval > 1 else 'binary'
pyplot.imshow(img, cmap, interpolation='nearest')
pyplot.title("%s %s %s %s" % (fname, unicode(pam.magicnum),
_shape, img.dtype))
pyplot.show()

View file

@ -18,7 +18,7 @@
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of paramax nor the names of its
# * Neither the name of GPy nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#

View file

@ -13,7 +13,7 @@
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of paramax nor the names of its
# * Neither the name of GPy nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
@ -32,8 +32,8 @@
#!/usr/bin/env python
import matplotlib
matplotlib.use('svg')
matplotlib.use('agg')
import nose
nose.main('GPy', defaultTest='GPy/testing')
nose.main('GPy', defaultTest='GPy/testing/plotting_tests.py')