mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-04 01:02:39 +02:00
[plotting] cleanup first commit, this cleans the plotting library and adds plotting tests
This commit is contained in:
parent
fee2f3f727
commit
b9bfd0fc6d
10 changed files with 200 additions and 690 deletions
254
GPy/core/gp.py
254
GPy/core/gp.py
|
|
@ -477,260 +477,6 @@ class GP(Model):
|
||||||
Ysim = self.likelihood.samples(fsim, Y_metadata=Y_metadata)
|
Ysim = self.likelihood.samples(fsim, Y_metadata=Y_metadata)
|
||||||
return Ysim
|
return Ysim
|
||||||
|
|
||||||
def plot_f(self, plot_limits=None, which_data_rows='all',
|
|
||||||
which_data_ycols='all', fixed_inputs=[],
|
|
||||||
levels=20, samples=0, fignum=None, ax=None, resolution=None,
|
|
||||||
plot_raw=True,
|
|
||||||
linecol=None,fillcol=None, Y_metadata=None, data_symbol='kx',
|
|
||||||
apply_link=False):
|
|
||||||
"""
|
|
||||||
Plot the GP's view of the world, where the data is normalized and before applying a likelihood.
|
|
||||||
This is a call to plot with plot_raw=True.
|
|
||||||
Data will not be plotted in this, as the GP's view of the world
|
|
||||||
may live in another space, or units then the data.
|
|
||||||
|
|
||||||
Can plot only part of the data and part of the posterior functions
|
|
||||||
using which_data_rowsm which_data_ycols.
|
|
||||||
|
|
||||||
:param plot_limits: The limits of the plot. If 1D [xmin,xmax], if 2D [[xmin,ymin],[xmax,ymax]]. Defaluts to data limits
|
|
||||||
:type plot_limits: np.array
|
|
||||||
:param which_data_rows: which of the training data to plot (default all)
|
|
||||||
:type which_data_rows: 'all' or a slice object to slice model.X, model.Y
|
|
||||||
:param which_data_ycols: when the data has several columns (independant outputs), only plot these
|
|
||||||
:type which_data_ycols: 'all' or a list of integers
|
|
||||||
:param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input index i should be set to value v.
|
|
||||||
:type fixed_inputs: a list of tuples
|
|
||||||
:param resolution: the number of intervals to sample the GP on. Defaults to 200 in 1D and 50 (a 50x50 grid) in 2D
|
|
||||||
:type resolution: int
|
|
||||||
:param levels: number of levels to plot in a contour plot.
|
|
||||||
:param levels: for 2D plotting, the number of contour levels to use is ax is None, create a new figure
|
|
||||||
:type levels: int
|
|
||||||
:param samples: the number of a posteriori samples to plot
|
|
||||||
:type samples: int
|
|
||||||
:param fignum: figure to plot on.
|
|
||||||
:type fignum: figure number
|
|
||||||
:param ax: axes to plot on.
|
|
||||||
:type ax: axes handle
|
|
||||||
:param linecol: color of line to plot [Tango.colorsHex['darkBlue']]
|
|
||||||
:type linecol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
|
|
||||||
:param fillcol: color of fill [Tango.colorsHex['lightBlue']]
|
|
||||||
:type fillcol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
|
|
||||||
:param Y_metadata: additional data associated with Y which may be needed
|
|
||||||
:type Y_metadata: dict
|
|
||||||
:param data_symbol: symbol as used matplotlib, by default this is a black cross ('kx')
|
|
||||||
:type data_symbol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) alongside marker type, as is standard in matplotlib.
|
|
||||||
:param apply_link: if there is a link function of the likelihood, plot the link(f*) rather than f*
|
|
||||||
:type apply_link: boolean
|
|
||||||
"""
|
|
||||||
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
|
||||||
from ..plotting.matplot_dep import models_plots
|
|
||||||
kw = {}
|
|
||||||
if linecol is not None:
|
|
||||||
kw['linecol'] = linecol
|
|
||||||
if fillcol is not None:
|
|
||||||
kw['fillcol'] = fillcol
|
|
||||||
return models_plots.plot_fit(self, plot_limits, which_data_rows,
|
|
||||||
which_data_ycols, fixed_inputs,
|
|
||||||
levels, samples, fignum, ax, resolution,
|
|
||||||
plot_raw=plot_raw, Y_metadata=Y_metadata,
|
|
||||||
data_symbol=data_symbol, apply_link=apply_link, **kw)
|
|
||||||
|
|
||||||
def plot(self, plot_limits=None, which_data_rows='all',
|
|
||||||
which_data_ycols='all', fixed_inputs=[],
|
|
||||||
levels=20, samples=0, fignum=None, ax=None, resolution=None,
|
|
||||||
plot_raw=False, linecol=None,fillcol=None, Y_metadata=None,
|
|
||||||
data_symbol='kx', predict_kw=None, plot_training_data=True, samples_y=0, apply_link=False):
|
|
||||||
"""
|
|
||||||
Plot the posterior of the GP.
|
|
||||||
- In one dimension, the function is plotted with a shaded region identifying two standard deviations.
|
|
||||||
- In two dimsensions, a contour-plot shows the mean predicted function
|
|
||||||
- In higher dimensions, use fixed_inputs to plot the GP with some of the inputs fixed.
|
|
||||||
|
|
||||||
Can plot only part of the data and part of the posterior functions
|
|
||||||
using which_data_rowsm which_data_ycols.
|
|
||||||
|
|
||||||
:param plot_limits: The limits of the plot. If 1D [xmin,xmax], if 2D [[xmin,ymin],[xmax,ymax]]. Defaluts to data limits
|
|
||||||
:type plot_limits: np.array
|
|
||||||
:param which_data_rows: which of the training data to plot (default all)
|
|
||||||
:type which_data_rows: 'all' or a slice object to slice model.X, model.Y
|
|
||||||
:param which_data_ycols: when the data has several columns (independant outputs), only plot these
|
|
||||||
:type which_data_ycols: 'all' or a list of integers
|
|
||||||
:param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input index i should be set to value v.
|
|
||||||
:type fixed_inputs: a list of tuples
|
|
||||||
:param resolution: the number of intervals to sample the GP on. Defaults to 200 in 1D and 50 (a 50x50 grid) in 2D
|
|
||||||
:type resolution: int
|
|
||||||
:param levels: number of levels to plot in a contour plot.
|
|
||||||
:param levels: for 2D plotting, the number of contour levels to use is ax is None, create a new figure
|
|
||||||
:type levels: int
|
|
||||||
:param samples: the number of a posteriori samples to plot, p(f*|y)
|
|
||||||
:type samples: int
|
|
||||||
:param fignum: figure to plot on.
|
|
||||||
:type fignum: figure number
|
|
||||||
:param ax: axes to plot on.
|
|
||||||
:type ax: axes handle
|
|
||||||
:param linecol: color of line to plot [Tango.colorsHex['darkBlue']]
|
|
||||||
:type linecol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
|
|
||||||
:param fillcol: color of fill [Tango.colorsHex['lightBlue']]
|
|
||||||
:type fillcol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
|
|
||||||
:param Y_metadata: additional data associated with Y which may be needed
|
|
||||||
:type Y_metadata: dict
|
|
||||||
:param data_symbol: symbol as used matplotlib, by default this is a black cross ('kx')
|
|
||||||
:type data_symbol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) alongside marker type, as is standard in matplotlib.
|
|
||||||
:param plot_training_data: whether or not to plot the training points
|
|
||||||
:type plot_training_data: boolean
|
|
||||||
:param samples_y: the number of a posteriori samples to plot, p(y*|y)
|
|
||||||
:type samples_y: int
|
|
||||||
:param apply_link: if there is a link function of the likelihood, plot the link(f*) rather than f*, when plotting posterior samples f
|
|
||||||
:type apply_link: boolean
|
|
||||||
"""
|
|
||||||
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
|
||||||
from ..plotting.matplot_dep import models_plots
|
|
||||||
kw = {}
|
|
||||||
if linecol is not None:
|
|
||||||
kw['linecol'] = linecol
|
|
||||||
if fillcol is not None:
|
|
||||||
kw['fillcol'] = fillcol
|
|
||||||
return models_plots.plot_fit(self, plot_limits, which_data_rows,
|
|
||||||
which_data_ycols, fixed_inputs,
|
|
||||||
levels, samples, fignum, ax, resolution,
|
|
||||||
plot_raw=plot_raw, Y_metadata=Y_metadata,
|
|
||||||
data_symbol=data_symbol, predict_kw=predict_kw,
|
|
||||||
plot_training_data=plot_training_data, samples_y=samples_y, apply_link=apply_link, **kw)
|
|
||||||
|
|
||||||
def plot_density(self, levels=20, plot_limits=None, fignum=None, ax=None,
|
|
||||||
fixed_inputs=[], plot_raw=False, edgecolor='none', facecolor='#3465a4',
|
|
||||||
predict_kw=None,Y_metadata=None,
|
|
||||||
apply_link=False, resolution=200, **patch_kw):
|
|
||||||
"""
|
|
||||||
Plot the posterior density of the GP.
|
|
||||||
- In one dimension, the function is plotted with a shaded gradient, visualizing the density of the posterior.
|
|
||||||
- Only implemented for one dimension, for higher dimensions use `plot`.
|
|
||||||
|
|
||||||
:param levels: number of levels to plot in the density plot. This is a number between 1 and 100. 1 corresponds to the normal plot_fit.
|
|
||||||
:type levels: int
|
|
||||||
:param plot_limits: The limits of the plot. If 1D [xmin,xmax], if 2D [[xmin,ymin],[xmax,ymax]]. Defaluts to data limits
|
|
||||||
:type plot_limits: np.array
|
|
||||||
:param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input index i should be set to value v.
|
|
||||||
:type fixed_inputs: a list of tuples
|
|
||||||
:param resolution: the number of intervals to sample the GP on. Defaults to 200 in 1D and 50 (a 50x50 grid) in 2D
|
|
||||||
:type resolution: int
|
|
||||||
:param edgecolor: color of line to plot [Tango.colorsHex['darkBlue']]
|
|
||||||
:type edgecolor: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
|
|
||||||
:param facecolor: color of fill [Tango.colorsHex['lightBlue']]
|
|
||||||
:type facecolor: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
|
|
||||||
:param Y_metadata: additional data associated with Y which may be needed
|
|
||||||
:type Y_metadata: dict
|
|
||||||
:param apply_link: if there is a link function of the likelihood, plot the link(f*) rather than f*, when plotting posterior samples f
|
|
||||||
:type apply_link: boolean
|
|
||||||
:param resolution: resolution of interpolation (how many points to interpolate of the posterior).
|
|
||||||
:type resolution: int
|
|
||||||
:param: patch_kw: the keyword arguments for the patchcollection fill.
|
|
||||||
"""
|
|
||||||
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
|
||||||
from ..plotting.matplot_dep import models_plots
|
|
||||||
return models_plots.plot_density(self, levels, plot_limits, fignum, ax,
|
|
||||||
fixed_inputs, plot_raw=plot_raw,
|
|
||||||
Y_metadata=Y_metadata,
|
|
||||||
predict_kw=predict_kw,
|
|
||||||
apply_link=apply_link,
|
|
||||||
edgecolor=edgecolor, facecolor=facecolor,
|
|
||||||
**patch_kw)
|
|
||||||
|
|
||||||
|
|
||||||
def plot_data(self, which_data_rows='all',
|
|
||||||
which_data_ycols='all', visible_dims=None,
|
|
||||||
fignum=None, ax=None, data_symbol='kx'):
|
|
||||||
"""
|
|
||||||
Plot the training data
|
|
||||||
- For higher dimensions than two, use fixed_inputs to plot the data points with some of the inputs fixed.
|
|
||||||
|
|
||||||
Can plot only part of the data
|
|
||||||
using which_data_rows and which_data_ycols.
|
|
||||||
|
|
||||||
:param plot_limits: The limits of the plot. If 1D [xmin,xmax], if 2D [[xmin,ymin],[xmax,ymax]]. Defaluts to data limits
|
|
||||||
:type plot_limits: np.array
|
|
||||||
:param which_data_rows: which of the training data to plot (default all)
|
|
||||||
:type which_data_rows: 'all' or a slice object to slice model.X, model.Y
|
|
||||||
:param which_data_ycols: when the data has several columns (independant outputs), only plot these
|
|
||||||
:type which_data_ycols: 'all' or a list of integers
|
|
||||||
:param visible_dims: an array specifying the input dimensions to plot (maximum two)
|
|
||||||
:type visible_dims: a numpy array
|
|
||||||
:param resolution: the number of intervals to sample the GP on. Defaults to 200 in 1D and 50 (a 50x50 grid) in 2D
|
|
||||||
:type resolution: int
|
|
||||||
:param levels: number of levels to plot in a contour plot.
|
|
||||||
:param levels: for 2D plotting, the number of contour levels to use is ax is None, create a new figure
|
|
||||||
:type levels: int
|
|
||||||
:param samples: the number of a posteriori samples to plot, p(f*|y)
|
|
||||||
:type samples: int
|
|
||||||
:param fignum: figure to plot on.
|
|
||||||
:type fignum: figure number
|
|
||||||
:param ax: axes to plot on.
|
|
||||||
:type ax: axes handle
|
|
||||||
:param linecol: color of line to plot [Tango.colorsHex['darkBlue']]
|
|
||||||
:type linecol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
|
|
||||||
:param fillcol: color of fill [Tango.colorsHex['lightBlue']]
|
|
||||||
:type fillcol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
|
|
||||||
:param data_symbol: symbol as used matplotlib, by default this is a black cross ('kx')
|
|
||||||
:type data_symbol: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) alongside marker type, as is standard in matplotlib.
|
|
||||||
"""
|
|
||||||
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
|
||||||
from ..plotting.matplot_dep import models_plots
|
|
||||||
kw = {}
|
|
||||||
return models_plots.plot_data(self, which_data_rows,
|
|
||||||
which_data_ycols, visible_dims,
|
|
||||||
fignum, ax, data_symbol, **kw)
|
|
||||||
|
|
||||||
|
|
||||||
def plot_errorbars_trainset(self, which_data_rows='all',
|
|
||||||
which_data_ycols='all', fixed_inputs=[], fignum=None, ax=None,
|
|
||||||
linecol=None, data_symbol='kx', predict_kw=None, plot_training_data=True,lw=None):
|
|
||||||
|
|
||||||
"""
|
|
||||||
Plot the posterior error bars corresponding to the training data
|
|
||||||
- For higher dimensions than two, use fixed_inputs to plot the data points with some of the inputs fixed.
|
|
||||||
|
|
||||||
Can plot only part of the data
|
|
||||||
using which_data_rows and which_data_ycols.
|
|
||||||
|
|
||||||
:param which_data_rows: which of the training data to plot (default all)
|
|
||||||
:type which_data_rows: 'all' or a slice object to slice model.X, model.Y
|
|
||||||
:param which_data_ycols: when the data has several columns (independant outputs), only plot these
|
|
||||||
:type which_data_rows: 'all' or a list of integers
|
|
||||||
:param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input index i should be set to value v.
|
|
||||||
:type fixed_inputs: a list of tuples
|
|
||||||
:param fignum: figure to plot on.
|
|
||||||
:type fignum: figure number
|
|
||||||
:param ax: axes to plot on.
|
|
||||||
:type ax: axes handle
|
|
||||||
:param plot_training_data: whether or not to plot the training points
|
|
||||||
:type plot_training_data: boolean
|
|
||||||
"""
|
|
||||||
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
|
||||||
from ..plotting.matplot_dep import models_plots
|
|
||||||
kw = {}
|
|
||||||
if lw is not None:
|
|
||||||
kw['lw'] = lw
|
|
||||||
return models_plots.plot_errorbars_trainset(self, which_data_rows, which_data_ycols, fixed_inputs,
|
|
||||||
fignum, ax, linecol, data_symbol,
|
|
||||||
predict_kw, plot_training_data, **kw)
|
|
||||||
|
|
||||||
|
|
||||||
def plot_magnification(self, labels=None, which_indices=None,
|
|
||||||
resolution=50, ax=None, marker='o', s=40,
|
|
||||||
fignum=None, legend=True,
|
|
||||||
plot_limits=None,
|
|
||||||
aspect='auto', updates=False, plot_inducing=True, kern=None, **kwargs):
|
|
||||||
|
|
||||||
import sys
|
|
||||||
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
|
||||||
from ..plotting.matplot_dep import dim_reduction_plots
|
|
||||||
|
|
||||||
return dim_reduction_plots.plot_magnification(self, labels, which_indices,
|
|
||||||
resolution, ax, marker, s,
|
|
||||||
fignum, plot_inducing, legend,
|
|
||||||
plot_limits, aspect, updates, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def input_sensitivity(self, summarize=True):
|
def input_sensitivity(self, summarize=True):
|
||||||
"""
|
"""
|
||||||
Returns the sensitivity for each dimension of this model
|
Returns the sensitivity for each dimension of this model
|
||||||
|
|
|
||||||
|
|
@ -28,3 +28,6 @@ working = True
|
||||||
|
|
||||||
[cython]
|
[cython]
|
||||||
working = True
|
working = True
|
||||||
|
|
||||||
|
[plotting]
|
||||||
|
library = matplotlib
|
||||||
|
|
@ -13,7 +13,7 @@ from functools import wraps
|
||||||
|
|
||||||
def put_clean(dct, name, func):
|
def put_clean(dct, name, func):
|
||||||
if name in dct:
|
if name in dct:
|
||||||
dct['_clean_{}'.format(name)] = dct[name]
|
#dct['_clean_{}'.format(name)] = dct[name]
|
||||||
dct[name] = func(dct[name])
|
dct[name] = func(dct[name])
|
||||||
|
|
||||||
class KernCallsViaSlicerMeta(ParametersChangedMeta):
|
class KernCallsViaSlicerMeta(ParametersChangedMeta):
|
||||||
|
|
|
||||||
|
|
@ -2,11 +2,30 @@
|
||||||
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import matplotlib
|
from ..util.config import config
|
||||||
from . import matplot_dep
|
lib = config.get('plotting', 'library')
|
||||||
|
if lib == 'matplotlib':
|
||||||
|
import matplotlib
|
||||||
|
from . import matplot_dep as plotting_library
|
||||||
except (ImportError, NameError):
|
except (ImportError, NameError):
|
||||||
# Matplotlib not available
|
|
||||||
import warnings
|
import warnings
|
||||||
warnings.warn(ImportWarning("Matplotlib not available, install newest version of Matplotlib for plotting"))
|
warnings.warn(ImportWarning("{} not available, install newest version of {} for plotting").format(lib, lib))
|
||||||
#sys.modules['matplotlib'] =
|
config.set('plotting', 'library', 'none')
|
||||||
#sys.modules[__name__+'.matplot_dep'] = ImportWarning("Matplotlib not available, install newest version of Matplotlib for plotting")
|
|
||||||
|
if config.get('plotting', 'library') is not 'none':
|
||||||
|
# Inject the plots into classes here:
|
||||||
|
|
||||||
|
# Already converted to new style:
|
||||||
|
from . import gpy_plot
|
||||||
|
|
||||||
|
from ..core import GP
|
||||||
|
GP.plot_data = gpy_plot.data_plots.plot_data
|
||||||
|
|
||||||
|
# Still to convert to new style:
|
||||||
|
GP.plot = plotting_library.models_plots.plot_fit
|
||||||
|
GP.plot_f = plotting_library.models_plots.plot_fit_f
|
||||||
|
GP.plot_density = plotting_library.models_plots.plot_density
|
||||||
|
|
||||||
|
GP.plot_errorbars_trainset = plotting_library.models_plots.plot_errorbars_trainset
|
||||||
|
GP.plot_magnification = plotting_library.dim_reduction_plots.plot_magnification
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,54 @@
|
||||||
# Copyright (c) 2014, GPy authors (see AUTHORS.txt).
|
# Copyright (c) 2014, GPy authors (see AUTHORS.txt).
|
||||||
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||||
|
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
from . import defaults
|
||||||
|
|
||||||
|
def get_new_canvas(kwargs):
|
||||||
|
"""
|
||||||
|
Return a canvas, kwargupdate for matplotlib. This just a
|
||||||
|
dictionary for the collection and we add the an axis to kwarg.
|
||||||
|
|
||||||
|
This method does two things, it creates an empty canvas
|
||||||
|
and updates the kwargs (deletes the unnecessary kwargs)
|
||||||
|
for further usage in normal plotting.
|
||||||
|
|
||||||
|
in matplotlib this means it deletes references to ax, as
|
||||||
|
plotting is done on the axis itself and is not a kwarg.
|
||||||
|
"""
|
||||||
|
if 'ax' in kwargs:
|
||||||
|
ax = kwargs.pop('ax')
|
||||||
|
elif 'num' in kwargs and 'figsize' in kwargs:
|
||||||
|
ax = plt.figure(num=kwargs.pop('num'), figsize=kwargs.pop('figsize')).add_subplot(111)
|
||||||
|
elif 'num' in kwargs:
|
||||||
|
ax = plt.figure(num=kwargs.pop('num')).add_subplot(111)
|
||||||
|
elif 'figsize' in kwargs:
|
||||||
|
ax = plt.figure(figsize=kwargs.pop('figsize')).add_subplot(111)
|
||||||
|
else:
|
||||||
|
ax = plt.figure().add_subplot(111)
|
||||||
|
# Add ax to kwargs to add all subsequent plots to this axis:
|
||||||
|
#kwargs['ax'] = ax
|
||||||
|
return ax, kwargs
|
||||||
|
|
||||||
|
def show_canvas(canvas):
|
||||||
|
try:
|
||||||
|
canvas.figure.canvas.draw()
|
||||||
|
canvas.figure.tight_layout()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return canvas
|
||||||
|
|
||||||
|
|
||||||
|
def scatter(ax, *args, **kwargs):
|
||||||
|
ax.scatter(*args, **kwargs)
|
||||||
|
|
||||||
|
def plot(ax, *args, **kwargs):
|
||||||
|
ax.plot(*args, **kwargs)
|
||||||
|
|
||||||
|
def imshow(ax, *args, **kwargs):
|
||||||
|
ax.imshow(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
from . import base_plots
|
from . import base_plots
|
||||||
from . import models_plots
|
from . import models_plots
|
||||||
from . import priors_plots
|
from . import priors_plots
|
||||||
|
|
@ -11,8 +59,9 @@ from . import mapping_plots
|
||||||
from . import Tango
|
from . import Tango
|
||||||
from . import visualize
|
from . import visualize
|
||||||
from . import latent_space_visualizations
|
from . import latent_space_visualizations
|
||||||
from . import netpbmfile
|
|
||||||
from . import inference_plots
|
from . import inference_plots
|
||||||
from . import maps
|
from . import maps
|
||||||
from . import img_plots
|
from . import img_plots
|
||||||
from .ssgplvm import SSGPLVM_plot
|
from .ssgplvm import SSGPLVM_plot
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
# #Copyright (c) 2012, GPy authors (see AUTHORS.txt).
|
# #Copyright (c) 2012, GPy authors (see AUTHORS.txt).
|
||||||
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||||
from matplotlib import pyplot as pb
|
from matplotlib import pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
def ax_default(fignum, ax):
|
def ax_default(fignum, ax):
|
||||||
if ax is None:
|
if ax is None:
|
||||||
fig = pb.figure(fignum)
|
fig = plt.figure(fignum)
|
||||||
ax = fig.add_subplot(111)
|
ax = fig.add_subplot(111)
|
||||||
else:
|
else:
|
||||||
fig = ax.figure
|
fig = ax.figure
|
||||||
|
|
@ -35,12 +35,14 @@ def gpplot(x, mu, lower, upper, edgecol='#3300FF', fillcol='#33CCFF', ax=None, f
|
||||||
plots.append(axes.fill(np.hstack((x,x[::-1])),np.hstack((upper,lower[::-1])),color=fillcol,**kwargs))
|
plots.append(axes.fill(np.hstack((x,x[::-1])),np.hstack((upper,lower[::-1])),color=fillcol,**kwargs))
|
||||||
|
|
||||||
#this is the edge:
|
#this is the edge:
|
||||||
plots.append(meanplot(x, upper,color=edgecol,linewidth=0.2,ax=axes))
|
plots.append(meanplot(x, upper,color=edgecol, linewidth=0.2, ax=axes))
|
||||||
plots.append(meanplot(x, lower,color=edgecol,linewidth=0.2,ax=axes))
|
plots.append(meanplot(x, lower,color=edgecol, linewidth=0.2, ax=axes))
|
||||||
|
|
||||||
return plots
|
return plots
|
||||||
|
|
||||||
def plot_gradient_fill(ax, x, percentiles, **kwargs):
|
def gradient_fill(x, percentiles, ax=None, fignum=None, **kwargs):
|
||||||
|
_, ax = ax_default(fignum, ax)
|
||||||
|
|
||||||
plots = []
|
plots = []
|
||||||
|
|
||||||
#here's the box
|
#here's the box
|
||||||
|
|
@ -150,19 +152,19 @@ def gperrors(x, mu, lower, upper, edgecol=None, ax=None, fignum=None, **kwargs):
|
||||||
|
|
||||||
|
|
||||||
def removeRightTicks(ax=None):
|
def removeRightTicks(ax=None):
|
||||||
ax = ax or pb.gca()
|
ax = ax or plt.gca()
|
||||||
for i, line in enumerate(ax.get_yticklines()):
|
for i, line in enumerate(ax.get_yticklines()):
|
||||||
if i%2 == 1: # odd indices
|
if i%2 == 1: # odd indices
|
||||||
line.set_visible(False)
|
line.set_visible(False)
|
||||||
|
|
||||||
def removeUpperTicks(ax=None):
|
def removeUpperTicks(ax=None):
|
||||||
ax = ax or pb.gca()
|
ax = ax or plt.gca()
|
||||||
for i, line in enumerate(ax.get_xticklines()):
|
for i, line in enumerate(ax.get_xticklines()):
|
||||||
if i%2 == 1: # odd indices
|
if i%2 == 1: # odd indices
|
||||||
line.set_visible(False)
|
line.set_visible(False)
|
||||||
|
|
||||||
def fewerXticks(ax=None,divideby=2):
|
def fewerXticks(ax=None,divideby=2):
|
||||||
ax = ax or pb.gca()
|
ax = ax or plt.gca()
|
||||||
ax.set_xticks(ax.get_xticks()[::divideby])
|
ax.set_xticks(ax.get_xticks()[::divideby])
|
||||||
|
|
||||||
def align_subplots(N,M,xlim=None, ylim=None):
|
def align_subplots(N,M,xlim=None, ylim=None):
|
||||||
|
|
@ -171,33 +173,33 @@ def align_subplots(N,M,xlim=None, ylim=None):
|
||||||
if xlim is None:
|
if xlim is None:
|
||||||
xlim = [np.inf,-np.inf]
|
xlim = [np.inf,-np.inf]
|
||||||
for i in range(N*M):
|
for i in range(N*M):
|
||||||
pb.subplot(N,M,i+1)
|
plt.subplot(N,M,i+1)
|
||||||
xlim[0] = min(xlim[0],pb.xlim()[0])
|
xlim[0] = min(xlim[0],plt.xlim()[0])
|
||||||
xlim[1] = max(xlim[1],pb.xlim()[1])
|
xlim[1] = max(xlim[1],plt.xlim()[1])
|
||||||
if ylim is None:
|
if ylim is None:
|
||||||
ylim = [np.inf,-np.inf]
|
ylim = [np.inf,-np.inf]
|
||||||
for i in range(N*M):
|
for i in range(N*M):
|
||||||
pb.subplot(N,M,i+1)
|
plt.subplot(N,M,i+1)
|
||||||
ylim[0] = min(ylim[0],pb.ylim()[0])
|
ylim[0] = min(ylim[0],plt.ylim()[0])
|
||||||
ylim[1] = max(ylim[1],pb.ylim()[1])
|
ylim[1] = max(ylim[1],plt.ylim()[1])
|
||||||
|
|
||||||
for i in range(N*M):
|
for i in range(N*M):
|
||||||
pb.subplot(N,M,i+1)
|
plt.subplot(N,M,i+1)
|
||||||
pb.xlim(xlim)
|
plt.xlim(xlim)
|
||||||
pb.ylim(ylim)
|
plt.ylim(ylim)
|
||||||
if (i)%M:
|
if (i)%M:
|
||||||
pb.yticks([])
|
plt.yticks([])
|
||||||
else:
|
else:
|
||||||
removeRightTicks()
|
removeRightTicks()
|
||||||
if i<(M*(N-1)):
|
if i<(M*(N-1)):
|
||||||
pb.xticks([])
|
plt.xticks([])
|
||||||
else:
|
else:
|
||||||
removeUpperTicks()
|
removeUpperTicks()
|
||||||
|
|
||||||
def align_subplot_array(axes,xlim=None, ylim=None):
|
def align_subplot_array(axes,xlim=None, ylim=None):
|
||||||
"""
|
"""
|
||||||
Make all of the axes in the array hae the same limits, turn off unnecessary ticks
|
Make all of the axes in the array hae the same limits, turn off unnecessary ticks
|
||||||
use pb.subplots() to get an array of axes
|
use plt.subplots() to get an array of axes
|
||||||
"""
|
"""
|
||||||
#find sensible xlim,ylim
|
#find sensible xlim,ylim
|
||||||
if xlim is None:
|
if xlim is None:
|
||||||
|
|
|
||||||
|
|
@ -9,12 +9,13 @@ from ...models.sparse_gp_coregionalized_regression import SparseGPCoregionalized
|
||||||
from scipy import sparse
|
from scipy import sparse
|
||||||
from ...core.parameterization.variational import VariationalPosterior
|
from ...core.parameterization.variational import VariationalPosterior
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from .base_plots import plot_gradient_fill
|
from .base_plots import gradient_fill
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
|
||||||
def plot_data(model, which_data_rows='all',
|
def plot_data(self, which_data_rows='all',
|
||||||
which_data_ycols='all', visible_dims=None,
|
which_data_ycols='all', visible_dims=None,
|
||||||
fignum=None, ax=None, data_symbol='kx',mew=1.5):
|
fignum=None, ax=None, data_symbol='kx',mew=1.5,**kwargs):
|
||||||
"""
|
"""
|
||||||
Plot the training data
|
Plot the training data
|
||||||
- For higher dimensions than two, use fixed_inputs to plot the data points with some of the inputs fixed.
|
- For higher dimensions than two, use fixed_inputs to plot the data points with some of the inputs fixed.
|
||||||
|
|
@ -23,7 +24,7 @@ def plot_data(model, which_data_rows='all',
|
||||||
using which_data_rows and which_data_ycols.
|
using which_data_rows and which_data_ycols.
|
||||||
|
|
||||||
:param which_data_rows: which of the training data to plot (default all)
|
:param which_data_rows: which of the training data to plot (default all)
|
||||||
:type which_data_rows: 'all' or a slice object to slice model.X, model.Y
|
:type which_data_rows: 'all' or a slice object to slice self.X, self.Y
|
||||||
:param which_data_ycols: when the data has several columns (independant outputs), only plot these
|
:param which_data_ycols: when the data has several columns (independant outputs), only plot these
|
||||||
:type which_data_rows: 'all' or a list of integers
|
:type which_data_rows: 'all' or a list of integers
|
||||||
:param visible_dims: an array specifying the input dimensions to plot (maximum two)
|
:param visible_dims: an array specifying the input dimensions to plot (maximum two)
|
||||||
|
|
@ -37,23 +38,23 @@ def plot_data(model, which_data_rows='all',
|
||||||
if which_data_rows == 'all':
|
if which_data_rows == 'all':
|
||||||
which_data_rows = slice(None)
|
which_data_rows = slice(None)
|
||||||
if which_data_ycols == 'all':
|
if which_data_ycols == 'all':
|
||||||
which_data_ycols = np.arange(model.output_dim)
|
which_data_ycols = np.arange(self.output_dim)
|
||||||
|
|
||||||
if ax is None:
|
if ax is None:
|
||||||
fig = plt.figure(num=fignum)
|
fig = plt.figure(num=fignum)
|
||||||
ax = fig.add_subplot(111)
|
ax = fig.add_subplot(111)
|
||||||
|
|
||||||
if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs():
|
if hasattr(self, 'has_uncertain_inputs') and self.has_uncertain_inputs():
|
||||||
X = model.X.mean
|
X = self.X.mean
|
||||||
X_variance = model.X.variance
|
X_variance = self.X.variance
|
||||||
else:
|
else:
|
||||||
X = model.X
|
X = self.X
|
||||||
X_variance = None
|
X_variance = None
|
||||||
Y = model.Y
|
Y = self.Y
|
||||||
|
|
||||||
#work out what the inputs are for plotting (1D or 2D)
|
#work out what the inputs are for plotting (1D or 2D)
|
||||||
if visible_dims is None:
|
if visible_dims is None:
|
||||||
visible_dims = np.arange(model.input_dim)
|
visible_dims = np.arange(self.input_dim)
|
||||||
assert visible_dims.size <= 2, "Visible inputs cannot be larger than two"
|
assert visible_dims.size <= 2, "Visible inputs cannot be larger than two"
|
||||||
free_dims = visible_dims
|
free_dims = visible_dims
|
||||||
plots = {}
|
plots = {}
|
||||||
|
|
@ -80,7 +81,7 @@ def plot_data(model, which_data_rows='all',
|
||||||
return plots
|
return plots
|
||||||
|
|
||||||
|
|
||||||
def plot_fit(model, plot_limits=None, which_data_rows='all',
|
def plot_fit(self, plot_limits=None, which_data_rows='all',
|
||||||
which_data_ycols='all', fixed_inputs=[],
|
which_data_ycols='all', fixed_inputs=[],
|
||||||
levels=20, samples=0, fignum=None, ax=None, resolution=None,
|
levels=20, samples=0, fignum=None, ax=None, resolution=None,
|
||||||
plot_raw=False,
|
plot_raw=False,
|
||||||
|
|
@ -98,7 +99,7 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
|
||||||
:param plot_limits: The limits of the plot. If 1D [xmin,xmax], if 2D [[xmin,ymin],[xmax,ymax]]. Defaluts to data limits
|
:param plot_limits: The limits of the plot. If 1D [xmin,xmax], if 2D [[xmin,ymin],[xmax,ymax]]. Defaluts to data limits
|
||||||
:type plot_limits: np.array
|
:type plot_limits: np.array
|
||||||
:param which_data_rows: which of the training data to plot (default all)
|
:param which_data_rows: which of the training data to plot (default all)
|
||||||
:type which_data_rows: 'all' or a slice object to slice model.X, model.Y
|
:type which_data_rows: 'all' or a slice object to slice self.X, self.Y
|
||||||
:param which_data_ycols: when the data has several columns (independant outputs), only plot these
|
:param which_data_ycols: when the data has several columns (independant outputs), only plot these
|
||||||
:type which_data_rows: 'all' or a list of integers
|
:type which_data_rows: 'all' or a list of integers
|
||||||
:param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input index i should be set to value v.
|
:param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input index i should be set to value v.
|
||||||
|
|
@ -134,98 +135,98 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
|
||||||
if which_data_rows == 'all':
|
if which_data_rows == 'all':
|
||||||
which_data_rows = slice(None)
|
which_data_rows = slice(None)
|
||||||
if which_data_ycols == 'all':
|
if which_data_ycols == 'all':
|
||||||
which_data_ycols = np.arange(model.output_dim)
|
which_data_ycols = np.arange(self.output_dim)
|
||||||
#if len(which_data_ycols)==0:
|
#if len(which_data_ycols)==0:
|
||||||
#raise ValueError('No data selected for plotting')
|
#raise ValueError('No data selected for plotting')
|
||||||
if ax is None:
|
if ax is None:
|
||||||
fig = plt.figure(num=fignum)
|
fig = plt.figure(num=fignum)
|
||||||
ax = fig.add_subplot(111)
|
ax = fig.add_subplot(111)
|
||||||
|
|
||||||
if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs():
|
if hasattr(self, 'has_uncertain_inputs') and self.has_uncertain_inputs():
|
||||||
X = model.X.mean
|
X = self.X.mean
|
||||||
X_variance = model.X.variance
|
X_variance = self.X.variance
|
||||||
else:
|
else:
|
||||||
X = model.X
|
X = self.X
|
||||||
Y = model.Y
|
Y = self.Y
|
||||||
if sparse.issparse(Y): Y = Y.todense().view(np.ndarray)
|
if sparse.issparse(Y): Y = Y.todense().view(np.ndarray)
|
||||||
|
|
||||||
if hasattr(model, 'Z'): Z = model.Z
|
if hasattr(self, 'Z'): Z = self.Z
|
||||||
|
|
||||||
if predict_kw is None:
|
if predict_kw is None:
|
||||||
predict_kw = {}
|
predict_kw = {}
|
||||||
|
|
||||||
#work out what the inputs are for plotting (1D or 2D)
|
#work out what the inputs are for plotting (1D or 2D)
|
||||||
fixed_dims = np.array([i for i,v in fixed_inputs])
|
fixed_dims = np.array([i for i,v in fixed_inputs])
|
||||||
free_dims = np.setdiff1d(np.arange(model.input_dim),fixed_dims)
|
free_dims = np.setdiff1d(np.arange(self.input_dim),fixed_dims)
|
||||||
plots = {}
|
plots = {}
|
||||||
#one dimensional plotting
|
#one dimensional plotting
|
||||||
if len(free_dims) == 1:
|
if len(free_dims) == 1:
|
||||||
|
|
||||||
#define the frame on which to plot
|
#define the frame on which to plot
|
||||||
Xnew, xmin, xmax = x_frame1D(X[:,free_dims], plot_limits=plot_limits, resolution=resolution or 200)
|
Xnew, xmin, xmax = x_frame1D(X[:,free_dims], plot_limits=plot_limits, resolution=resolution or 200)
|
||||||
Xgrid = np.empty((Xnew.shape[0],model.input_dim))
|
Xgrid = np.empty((Xnew.shape[0],self.input_dim))
|
||||||
Xgrid[:,free_dims] = Xnew
|
Xgrid[:,free_dims] = Xnew
|
||||||
for i,v in fixed_inputs:
|
for i,v in fixed_inputs:
|
||||||
Xgrid[:,i] = v
|
Xgrid[:,i] = v
|
||||||
|
|
||||||
#make a prediction on the frame and plot it
|
#make a prediction on the frame and plot it
|
||||||
if plot_raw:
|
if plot_raw:
|
||||||
m, v = model._raw_predict(Xgrid, **predict_kw)
|
m, v = self._raw_predict(Xgrid, **predict_kw)
|
||||||
if apply_link:
|
if apply_link:
|
||||||
lower = model.likelihood.gp_link.transf(m - 2*np.sqrt(v))
|
lower = self.likelihood.gp_link.transf(m - 2*np.sqrt(v))
|
||||||
upper = model.likelihood.gp_link.transf(m + 2*np.sqrt(v))
|
upper = self.likelihood.gp_link.transf(m + 2*np.sqrt(v))
|
||||||
#Once transformed this is now the median of the function
|
#Once transformed this is now the median of the function
|
||||||
m = model.likelihood.gp_link.transf(m)
|
m = self.likelihood.gp_link.transf(m)
|
||||||
else:
|
else:
|
||||||
lower = m - 2*np.sqrt(v)
|
lower = m - 2*np.sqrt(v)
|
||||||
upper = m + 2*np.sqrt(v)
|
upper = m + 2*np.sqrt(v)
|
||||||
else:
|
else:
|
||||||
if isinstance(model,GPCoregionalizedRegression) or isinstance(model,SparseGPCoregionalizedRegression):
|
if isinstance(self,GPCoregionalizedRegression) or isinstance(self,SparseGPCoregionalizedRegression):
|
||||||
extra_data = Xgrid[:,-1:].astype(np.int)
|
extra_data = Xgrid[:,-1:].astype(np.int)
|
||||||
if Y_metadata is None:
|
if Y_metadata is None:
|
||||||
Y_metadata = {'output_index': extra_data}
|
Y_metadata = {'output_index': extra_data}
|
||||||
else:
|
else:
|
||||||
Y_metadata['output_index'] = extra_data
|
Y_metadata['output_index'] = extra_data
|
||||||
m, v = model.predict(Xgrid, full_cov=False, Y_metadata=Y_metadata, **predict_kw)
|
m, v = self.predict(Xgrid, full_cov=False, Y_metadata=Y_metadata, **predict_kw)
|
||||||
fmu, fv = model._raw_predict(Xgrid, full_cov=False, **predict_kw)
|
fmu, fv = self._raw_predict(Xgrid, full_cov=False, **predict_kw)
|
||||||
lower, upper = model.likelihood.predictive_quantiles(fmu, fv, (2.5, 97.5), Y_metadata=Y_metadata)
|
lower, upper = self.likelihood.predictive_quantiles(fmu, fv, (2.5, 97.5), Y_metadata=Y_metadata)
|
||||||
|
|
||||||
|
|
||||||
for d in which_data_ycols:
|
for d in which_data_ycols:
|
||||||
plots['gpplot'] = gpplot(Xnew, m[:, d], lower[:, d], upper[:, d], ax=ax, edgecol=linecol, fillcol=fillcol)
|
plots['gpplot'] = gpplot(Xnew, m[:, d], lower[:, d], upper[:, d], ax=ax, edgecol=linecol, fillcol=fillcol)
|
||||||
#if not plot_raw: plots['dataplot'] = ax.plot(X[which_data_rows,free_dims], Y[which_data_rows, d], data_symbol, mew=1.5)
|
#if not plot_raw: plots['dataplot'] = ax.plot(X[which_data_rows,free_dims], Y[which_data_rows, d], data_symbol, mew=1.5)
|
||||||
if not plot_raw and plot_training_data:
|
if not plot_raw and plot_training_data:
|
||||||
plots['dataplot'] = plot_data(model=model, which_data_rows=which_data_rows,
|
plots['dataplot'] = plot_data(self=self, which_data_rows=which_data_rows,
|
||||||
visible_dims=free_dims, data_symbol=data_symbol, mew=1.5, ax=ax, fignum=fignum)
|
visible_dims=free_dims, data_symbol=data_symbol, mew=1.5, ax=ax, fignum=fignum)
|
||||||
|
|
||||||
|
|
||||||
#optionally plot some samples
|
#optionally plot some samples
|
||||||
if samples: #NOTE not tested with fixed_inputs
|
if samples: #NOTE not tested with fixed_inputs
|
||||||
Fsim = model.posterior_samples_f(Xgrid, samples)
|
Fsim = self.posterior_samples_f(Xgrid, samples)
|
||||||
if apply_link:
|
if apply_link:
|
||||||
Fsim = model.likelihood.gp_link.transf(Fsim)
|
Fsim = self.likelihood.gp_link.transf(Fsim)
|
||||||
for fi in Fsim.T:
|
for fi in Fsim.T:
|
||||||
plots['posterior_samples'] = ax.plot(Xnew, fi[:,None], '#3300FF', linewidth=0.25)
|
plots['posterior_samples'] = ax.plot(Xnew, fi[:,None], '#3300FF', linewidth=0.25)
|
||||||
#ax.plot(Xnew, fi[:,None], marker='x', linestyle='--',color=Tango.colorsHex['darkBlue']) #TODO apply this line for discrete outputs.
|
#ax.plot(Xnew, fi[:,None], marker='x', linestyle='--',color=Tango.colorsHex['darkBlue']) #TODO apply this line for discrete outputs.
|
||||||
|
|
||||||
if samples_y: #NOTE not tested with fixed_inputs
|
if samples_y: #NOTE not tested with fixed_inputs
|
||||||
Ysim = model.posterior_samples(Xgrid, samples_y, Y_metadata=Y_metadata)
|
Ysim = self.posterior_samples(Xgrid, samples_y, Y_metadata=Y_metadata)
|
||||||
for yi in Ysim.T:
|
for yi in Ysim.T:
|
||||||
plots['posterior_samples_y'] = ax.scatter(Xnew, yi[:,None], s=5, c=Tango.colorsHex['darkBlue'], marker='o', alpha=0.5)
|
plots['posterior_samples_y'] = ax.scatter(Xnew, yi[:,None], s=5, c=Tango.colorsHex['darkBlue'], marker='o', alpha=0.5)
|
||||||
#ax.plot(Xnew, yi[:,None], marker='x', linestyle='--',color=Tango.colorsHex['darkBlue']) #TODO apply this line for discrete outputs.
|
#ax.plot(Xnew, yi[:,None], marker='x', linestyle='--',color=Tango.colorsHex['darkBlue']) #TODO apply this line for discrete outputs.
|
||||||
|
|
||||||
|
|
||||||
#add error bars for uncertain (if input uncertainty is being modelled)
|
#add error bars for uncertain (if input uncertainty is being modelled)
|
||||||
if hasattr(model,"has_uncertain_inputs") and model.has_uncertain_inputs() and plot_uncertain_inputs:
|
if hasattr(self,"has_uncertain_inputs") and self.has_uncertain_inputs() and plot_uncertain_inputs:
|
||||||
if plot_raw:
|
if plot_raw:
|
||||||
#add error bars for uncertain (if input uncertainty is being modelled), for plot_f
|
#add error bars for uncertain (if input uncertainty is being modelled), for plot_f
|
||||||
#Hack to plot error bars on latent function, rather than on the data
|
#Hack to plot error bars on latent function, rather than on the data
|
||||||
vs = model.X.mean.values.copy()
|
vs = self.X.mean.values.copy()
|
||||||
for i,v in fixed_inputs:
|
for i,v in fixed_inputs:
|
||||||
vs[:,i] = v
|
vs[:,i] = v
|
||||||
m_X, _ = model._raw_predict(vs)
|
m_X, _ = self._raw_predict(vs)
|
||||||
if apply_link:
|
if apply_link:
|
||||||
m_X = model.likelihood.gp_link.transf(m_X)
|
m_X = self.likelihood.gp_link.transf(m_X)
|
||||||
plots['xerrorbar'] = ax.errorbar(X[which_data_rows, free_dims].flatten(), m_X[which_data_rows, which_data_ycols].flatten(),
|
plots['xerrorbar'] = ax.errorbar(X[which_data_rows, free_dims].flatten(), m_X[which_data_rows, which_data_ycols].flatten(),
|
||||||
xerr=2 * np.sqrt(X_variance[which_data_rows, free_dims].flatten()),
|
xerr=2 * np.sqrt(X_variance[which_data_rows, free_dims].flatten()),
|
||||||
ecolor='k', fmt=None, elinewidth=.5, alpha=.5)
|
ecolor='k', fmt=None, elinewidth=.5, alpha=.5)
|
||||||
|
|
@ -243,9 +244,9 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
|
||||||
pass
|
pass
|
||||||
|
|
||||||
#add inducing inputs (if a sparse model is used)
|
#add inducing inputs (if a sparse model is used)
|
||||||
if hasattr(model,"Z"):
|
if hasattr(self,"Z"):
|
||||||
#Zu = model.Z[:,free_dims] * model._Xscale[:,free_dims] + model._Xoffset[:,free_dims]
|
#Zu = self.Z[:,free_dims] * self._Xscale[:,free_dims] + self._Xoffset[:,free_dims]
|
||||||
if isinstance(model,SparseGPCoregionalizedRegression):
|
if isinstance(self,SparseGPCoregionalizedRegression):
|
||||||
Z = Z[Z[:,-1] == Y_metadata['output_index'],:]
|
Z = Z[Z[:,-1] == Y_metadata['output_index'],:]
|
||||||
Zu = Z[:,free_dims]
|
Zu = Z[:,free_dims]
|
||||||
z_height = ax.get_ylim()[0]
|
z_height = ax.get_ylim()[0]
|
||||||
|
|
@ -259,7 +260,7 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
|
||||||
#define the frame for plotting on
|
#define the frame for plotting on
|
||||||
resolution = resolution or 50
|
resolution = resolution or 50
|
||||||
Xnew, _, _, xmin, xmax = x_frame2D(X[:,free_dims], plot_limits, resolution)
|
Xnew, _, _, xmin, xmax = x_frame2D(X[:,free_dims], plot_limits, resolution)
|
||||||
Xgrid = np.empty((Xnew.shape[0],model.input_dim))
|
Xgrid = np.empty((Xnew.shape[0],self.input_dim))
|
||||||
Xgrid[:,free_dims] = Xnew
|
Xgrid[:,free_dims] = Xnew
|
||||||
for i,v in fixed_inputs:
|
for i,v in fixed_inputs:
|
||||||
Xgrid[:,i] = v
|
Xgrid[:,i] = v
|
||||||
|
|
@ -267,15 +268,15 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
|
||||||
|
|
||||||
#predict on the frame and plot
|
#predict on the frame and plot
|
||||||
if plot_raw:
|
if plot_raw:
|
||||||
m, _ = model._raw_predict(Xgrid, **predict_kw)
|
m, _ = self._raw_predict(Xgrid, **predict_kw)
|
||||||
else:
|
else:
|
||||||
if isinstance(model,GPCoregionalizedRegression) or isinstance(model,SparseGPCoregionalizedRegression):
|
if isinstance(self,GPCoregionalizedRegression) or isinstance(self,SparseGPCoregionalizedRegression):
|
||||||
extra_data = Xgrid[:,-1:].astype(np.int)
|
extra_data = Xgrid[:,-1:].astype(np.int)
|
||||||
if Y_metadata is None:
|
if Y_metadata is None:
|
||||||
Y_metadata = {'output_index': extra_data}
|
Y_metadata = {'output_index': extra_data}
|
||||||
else:
|
else:
|
||||||
Y_metadata['output_index'] = extra_data
|
Y_metadata['output_index'] = extra_data
|
||||||
m, v = model.predict(Xgrid, full_cov=False, Y_metadata=Y_metadata, **predict_kw)
|
m, v = self.predict(Xgrid, full_cov=False, Y_metadata=Y_metadata, **predict_kw)
|
||||||
for d in which_data_ycols:
|
for d in which_data_ycols:
|
||||||
m_d = m[:,d].reshape(resolution, resolution).T
|
m_d = m[:,d].reshape(resolution, resolution).T
|
||||||
plots['contour'] = ax.contour(x, y, m_d, levels, vmin=m.min(), vmax=m.max(), cmap=plt.cm.jet)
|
plots['contour'] = ax.contour(x, y, m_d, levels, vmin=m.min(), vmax=m.max(), cmap=plt.cm.jet)
|
||||||
|
|
@ -290,9 +291,9 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
|
||||||
if samples:
|
if samples:
|
||||||
warnings.warn("Samples are rather difficult to plot for 2D inputs...")
|
warnings.warn("Samples are rather difficult to plot for 2D inputs...")
|
||||||
|
|
||||||
#add inducing inputs (if a sparse model is used)
|
#add inducing inputs (if a sparse self is used)
|
||||||
if hasattr(model,"Z"):
|
if hasattr(self,"Z"):
|
||||||
#Zu = model.Z[:,free_dims] * model._Xscale[:,free_dims] + model._Xoffset[:,free_dims]
|
#Zu = self.Z[:,free_dims] * self._Xscale[:,free_dims] + self._Xoffset[:,free_dims]
|
||||||
Zu = Z[:,free_dims]
|
Zu = Z[:,free_dims]
|
||||||
plots['inducing_inputs'] = ax.plot(Zu[:,0], Zu[:,1], 'wo')
|
plots['inducing_inputs'] = ax.plot(Zu[:,0], Zu[:,1], 'wo')
|
||||||
|
|
||||||
|
|
@ -300,20 +301,41 @@ def plot_fit(model, plot_limits=None, which_data_rows='all',
|
||||||
raise NotImplementedError("Cannot define a frame with more than two input dimensions")
|
raise NotImplementedError("Cannot define a frame with more than two input dimensions")
|
||||||
return plots
|
return plots
|
||||||
|
|
||||||
def plot_density(model, levels=20, plot_limits=None, fignum=None, ax=None,
|
def plot_density(self, levels=20, plot_limits=None,
|
||||||
fixed_inputs=[], plot_raw=False, edgecolor='none', facecolor='#3465a4',
|
fixed_inputs=[], plot_raw=False, edgecolor='none', facecolor='#3465a4',
|
||||||
predict_kw=None,Y_metadata=None,
|
predict_kw=None,Y_metadata=None,
|
||||||
apply_link=False, resolution=200, **patch_kwargs):
|
apply_link=False, resolution=200, **patch_kwargs):
|
||||||
#deal with optional arguments
|
"""
|
||||||
if ax is None:
|
Plot the posterior density of the GP.
|
||||||
fig = plt.figure(num=fignum)
|
- In one dimension, the function is plotted with a shaded gradient, visualizing the density of the posterior.
|
||||||
ax = fig.add_subplot(111)
|
- Only implemented for one dimension, for higher dimensions use `plot`.
|
||||||
|
|
||||||
if hasattr(model, 'has_uncertain_inputs') and model.has_uncertain_inputs():
|
:param levels: number of levels to plot in the density plot. This is a number between 1 and 100. 1 corresponds to the normal plot_fit.
|
||||||
X = model.X.mean
|
:type levels: int
|
||||||
|
:param plot_limits: The limits of the plot. If 1D [xmin,xmax], if 2D [[xmin,ymin],[xmax,ymax]]. Defaluts to data limits
|
||||||
|
:type plot_limits: np.array
|
||||||
|
:param fixed_inputs: a list of tuple [(i,v), (i,v)...], specifying that input index i should be set to value v.
|
||||||
|
:type fixed_inputs: a list of tuples
|
||||||
|
:param resolution: the number of intervals to sample the GP on. Defaults to 200 in 1D and 50 (a 50x50 grid) in 2D
|
||||||
|
:type resolution: int
|
||||||
|
:param edgecolor: color of line to plot [Tango.colorsHex['darkBlue']]
|
||||||
|
:type edgecolor: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
|
||||||
|
:param facecolor: color of fill [Tango.colorsHex['lightBlue']]
|
||||||
|
:type facecolor: color either as Tango.colorsHex object or character ('r' is red, 'g' is green) as is standard in matplotlib
|
||||||
|
:param Y_metadata: additional data associated with Y which may be needed
|
||||||
|
:type Y_metadata: dict
|
||||||
|
:param apply_link: if there is a link function of the likelihood, plot the link(f*) rather than f*, when plotting posterior samples f
|
||||||
|
:type apply_link: boolean
|
||||||
|
:param resolution: resolution of interpolation (how many points to interpolate of the posterior).
|
||||||
|
:type resolution: int
|
||||||
|
:param: patch_kw: the keyword arguments for the patchcollection fill.
|
||||||
|
"""
|
||||||
|
#deal with optional arguments
|
||||||
|
if hasattr(self, 'has_uncertain_inputs') and self.has_uncertain_inputs():
|
||||||
|
X = self.X.mean
|
||||||
else:
|
else:
|
||||||
X = model.X
|
X = self.X
|
||||||
Y = model.Y
|
Y = self.Y
|
||||||
if sparse.issparse(Y): Y = Y.todense().view(np.ndarray)
|
if sparse.issparse(Y): Y = Y.todense().view(np.ndarray)
|
||||||
|
|
||||||
if predict_kw is None:
|
if predict_kw is None:
|
||||||
|
|
@ -321,13 +343,13 @@ def plot_density(model, levels=20, plot_limits=None, fignum=None, ax=None,
|
||||||
|
|
||||||
#work out what the inputs are for plotting (1D or 2D)
|
#work out what the inputs are for plotting (1D or 2D)
|
||||||
fixed_dims = np.array([i for i,v in fixed_inputs])
|
fixed_dims = np.array([i for i,v in fixed_inputs])
|
||||||
free_dims = np.setdiff1d(np.arange(model.input_dim),fixed_dims)
|
free_dims = np.setdiff1d(np.arange(self.input_dim),fixed_dims)
|
||||||
plots = {}
|
plots = {}
|
||||||
#one dimensional plotting
|
#one dimensional plotting
|
||||||
if len(free_dims) == 1:
|
if len(free_dims) == 1:
|
||||||
#define the frame on which to plot
|
#define the frame on which to plot
|
||||||
Xnew, xmin, xmax = x_frame1D(X[:,free_dims], plot_limits=plot_limits, resolution=resolution)
|
Xnew, xmin, xmax = x_frame1D(X[:,free_dims], plot_limits=plot_limits, resolution=resolution)
|
||||||
Xgrid = np.empty((Xnew.shape[0],model.input_dim))
|
Xgrid = np.empty((Xnew.shape[0],self.input_dim))
|
||||||
Xgrid[:,free_dims] = Xnew
|
Xgrid[:,free_dims] = Xnew
|
||||||
for i,v in fixed_inputs:
|
for i,v in fixed_inputs:
|
||||||
Xgrid[:,i] = v
|
Xgrid[:,i] = v
|
||||||
|
|
@ -340,32 +362,32 @@ def plot_density(model, levels=20, plot_limits=None, fignum=None, ax=None,
|
||||||
from ...likelihoods import Gaussian
|
from ...likelihoods import Gaussian
|
||||||
lik = Gaussian(variance=0)
|
lik = Gaussian(variance=0)
|
||||||
else:
|
else:
|
||||||
if isinstance(model,GPCoregionalizedRegression) or isinstance(model,SparseGPCoregionalizedRegression):
|
if isinstance(self,GPCoregionalizedRegression) or isinstance(self,SparseGPCoregionalizedRegression):
|
||||||
extra_data = Xgrid[:,-1:].astype(np.int)
|
extra_data = Xgrid[:,-1:].astype(np.int)
|
||||||
if Y_metadata is None:
|
if Y_metadata is None:
|
||||||
Y_metadata = {'output_index': extra_data}
|
Y_metadata = {'output_index': extra_data}
|
||||||
else:
|
else:
|
||||||
Y_metadata['output_index'] = extra_data
|
Y_metadata['output_index'] = extra_data
|
||||||
lik = None
|
lik = None
|
||||||
percentiles = [i[:, 0] for i in model.predict_quantiles(Xgrid, percs, Y_metadata=Y_metadata, likelihood=lik, **predict_kw)]
|
percentiles = [i[:, 0] for i in self.predict_quantiles(Xgrid, percs, Y_metadata=Y_metadata, likelihood=lik, **predict_kw)]
|
||||||
if apply_link:
|
if apply_link:
|
||||||
percentiles = model.likelihood.gp_link.transf(percentiles)
|
percentiles = self.likelihood.gp_link.transf(percentiles)
|
||||||
|
|
||||||
patch_kwargs['facecolor'] = facecolor
|
patch_kwargs['facecolor'] = facecolor
|
||||||
patch_kwargs['edgecolor'] = edgecolor
|
patch_kwargs['edgecolor'] = edgecolor
|
||||||
plots['density'] = plot_gradient_fill(ax, Xgrid[:, 0], percentiles, **patch_kwargs)
|
plots['density'] = gradient_fill(Xgrid[:, 0], percentiles, **patch_kwargs)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Only 1D density plottable.')
|
raise NotImplementedError('Only 1D density plottable.')
|
||||||
return plots
|
return plots
|
||||||
|
|
||||||
def plot_fit_f(model, *args, **kwargs):
|
@wraps(plot_fit)
|
||||||
"""
|
def plot_fit_f(self, plot_limits=None, which_data_rows='all',
|
||||||
Plot the GP's view of the world, where the data is normalized and before applying a likelihood.
|
which_data_ycols='all', fixed_inputs=[],
|
||||||
|
levels=20, samples=0, fignum=None, ax=None, resolution=None,
|
||||||
All args and kwargs are passed on to models_plots.plot.
|
plot_raw=True,
|
||||||
"""
|
linecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue'], Y_metadata=None, data_symbol='kx',
|
||||||
kwargs['plot_raw'] = True
|
apply_link=False, samples_y=0, plot_uncertain_inputs=True, predict_kw=None, plot_training_data=True):
|
||||||
plot_fit(model,*args, **kwargs)
|
return plot_fit(self, plot_limits, which_data_rows, which_data_ycols, fixed_inputs, levels, samples, fignum, ax, resolution, plot_raw, linecol, fillcol, Y_metadata, data_symbol, apply_link, samples_y, plot_uncertain_inputs, predict_kw, plot_training_data)
|
||||||
|
|
||||||
def fixed_inputs(model, non_fixed_inputs, fix_routine='median', as_list=True, X_all=False):
|
def fixed_inputs(model, non_fixed_inputs, fix_routine='median', as_list=True, X_all=False):
|
||||||
"""
|
"""
|
||||||
|
|
@ -465,7 +487,7 @@ def plot_errorbars_trainset(model, which_data_rows='all',
|
||||||
for d in which_data_ycols:
|
for d in which_data_ycols:
|
||||||
plots['gperrors'] = gperrors(X, m[:, d], lower[:, d], upper[:, d], edgecol=linecol, ax=ax, fignum=fignum, **kwargs )
|
plots['gperrors'] = gperrors(X, m[:, d], lower[:, d], upper[:, d], edgecol=linecol, ax=ax, fignum=fignum, **kwargs )
|
||||||
if plot_training_data:
|
if plot_training_data:
|
||||||
plots['dataplot'] = plot_data(model=model, which_data_rows=which_data_rows,
|
plots['dataplot'] = plot_data(self=model, which_data_rows=which_data_rows,
|
||||||
visible_dims=free_dims, data_symbol=data_symbol, mew=1.5, ax=ax, fignum=fignum)
|
visible_dims=free_dims, data_symbol=data_symbol, mew=1.5, ax=ax, fignum=fignum)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,331 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
# netpbmfile.py
|
|
||||||
|
|
||||||
# Copyright (c) 2011-2013, Christoph Gohlke
|
|
||||||
# Copyright (c) 2011-2013, The Regents of the University of California
|
|
||||||
# Produced at the Laboratory for Fluorescence Dynamics.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# Redistribution and use in source and binary forms, with or without
|
|
||||||
# modification, are permitted provided that the following conditions are met:
|
|
||||||
#
|
|
||||||
# * Redistributions of source code must retain the above copyright
|
|
||||||
# notice, this list of conditions and the following disclaimer.
|
|
||||||
# * Redistributions in binary form must reproduce the above copyright
|
|
||||||
# notice, this list of conditions and the following disclaimer in the
|
|
||||||
# documentation and/or other materials provided with the distribution.
|
|
||||||
# * Neither the name of the copyright holders nor the names of any
|
|
||||||
# contributors may be used to endorse or promote products derived
|
|
||||||
# from this software without specific prior written permission.
|
|
||||||
#
|
|
||||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
||||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
||||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
|
||||||
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
|
|
||||||
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
|
||||||
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
|
||||||
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
|
||||||
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
|
||||||
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
|
||||||
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
|
||||||
# POSSIBILITY OF SUCH DAMAGE.
|
|
||||||
|
|
||||||
"""Read and write image data from respectively to Netpbm files.
|
|
||||||
|
|
||||||
This implementation follows the Netpbm format specifications at
|
|
||||||
http://netpbm.sourceforge.net/doc/. No gamma correction is performed.
|
|
||||||
|
|
||||||
The following image formats are supported: PBM (bi-level), PGM (grayscale),
|
|
||||||
PPM (color), PAM (arbitrary), XV thumbnail (RGB332, read-only).
|
|
||||||
|
|
||||||
:Author:
|
|
||||||
`Christoph Gohlke <http://www.lfd.uci.edu/~gohlke/>`_
|
|
||||||
|
|
||||||
:Organization:
|
|
||||||
Laboratory for Fluorescence Dynamics, University of California, Irvine
|
|
||||||
|
|
||||||
:Version: 2013.01.18
|
|
||||||
|
|
||||||
Requirements
|
|
||||||
------------
|
|
||||||
* `CPython 2.7, 3.2 or 3.3 <http://www.python.org>`_
|
|
||||||
* `Numpy 1.7 <http://www.numpy.org>`_
|
|
||||||
* `Matplotlib 1.2 <http://www.matplotlib.org>`_ (optional for plotting)
|
|
||||||
|
|
||||||
Examples
|
|
||||||
--------
|
|
||||||
>>> im1 = numpy.array([[0, 1],[65534, 65535]], dtype=numpy.uint16)
|
|
||||||
>>> imsave('_tmp.pgm', im1)
|
|
||||||
>>> im2 = imread('_tmp.pgm')
|
|
||||||
>>> assert numpy.all(im1 == im2)
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import division, print_function
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import re
|
|
||||||
import math
|
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
import numpy
|
|
||||||
|
|
||||||
__version__ = '2013.01.18'
|
|
||||||
__docformat__ = 'restructuredtext en'
|
|
||||||
__all__ = ['imread', 'imsave', 'NetpbmFile']
|
|
||||||
|
|
||||||
|
|
||||||
def imread(filename, *args, **kwargs):
|
|
||||||
"""Return image data from Netpbm file as numpy array.
|
|
||||||
|
|
||||||
`args` and `kwargs` are arguments to NetpbmFile.asarray().
|
|
||||||
|
|
||||||
Examples
|
|
||||||
--------
|
|
||||||
>>> image = imread('_tmp.pgm')
|
|
||||||
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
netpbm = NetpbmFile(filename)
|
|
||||||
image = netpbm.asarray()
|
|
||||||
finally:
|
|
||||||
netpbm.close()
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
def imsave(filename, data, maxval=None, pam=False):
|
|
||||||
"""Write image data to Netpbm file.
|
|
||||||
|
|
||||||
Examples
|
|
||||||
--------
|
|
||||||
>>> image = numpy.array([[0, 1],[65534, 65535]], dtype=numpy.uint16)
|
|
||||||
>>> imsave('_tmp.pgm', image)
|
|
||||||
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
netpbm = NetpbmFile(data, maxval=maxval)
|
|
||||||
netpbm.write(filename, pam=pam)
|
|
||||||
finally:
|
|
||||||
netpbm.close()
|
|
||||||
|
|
||||||
|
|
||||||
class NetpbmFile(object):
|
|
||||||
"""Read and write Netpbm PAM, PBM, PGM, PPM, files."""
|
|
||||||
|
|
||||||
_types = {b'P1': b'BLACKANDWHITE', b'P2': b'GRAYSCALE', b'P3': b'RGB',
|
|
||||||
b'P4': b'BLACKANDWHITE', b'P5': b'GRAYSCALE', b'P6': b'RGB',
|
|
||||||
b'P7 332': b'RGB', b'P7': b'RGB_ALPHA'}
|
|
||||||
|
|
||||||
def __init__(self, arg=None, **kwargs):
|
|
||||||
"""Initialize instance from filename, open file, or numpy array."""
|
|
||||||
for attr in ('header', 'magicnum', 'width', 'height', 'maxval',
|
|
||||||
'depth', 'tupltypes', '_filename', '_fh', '_data'):
|
|
||||||
setattr(self, attr, None)
|
|
||||||
if arg is None:
|
|
||||||
self._fromdata([], **kwargs)
|
|
||||||
elif isinstance(arg, basestring):
|
|
||||||
self._fh = open(arg, 'rb')
|
|
||||||
self._filename = arg
|
|
||||||
self._fromfile(self._fh, **kwargs)
|
|
||||||
elif hasattr(arg, 'seek'):
|
|
||||||
self._fromfile(arg, **kwargs)
|
|
||||||
self._fh = arg
|
|
||||||
else:
|
|
||||||
self._fromdata(arg, **kwargs)
|
|
||||||
|
|
||||||
def asarray(self, copy=True, cache=False, **kwargs):
|
|
||||||
"""Return image data from file as numpy array."""
|
|
||||||
data = self._data
|
|
||||||
if data is None:
|
|
||||||
data = self._read_data(self._fh, **kwargs)
|
|
||||||
if cache:
|
|
||||||
self._data = data
|
|
||||||
else:
|
|
||||||
return data
|
|
||||||
return deepcopy(data) if copy else data
|
|
||||||
|
|
||||||
def write(self, arg, **kwargs):
|
|
||||||
"""Write instance to file."""
|
|
||||||
if hasattr(arg, 'seek'):
|
|
||||||
self._tofile(arg, **kwargs)
|
|
||||||
else:
|
|
||||||
with open(arg, 'wb') as fid:
|
|
||||||
self._tofile(fid, **kwargs)
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
"""Close open file. Future asarray calls might fail."""
|
|
||||||
if self._filename and self._fh:
|
|
||||||
self._fh.close()
|
|
||||||
self._fh = None
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
self.close()
|
|
||||||
|
|
||||||
def _fromfile(self, fh):
|
|
||||||
"""Initialize instance from open file."""
|
|
||||||
fh.seek(0)
|
|
||||||
data = fh.read(4096)
|
|
||||||
if (len(data) < 7) or not (b'0' < data[1:2] < b'8'):
|
|
||||||
raise ValueError("Not a Netpbm file:\n%s" % data[:32])
|
|
||||||
try:
|
|
||||||
self._read_pam_header(data)
|
|
||||||
except Exception:
|
|
||||||
try:
|
|
||||||
self._read_pnm_header(data)
|
|
||||||
except Exception:
|
|
||||||
raise ValueError("Not a Netpbm file:\n%s" % data[:32])
|
|
||||||
|
|
||||||
def _read_pam_header(self, data):
|
|
||||||
"""Read PAM header and initialize instance."""
|
|
||||||
regroups = re.search(
|
|
||||||
b"(^P7[\n\r]+(?:(?:[\n\r]+)|(?:#.*)|"
|
|
||||||
b"(HEIGHT\s+\d+)|(WIDTH\s+\d+)|(DEPTH\s+\d+)|(MAXVAL\s+\d+)|"
|
|
||||||
b"(?:TUPLTYPE\s+\w+))*ENDHDR\n)", data).groups()
|
|
||||||
self.header = regroups[0]
|
|
||||||
self.magicnum = b'P7'
|
|
||||||
for group in regroups[1:]:
|
|
||||||
key, value = group.split()
|
|
||||||
setattr(self, unicode(key).lower(), int(value))
|
|
||||||
matches = re.findall(b"(TUPLTYPE\s+\w+)", self.header)
|
|
||||||
self.tupltypes = [s.split(None, 1)[1] for s in matches]
|
|
||||||
|
|
||||||
def _read_pnm_header(self, data):
|
|
||||||
"""Read PNM header and initialize instance."""
|
|
||||||
bpm = data[1:2] in b"14"
|
|
||||||
regroups = re.search(b"".join((
|
|
||||||
b"(^(P[123456]|P7 332)\s+(?:#.*[\r\n])*",
|
|
||||||
b"\s*(\d+)\s+(?:#.*[\r\n])*",
|
|
||||||
b"\s*(\d+)\s+(?:#.*[\r\n])*" * (not bpm),
|
|
||||||
b"\s*(\d+)\s(?:\s*#.*[\r\n]\s)*)")), data).groups() + (1, ) * bpm
|
|
||||||
self.header = regroups[0]
|
|
||||||
self.magicnum = regroups[1]
|
|
||||||
self.width = int(regroups[2])
|
|
||||||
self.height = int(regroups[3])
|
|
||||||
self.maxval = int(regroups[4])
|
|
||||||
self.depth = 3 if self.magicnum in b"P3P6P7 332" else 1
|
|
||||||
self.tupltypes = [self._types[self.magicnum]]
|
|
||||||
|
|
||||||
def _read_data(self, fh, byteorder='>'):
|
|
||||||
"""Return image data from open file as numpy array."""
|
|
||||||
fh.seek(len(self.header))
|
|
||||||
data = fh.read()
|
|
||||||
dtype = 'u1' if self.maxval < 256 else byteorder + 'u2'
|
|
||||||
depth = 1 if self.magicnum == b"P7 332" else self.depth
|
|
||||||
shape = [-1, self.height, self.width, depth]
|
|
||||||
size = numpy.prod(shape[1:])
|
|
||||||
if self.magicnum in b"P1P2P3":
|
|
||||||
data = numpy.array(data.split(None, size)[:size], dtype)
|
|
||||||
data = data.reshape(shape)
|
|
||||||
elif self.maxval == 1:
|
|
||||||
shape[2] = int(math.ceil(self.width / 8))
|
|
||||||
data = numpy.frombuffer(data, dtype).reshape(shape)
|
|
||||||
data = numpy.unpackbits(data, axis=-2)[:, :, :self.width, :]
|
|
||||||
else:
|
|
||||||
data = numpy.frombuffer(data, dtype)
|
|
||||||
data = data[:size * (data.size // size)].reshape(shape)
|
|
||||||
if data.shape[0] < 2:
|
|
||||||
data = data.reshape(data.shape[1:])
|
|
||||||
if data.shape[-1] < 2:
|
|
||||||
data = data.reshape(data.shape[:-1])
|
|
||||||
if self.magicnum == b"P7 332":
|
|
||||||
rgb332 = numpy.array(list(numpy.ndindex(8, 8, 4)), numpy.uint8)
|
|
||||||
rgb332 *= [36, 36, 85]
|
|
||||||
data = numpy.take(rgb332, data, axis=0)
|
|
||||||
return data
|
|
||||||
|
|
||||||
def _fromdata(self, data, maxval=None):
|
|
||||||
"""Initialize instance from numpy array."""
|
|
||||||
data = numpy.array(data, ndmin=2, copy=True)
|
|
||||||
if data.dtype.kind not in "uib":
|
|
||||||
raise ValueError("not an integer type: %s" % data.dtype)
|
|
||||||
if data.dtype.kind == 'i' and numpy.min(data) < 0:
|
|
||||||
raise ValueError("data out of range: %i" % numpy.min(data))
|
|
||||||
if maxval is None:
|
|
||||||
maxval = numpy.max(data)
|
|
||||||
maxval = 255 if maxval < 256 else 65535
|
|
||||||
if maxval < 0 or maxval > 65535:
|
|
||||||
raise ValueError("data out of range: %i" % maxval)
|
|
||||||
data = data.astype('u1' if maxval < 256 else '>u2')
|
|
||||||
self._data = data
|
|
||||||
if data.ndim > 2 and data.shape[-1] in (3, 4):
|
|
||||||
self.depth = data.shape[-1]
|
|
||||||
self.width = data.shape[-2]
|
|
||||||
self.height = data.shape[-3]
|
|
||||||
self.magicnum = b'P7' if self.depth == 4 else b'P6'
|
|
||||||
else:
|
|
||||||
self.depth = 1
|
|
||||||
self.width = data.shape[-1]
|
|
||||||
self.height = data.shape[-2]
|
|
||||||
self.magicnum = b'P5' if maxval > 1 else b'P4'
|
|
||||||
self.maxval = maxval
|
|
||||||
self.tupltypes = [self._types[self.magicnum]]
|
|
||||||
self.header = self._header()
|
|
||||||
|
|
||||||
def _tofile(self, fh, pam=False):
|
|
||||||
"""Write Netbm file."""
|
|
||||||
fh.seek(0)
|
|
||||||
fh.write(self._header(pam))
|
|
||||||
data = self.asarray(copy=False)
|
|
||||||
if self.maxval == 1:
|
|
||||||
data = numpy.packbits(data, axis=-1)
|
|
||||||
data.tofile(fh)
|
|
||||||
|
|
||||||
def _header(self, pam=False):
|
|
||||||
"""Return file header as byte string."""
|
|
||||||
if pam or self.magicnum == b'P7':
|
|
||||||
header = "\n".join((
|
|
||||||
"P7",
|
|
||||||
"HEIGHT %i" % self.height,
|
|
||||||
"WIDTH %i" % self.width,
|
|
||||||
"DEPTH %i" % self.depth,
|
|
||||||
"MAXVAL %i" % self.maxval,
|
|
||||||
"\n".join("TUPLTYPE %s" % unicode(i) for i in self.tupltypes),
|
|
||||||
"ENDHDR\n"))
|
|
||||||
elif self.maxval == 1:
|
|
||||||
header = "P4 %i %i\n" % (self.width, self.height)
|
|
||||||
elif self.depth == 1:
|
|
||||||
header = "P5 %i %i %i\n" % (self.width, self.height, self.maxval)
|
|
||||||
else:
|
|
||||||
header = "P6 %i %i %i\n" % (self.width, self.height, self.maxval)
|
|
||||||
if sys.version_info[0] > 2:
|
|
||||||
header = bytes(header, 'ascii')
|
|
||||||
return header
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
"""Return information about instance."""
|
|
||||||
return unicode(self.header)
|
|
||||||
|
|
||||||
|
|
||||||
if sys.version_info[0] > 2:
|
|
||||||
basestring = str
|
|
||||||
unicode = lambda x: str(x, 'ascii')
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Show images specified on command line or all images in current directory
|
|
||||||
from glob import glob
|
|
||||||
from matplotlib import pyplot
|
|
||||||
files = sys.argv[1:] if len(sys.argv) > 1 else glob('*.p*m')
|
|
||||||
for fname in files:
|
|
||||||
try:
|
|
||||||
pam = NetpbmFile(fname)
|
|
||||||
img = pam.asarray(copy=False)
|
|
||||||
if False:
|
|
||||||
pam.write('_tmp.pgm.out', pam=True)
|
|
||||||
img2 = imread('_tmp.pgm.out')
|
|
||||||
assert numpy.all(img == img2)
|
|
||||||
imsave('_tmp.pgm.out', img)
|
|
||||||
img2 = imread('_tmp.pgm.out')
|
|
||||||
assert numpy.all(img == img2)
|
|
||||||
pam.close()
|
|
||||||
except ValueError as e:
|
|
||||||
print(fname, e)
|
|
||||||
continue
|
|
||||||
_shape = img.shape
|
|
||||||
if img.ndim > 3 or (img.ndim > 2 and img.shape[-1] not in (3, 4)):
|
|
||||||
img = img[0]
|
|
||||||
cmap = 'gray' if pam.maxval > 1 else 'binary'
|
|
||||||
pyplot.imshow(img, cmap, interpolation='nearest')
|
|
||||||
pyplot.title("%s %s %s %s" % (fname, unicode(pam.magicnum),
|
|
||||||
_shape, img.dtype))
|
|
||||||
pyplot.show()
|
|
||||||
2
setup.py
2
setup.py
|
|
@ -18,7 +18,7 @@
|
||||||
# this list of conditions and the following disclaimer in the documentation
|
# this list of conditions and the following disclaimer in the documentation
|
||||||
# and/or other materials provided with the distribution.
|
# and/or other materials provided with the distribution.
|
||||||
#
|
#
|
||||||
# * Neither the name of paramax nor the names of its
|
# * Neither the name of GPy nor the names of its
|
||||||
# contributors may be used to endorse or promote products derived from
|
# contributors may be used to endorse or promote products derived from
|
||||||
# this software without specific prior written permission.
|
# this software without specific prior written permission.
|
||||||
#
|
#
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@
|
||||||
# this list of conditions and the following disclaimer in the documentation
|
# this list of conditions and the following disclaimer in the documentation
|
||||||
# and/or other materials provided with the distribution.
|
# and/or other materials provided with the distribution.
|
||||||
#
|
#
|
||||||
# * Neither the name of paramax nor the names of its
|
# * Neither the name of GPy nor the names of its
|
||||||
# contributors may be used to endorse or promote products derived from
|
# contributors may be used to endorse or promote products derived from
|
||||||
# this software without specific prior written permission.
|
# this software without specific prior written permission.
|
||||||
#
|
#
|
||||||
|
|
@ -32,8 +32,8 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
import matplotlib
|
import matplotlib
|
||||||
matplotlib.use('svg')
|
matplotlib.use('agg')
|
||||||
|
|
||||||
import nose
|
import nose
|
||||||
nose.main('GPy', defaultTest='GPy/testing')
|
nose.main('GPy', defaultTest='GPy/testing/plotting_tests.py')
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue