Files relocated

This commit is contained in:
Ricardo 2014-01-28 13:44:10 +00:00
parent 73546f2408
commit 7782f62885
10 changed files with 0 additions and 1727 deletions

View file

@ -1,135 +0,0 @@
# #Copyright (c) 2012, GPy authors (see AUTHORS.txt).
# Licensed under the BSD 3-clause license (see LICENSE.txt)
import Tango
import pylab as pb
import numpy as np
def gpplot(x,mu,lower,upper,edgecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue'],axes=None,**kwargs):
if axes is None:
axes = pb.gca()
mu = mu.flatten()
x = x.flatten()
lower = lower.flatten()
upper = upper.flatten()
#here's the mean
axes.plot(x,mu,color=edgecol,linewidth=2)
#here's the box
kwargs['linewidth']=0.5
if not 'alpha' in kwargs.keys():
kwargs['alpha'] = 0.3
axes.fill(np.hstack((x,x[::-1])),np.hstack((upper,lower[::-1])),color=fillcol,**kwargs)
#this is the edge:
axes.plot(x,upper,color=edgecol,linewidth=0.2)
axes.plot(x,lower,color=edgecol,linewidth=0.2)
def removeRightTicks(ax=None):
ax = ax or pb.gca()
for i, line in enumerate(ax.get_yticklines()):
if i%2 == 1: # odd indices
line.set_visible(False)
def removeUpperTicks(ax=None):
ax = ax or pb.gca()
for i, line in enumerate(ax.get_xticklines()):
if i%2 == 1: # odd indices
line.set_visible(False)
def fewerXticks(ax=None,divideby=2):
ax = ax or pb.gca()
ax.set_xticks(ax.get_xticks()[::divideby])
def align_subplots(N,M,xlim=None, ylim=None):
"""make all of the subplots have the same limits, turn off unnecessary ticks"""
#find sensible xlim,ylim
if xlim is None:
xlim = [np.inf,-np.inf]
for i in range(N*M):
pb.subplot(N,M,i+1)
xlim[0] = min(xlim[0],pb.xlim()[0])
xlim[1] = max(xlim[1],pb.xlim()[1])
if ylim is None:
ylim = [np.inf,-np.inf]
for i in range(N*M):
pb.subplot(N,M,i+1)
ylim[0] = min(ylim[0],pb.ylim()[0])
ylim[1] = max(ylim[1],pb.ylim()[1])
for i in range(N*M):
pb.subplot(N,M,i+1)
pb.xlim(xlim)
pb.ylim(ylim)
if (i)%M:
pb.yticks([])
else:
removeRightTicks()
if i<(M*(N-1)):
pb.xticks([])
else:
removeUpperTicks()
def align_subplot_array(axes,xlim=None, ylim=None):
"""make all of the axes in the array hae the same limits, turn off unnecessary ticks
use pb.subplots() to get an array of axes
"""
#find sensible xlim,ylim
if xlim is None:
xlim = [np.inf,-np.inf]
for ax in axes.flatten():
xlim[0] = min(xlim[0],ax.get_xlim()[0])
xlim[1] = max(xlim[1],ax.get_xlim()[1])
if ylim is None:
ylim = [np.inf,-np.inf]
for ax in axes.flatten():
ylim[0] = min(ylim[0],ax.get_ylim()[0])
ylim[1] = max(ylim[1],ax.get_ylim()[1])
N,M = axes.shape
for i,ax in enumerate(axes.flatten()):
ax.set_xlim(xlim)
ax.set_ylim(ylim)
if (i)%M:
ax.set_yticks([])
else:
removeRightTicks(ax)
if i<(M*(N-1)):
ax.set_xticks([])
else:
removeUpperTicks(ax)
def x_frame1D(X,plot_limits=None,resolution=None):
"""
Internal helper function for making plots, returns a set of input values to plot as well as lower and upper limits
"""
assert X.shape[1] ==1, "x_frame1D is defined for one-dimensional inputs"
if plot_limits is None:
xmin,xmax = X.min(0),X.max(0)
xmin, xmax = xmin-0.2*(xmax-xmin), xmax+0.2*(xmax-xmin)
elif len(plot_limits)==2:
xmin, xmax = plot_limits
else:
raise ValueError, "Bad limits for plotting"
Xnew = np.linspace(xmin,xmax,resolution or 200)[:,None]
return Xnew, xmin, xmax
def x_frame2D(X,plot_limits=None,resolution=None):
"""
Internal helper function for making plots, returns a set of input values to plot as well as lower and upper limits
"""
assert X.shape[1] ==2, "x_frame2D is defined for two-dimensional inputs"
if plot_limits is None:
xmin,xmax = X.min(0),X.max(0)
xmin, xmax = xmin-0.2*(xmax-xmin), xmax+0.2*(xmax-xmin)
elif len(plot_limits)==2:
xmin, xmax = plot_limits
else:
raise ValueError, "Bad limits for plotting"
resolution = resolution or 50
xx,yy = np.mgrid[xmin[0]:xmax[0]:1j*resolution,xmin[1]:xmax[1]:1j*resolution]
Xnew = np.vstack((xx.flatten(),yy.flatten())).T
return Xnew, xx, yy, xmin, xmax

View file

@ -1,166 +0,0 @@
# Copyright (c) 2012, GPy authors (see AUTHORS.txt).
# Licensed under the BSD 3-clause license (see LICENSE.txt)
import matplotlib as mpl
import pylab as pb
import sys
#sys.path.append('/home/james/mlprojects/sitran_cluster/')
#from switch_pylab_backend import *
#this stuff isn;t really Tango related: maybe it could be moved out? TODO
def removeRightTicks(ax=None):
ax = ax or pb.gca()
for i, line in enumerate(ax.get_yticklines()):
if i%2 == 1: # odd indices
line.set_visible(False)
def removeUpperTicks(ax=None):
ax = ax or pb.gca()
for i, line in enumerate(ax.get_xticklines()):
if i%2 == 1: # odd indices
line.set_visible(False)
def fewerXticks(ax=None,divideby=2):
ax = ax or pb.gca()
ax.set_xticks(ax.get_xticks()[::divideby])
colorsHex = {\
"Aluminium6":"#2e3436",\
"Aluminium5":"#555753",\
"Aluminium4":"#888a85",\
"Aluminium3":"#babdb6",\
"Aluminium2":"#d3d7cf",\
"Aluminium1":"#eeeeec",\
"lightPurple":"#ad7fa8",\
"mediumPurple":"#75507b",\
"darkPurple":"#5c3566",\
"lightBlue":"#729fcf",\
"mediumBlue":"#3465a4",\
"darkBlue": "#204a87",\
"lightGreen":"#8ae234",\
"mediumGreen":"#73d216",\
"darkGreen":"#4e9a06",\
"lightChocolate":"#e9b96e",\
"mediumChocolate":"#c17d11",\
"darkChocolate":"#8f5902",\
"lightRed":"#ef2929",\
"mediumRed":"#cc0000",\
"darkRed":"#a40000",\
"lightOrange":"#fcaf3e",\
"mediumOrange":"#f57900",\
"darkOrange":"#ce5c00",\
"lightButter":"#fce94f",\
"mediumButter":"#edd400",\
"darkButter":"#c4a000"}
darkList = [colorsHex['darkBlue'],colorsHex['darkRed'],colorsHex['darkGreen'], colorsHex['darkOrange'], colorsHex['darkButter'], colorsHex['darkPurple'], colorsHex['darkChocolate'], colorsHex['Aluminium6']]
mediumList = [colorsHex['mediumBlue'], colorsHex['mediumRed'],colorsHex['mediumGreen'], colorsHex['mediumOrange'], colorsHex['mediumButter'], colorsHex['mediumPurple'], colorsHex['mediumChocolate'], colorsHex['Aluminium5']]
lightList = [colorsHex['lightBlue'], colorsHex['lightRed'],colorsHex['lightGreen'], colorsHex['lightOrange'], colorsHex['lightButter'], colorsHex['lightPurple'], colorsHex['lightChocolate'], colorsHex['Aluminium4']]
def currentDark():
return darkList[-1]
def currentMedium():
return mediumList[-1]
def currentLight():
return lightList[-1]
def nextDark():
darkList.append(darkList.pop(0))
return darkList[-1]
def nextMedium():
mediumList.append(mediumList.pop(0))
return mediumList[-1]
def nextLight():
lightList.append(lightList.pop(0))
return lightList[-1]
def reset():
while not darkList[0]==colorsHex['darkBlue']:
darkList.append(darkList.pop(0))
while not mediumList[0]==colorsHex['mediumBlue']:
mediumList.append(mediumList.pop(0))
while not lightList[0]==colorsHex['lightBlue']:
lightList.append(lightList.pop(0))
def setLightFigures():
mpl.rcParams['axes.edgecolor']=colorsHex['Aluminium6']
mpl.rcParams['axes.facecolor']=colorsHex['Aluminium2']
mpl.rcParams['axes.labelcolor']=colorsHex['Aluminium6']
mpl.rcParams['figure.edgecolor']=colorsHex['Aluminium6']
mpl.rcParams['figure.facecolor']=colorsHex['Aluminium2']
mpl.rcParams['grid.color']=colorsHex['Aluminium6']
mpl.rcParams['savefig.edgecolor']=colorsHex['Aluminium2']
mpl.rcParams['savefig.facecolor']=colorsHex['Aluminium2']
mpl.rcParams['text.color']=colorsHex['Aluminium6']
mpl.rcParams['xtick.color']=colorsHex['Aluminium6']
mpl.rcParams['ytick.color']=colorsHex['Aluminium6']
def setDarkFigures():
mpl.rcParams['axes.edgecolor']=colorsHex['Aluminium2']
mpl.rcParams['axes.facecolor']=colorsHex['Aluminium6']
mpl.rcParams['axes.labelcolor']=colorsHex['Aluminium2']
mpl.rcParams['figure.edgecolor']=colorsHex['Aluminium2']
mpl.rcParams['figure.facecolor']=colorsHex['Aluminium6']
mpl.rcParams['grid.color']=colorsHex['Aluminium2']
mpl.rcParams['savefig.edgecolor']=colorsHex['Aluminium6']
mpl.rcParams['savefig.facecolor']=colorsHex['Aluminium6']
mpl.rcParams['text.color']=colorsHex['Aluminium2']
mpl.rcParams['xtick.color']=colorsHex['Aluminium2']
mpl.rcParams['ytick.color']=colorsHex['Aluminium2']
def hex2rgb(hexcolor):
hexcolor = [hexcolor[1+2*i:1+2*(i+1)] for i in range(3)]
r,g,b = [int(n,16) for n in hexcolor]
return (r,g,b)
colorsRGB = dict([(k,hex2rgb(i)) for k,i in colorsHex.items()])
cdict_RB = {'red' :((0.,colorsRGB['mediumRed'][0]/256.,colorsRGB['mediumRed'][0]/256.),
(.5,colorsRGB['mediumPurple'][0]/256.,colorsRGB['mediumPurple'][0]/256.),
(1.,colorsRGB['mediumBlue'][0]/256.,colorsRGB['mediumBlue'][0]/256.)),
'green':((0.,colorsRGB['mediumRed'][1]/256.,colorsRGB['mediumRed'][1]/256.),
(.5,colorsRGB['mediumPurple'][1]/256.,colorsRGB['mediumPurple'][1]/256.),
(1.,colorsRGB['mediumBlue'][1]/256.,colorsRGB['mediumBlue'][1]/256.)),
'blue':((0.,colorsRGB['mediumRed'][2]/256.,colorsRGB['mediumRed'][2]/256.),
(.5,colorsRGB['mediumPurple'][2]/256.,colorsRGB['mediumPurple'][2]/256.),
(1.,colorsRGB['mediumBlue'][2]/256.,colorsRGB['mediumBlue'][2]/256.))}
cdict_BGR = {'red' :((0.,colorsRGB['mediumBlue'][0]/256.,colorsRGB['mediumBlue'][0]/256.),
(.5,colorsRGB['mediumGreen'][0]/256.,colorsRGB['mediumGreen'][0]/256.),
(1.,colorsRGB['mediumRed'][0]/256.,colorsRGB['mediumRed'][0]/256.)),
'green':((0.,colorsRGB['mediumBlue'][1]/256.,colorsRGB['mediumBlue'][1]/256.),
(.5,colorsRGB['mediumGreen'][1]/256.,colorsRGB['mediumGreen'][1]/256.),
(1.,colorsRGB['mediumRed'][1]/256.,colorsRGB['mediumRed'][1]/256.)),
'blue':((0.,colorsRGB['mediumBlue'][2]/256.,colorsRGB['mediumBlue'][2]/256.),
(.5,colorsRGB['mediumGreen'][2]/256.,colorsRGB['mediumGreen'][2]/256.),
(1.,colorsRGB['mediumRed'][2]/256.,colorsRGB['mediumRed'][2]/256.))}
cdict_Alu = {'red' :((0./5,colorsRGB['Aluminium1'][0]/256.,colorsRGB['Aluminium1'][0]/256.),
(1./5,colorsRGB['Aluminium2'][0]/256.,colorsRGB['Aluminium2'][0]/256.),
(2./5,colorsRGB['Aluminium3'][0]/256.,colorsRGB['Aluminium3'][0]/256.),
(3./5,colorsRGB['Aluminium4'][0]/256.,colorsRGB['Aluminium4'][0]/256.),
(4./5,colorsRGB['Aluminium5'][0]/256.,colorsRGB['Aluminium5'][0]/256.),
(5./5,colorsRGB['Aluminium6'][0]/256.,colorsRGB['Aluminium6'][0]/256.)),
'green' :((0./5,colorsRGB['Aluminium1'][1]/256.,colorsRGB['Aluminium1'][1]/256.),
(1./5,colorsRGB['Aluminium2'][1]/256.,colorsRGB['Aluminium2'][1]/256.),
(2./5,colorsRGB['Aluminium3'][1]/256.,colorsRGB['Aluminium3'][1]/256.),
(3./5,colorsRGB['Aluminium4'][1]/256.,colorsRGB['Aluminium4'][1]/256.),
(4./5,colorsRGB['Aluminium5'][1]/256.,colorsRGB['Aluminium5'][1]/256.),
(5./5,colorsRGB['Aluminium6'][1]/256.,colorsRGB['Aluminium6'][1]/256.)),
'blue' :((0./5,colorsRGB['Aluminium1'][2]/256.,colorsRGB['Aluminium1'][2]/256.),
(1./5,colorsRGB['Aluminium2'][2]/256.,colorsRGB['Aluminium2'][2]/256.),
(2./5,colorsRGB['Aluminium3'][2]/256.,colorsRGB['Aluminium3'][2]/256.),
(3./5,colorsRGB['Aluminium4'][2]/256.,colorsRGB['Aluminium4'][2]/256.),
(4./5,colorsRGB['Aluminium5'][2]/256.,colorsRGB['Aluminium5'][2]/256.),
(5./5,colorsRGB['Aluminium6'][2]/256.,colorsRGB['Aluminium6'][2]/256.))}
# cmap_Alu = mpl.colors.LinearSegmentedColormap('TangoAluminium',cdict_Alu,256)
# cmap_BGR = mpl.colors.LinearSegmentedColormap('TangoRedBlue',cdict_BGR,256)
# cmap_RB = mpl.colors.LinearSegmentedColormap('TangoRedBlue',cdict_RB,256)
if __name__=='__main__':
import pylab as pb
pb.figure()
pb.pcolor(pb.rand(10,10),cmap=cmap_RB)
pb.colorbar()
pb.show()

View file

@ -1 +0,0 @@
import controllers

View file

@ -1 +0,0 @@
import axis_event_controller, imshow_controller

View file

@ -1,142 +0,0 @@
'''
Created on 24 Jul 2013
@author: maxz
'''
import numpy
class AxisEventController(object):
def __init__(self, ax):
self.ax = ax
self.activate()
def deactivate(self):
for cb_class in self.ax.callbacks.callbacks.values():
for cb_num in cb_class.keys():
self.ax.callbacks.disconnect(cb_num)
def activate(self):
self.ax.callbacks.connect('xlim_changed', self.xlim_changed)
self.ax.callbacks.connect('ylim_changed', self.ylim_changed)
def xlim_changed(self, ax):
pass
def ylim_changed(self, ax):
pass
class AxisChangedController(AxisEventController):
'''
Buffered control of axis limit changes
'''
_changing = False
def __init__(self, ax, update_lim=None):
'''
Constructor
'''
super(AxisChangedController, self).__init__(ax)
self._lim_ratio_threshold = update_lim or .8
self._x_lim = self.ax.get_xlim()
self._y_lim = self.ax.get_ylim()
def update(self, ax):
pass
def xlim_changed(self, ax):
super(AxisChangedController, self).xlim_changed(ax)
if not self._changing and self.lim_changed(ax.get_xlim(), self._x_lim):
self._changing = True
self._x_lim = ax.get_xlim()
self.update(ax)
self._changing = False
def ylim_changed(self, ax):
super(AxisChangedController, self).ylim_changed(ax)
if not self._changing and self.lim_changed(ax.get_ylim(), self._y_lim):
self._changing = True
self._y_lim = ax.get_ylim()
self.update(ax)
self._changing = False
def extent(self, lim):
return numpy.subtract(*lim)
def lim_changed(self, axlim, savedlim):
axextent = self.extent(axlim)
extent = self.extent(savedlim)
lim_changed = ((axextent / extent) < self._lim_ratio_threshold ** 2
or (extent / axextent) < self._lim_ratio_threshold ** 2
or ((1 - (self.extent((axlim[0], savedlim[0])) / self.extent((savedlim[0], axlim[1]))))
< self._lim_ratio_threshold)
or ((1 - (self.extent((savedlim[0], axlim[0])) / self.extent((axlim[0], savedlim[1]))))
< self._lim_ratio_threshold)
)
return lim_changed
def _buffer_lim(self, lim):
# buffer_size = 1 - self._lim_ratio_threshold
# extent = self.extent(lim)
return lim
class BufferedAxisChangedController(AxisChangedController):
def __init__(self, ax, plot_function, plot_limits, resolution=50, update_lim=None, **kwargs):
"""
:param plot_function:
function to use for creating image for plotting (return ndarray-like)
plot_function gets called with (2D!) Xtest grid if replotting required
:type plot_function: function
:param plot_limits:
beginning plot limits [xmin, ymin, xmax, ymax]
:param kwargs: additional kwargs are for pyplot.imshow(**kwargs)
"""
super(BufferedAxisChangedController, self).__init__(ax, update_lim=update_lim)
self.plot_function = plot_function
xmin, xmax = self._x_lim # self._compute_buffered(*self._x_lim)
ymin, ymax = self._y_lim # self._compute_buffered(*self._y_lim)
self.resolution = resolution
self._not_init = False
self.view = self._init_view(self.ax, self.recompute_X(), xmin, xmax, ymin, ymax, **kwargs)
self._not_init = True
def update(self, ax):
super(BufferedAxisChangedController, self).update(ax)
if self._not_init:
xmin, xmax = self._compute_buffered(*self._x_lim)
ymin, ymax = self._compute_buffered(*self._y_lim)
self.update_view(self.view, self.recompute_X(), xmin, xmax, ymin, ymax)
def _init_view(self, ax, X, xmin, xmax, ymin, ymax):
raise NotImplementedError('return view for this controller')
def update_view(self, view, X, xmin, xmax, ymin, ymax):
raise NotImplementedError('update view given in here')
def get_grid(self):
xmin, xmax = self._compute_buffered(*self._x_lim)
ymin, ymax = self._compute_buffered(*self._y_lim)
x, y = numpy.mgrid[xmin:xmax:1j * self.resolution, ymin:ymax:1j * self.resolution]
return numpy.hstack((x.flatten()[:, None], y.flatten()[:, None]))
def recompute_X(self):
X = self.plot_function(self.get_grid())
if isinstance(X, (tuple, list)):
for x in X:
x.shape = [self.resolution, self.resolution]
x[:, :] = x.T[::-1, :]
return X
return X.reshape(self.resolution, self.resolution).T[::-1, :]
def _compute_buffered(self, mi, ma):
buffersize = self._buffersize()
size = ma - mi
return mi - (buffersize * size), ma + (buffersize * size)
def _buffersize(self):
try:
buffersize = 1. - self._lim_ratio_threshold
except:
buffersize = .4
return buffersize

View file

@ -1,71 +0,0 @@
'''
Created on 24 Jul 2013
@author: maxz
'''
from GPy.util.latent_space_visualizations.controllers.axis_event_controller import BufferedAxisChangedController
import itertools
import numpy
class ImshowController(BufferedAxisChangedController):
def __init__(self, ax, plot_function, plot_limits, resolution=50, update_lim=.5, **kwargs):
"""
:param plot_function:
function to use for creating image for plotting (return ndarray-like)
plot_function gets called with (2D!) Xtest grid if replotting required
:type plot_function: function
:param plot_limits:
beginning plot limits [xmin, ymin, xmax, ymax]
:param kwargs: additional kwargs are for pyplot.imshow(**kwargs)
"""
super(ImshowController, self).__init__(ax, plot_function, plot_limits, resolution, update_lim, **kwargs)
def _init_view(self, ax, X, xmin, xmax, ymin, ymax, **kwargs):
return ax.imshow(X, extent=(xmin, xmax,
ymin, ymax),
vmin=X.min(),
vmax=X.max(),
**kwargs)
def update_view(self, view, X, xmin, xmax, ymin, ymax):
view.set_data(X)
view.set_extent((xmin, xmax, ymin, ymax))
class ImAnnotateController(ImshowController):
def __init__(self, ax, plot_function, plot_limits, resolution=20, update_lim=.99, **kwargs):
"""
:param plot_function:
function to use for creating image for plotting (return ndarray-like)
plot_function gets called with (2D!) Xtest grid if replotting required
:type plot_function: function
:param plot_limits:
beginning plot limits [xmin, ymin, xmax, ymax]
:param text_props: kwargs for pyplot.text(**text_props)
:param kwargs: additional kwargs are for pyplot.imshow(**kwargs)
"""
super(ImAnnotateController, self).__init__(ax, plot_function, plot_limits, resolution, update_lim, **kwargs)
def _init_view(self, ax, X, xmin, xmax, ymin, ymax, text_props={}, **kwargs):
view = [super(ImAnnotateController, self)._init_view(ax, X[0], xmin, xmax, ymin, ymax, **kwargs)]
xoffset, yoffset = self._offsets(xmin, xmax, ymin, ymax)
xlin = numpy.linspace(xmin, xmax, self.resolution, endpoint=False)
ylin = numpy.linspace(ymin, ymax, self.resolution, endpoint=False)
for [i, x], [j, y] in itertools.product(enumerate(xlin), enumerate(ylin[::-1])):
view.append(ax.text(x + xoffset, y + yoffset, "{}".format(X[1][j, i]), ha='center', va='center', **text_props))
return view
def update_view(self, view, X, xmin, xmax, ymin, ymax):
super(ImAnnotateController, self).update_view(view[0], X[0], xmin, xmax, ymin, ymax)
xoffset, yoffset = self._offsets(xmin, xmax, ymin, ymax)
xlin = numpy.linspace(xmin, xmax, self.resolution, endpoint=False)
ylin = numpy.linspace(ymin, ymax, self.resolution, endpoint=False)
for [[i, x], [j, y]], text in itertools.izip(itertools.product(enumerate(xlin), enumerate(ylin[::-1])), view[1:]):
text.set_x(x + xoffset)
text.set_y(y + yoffset)
text.set_text("{}".format(X[1][j, i]))
return view
def _offsets(self, xmin, xmax, ymin, ymax):
return (xmax - xmin) / (2 * self.resolution), (ymax - ymin) / (2 * self.resolution)

View file

@ -1,161 +0,0 @@
import numpy as np
import pylab as pb
import matplotlib.patches as patches
from matplotlib.patches import Polygon
from matplotlib.collections import PatchCollection
#from matplotlib import cm
import shapefile
import re
pb.ion()
def plot(shape_records,facecolor='w',edgecolor='k',linewidths=.5, ax=None,xlims=None,ylims=None):
"""
Plot the geometry of a shapefile
:param shape_records: geometry and attributes list
:type shape_records: ShapeRecord object (output of a shapeRecords() method)
:param facecolor: color to be used to fill in polygons
:param edgecolor: color to be used for lines
:param ax: axes to plot on.
:type ax: axes handle
"""
#Axes handle
if ax is None:
fig = pb.figure()
ax = fig.add_subplot(111)
#Iterate over shape_records
for srec in shape_records:
points = np.vstack(srec.shape.points)
sparts = srec.shape.parts
par = list(sparts) + [points.shape[0]]
polygs = []
for pj in xrange(len(sparts)):
polygs.append(Polygon(points[par[pj]:par[pj+1]]))
ax.add_collection(PatchCollection(polygs,facecolor=facecolor,edgecolor=edgecolor, linewidths=linewidths))
#Plot limits
_box = np.vstack([srec.shape.bbox for srec in shape_records])
minx,miny = np.min(_box[:,:2],0)
maxx,maxy = np.max(_box[:,2:],0)
if xlims is not None:
minx,maxx = xlims
if ylims is not None:
miny,maxy = ylims
ax.set_xlim(minx,maxx)
ax.set_ylim(miny,maxy)
def string_match(sf,regex,field=2):
"""
Return the geometry and attributes of a shapefile whose fields match a regular expression given
:param sf: shapefile
:type sf: shapefile object
:regex: regular expression to match
:type regex: string
:field: field number to be matched with the regex
:type field: integer
"""
index = []
shape_records = []
for rec in enumerate(sf.shapeRecords()):
m = re.search(regex,rec[1].record[field])
if m is not None:
index.append(rec[0])
shape_records.append(rec[1])
return index,shape_records
def bbox_match(sf,bbox,inside_only=True):
"""
Return the geometry and attributes of a shapefile that lie within (or intersect) a bounding box
:param sf: shapefile
:type sf: shapefile object
:param bbox: bounding box
:type bbox: list of floats [x_min,y_min,x_max,y_max]
:inside_only: True if the objects returned are those that lie within the bbox and False if the objects returned are any that intersect the bbox
:type inside_only: Boolean
"""
A,B,C,D = bbox
index = []
shape_records = []
for rec in enumerate(sf.shapeRecords()):
a,b,c,d = rec[1].shape.bbox
if inside_only:
if A <= a and B <= b and C >= c and D >= d:
index.append(rec[0])
shape_records.append(rec[1])
else:
cond1 = A <= a and B <= b and C >= a and D >= b
cond2 = A <= c and B <= d and C >= c and D >= d
cond3 = A <= a and D >= d and C >= a and B <= d
cond4 = A <= c and D >= b and C >= c and B <= b
cond5 = a <= C and b <= B and d >= D
cond6 = c <= A and b <= B and d >= D
cond7 = d <= B and a <= A and c >= C
cond8 = b <= D and a <= A and c >= C
if cond1 or cond2 or cond3 or cond4 or cond5 or cond6 or cond7 or cond8:
index.append(rec[0])
shape_records.append(rec[1])
return index,shape_records
def plot_bbox(sf,bbox,inside_only=True):
"""
Plot the geometry of a shapefile within a bbox
:param sf: shapefile
:type sf: shapefile object
:param bbox: bounding box
:type bbox: list of floats [x_min,y_min,x_max,y_max]
:inside_only: True if the objects returned are those that lie within the bbox and False if the objects returned are any that intersect the bbox
:type inside_only: Boolean
"""
index,shape_records = bbox_match(sf,bbox,inside_only)
A,B,C,D = bbox
plot(shape_records,xlims=[bbox[0],bbox[2]],ylims=[bbox[1],bbox[3]])
def plot_string_match(sf,regex,field):
"""
Plot the geometry of a shapefile whose fields match a regular expression given
:param sf: shapefile
:type sf: shapefile object
:regex: regular expression to match
:type regex: string
:field: field number to be matched with the regex
:type field: integer
"""
index,shape_records = string_match(sf,regex,field)
plot(shape_records)
def new_shape_string(sf,name,regex,field=2,type=shapefile.POINT):
newshp = shapefile.Writer(shapeType = sf.shapeType)
newshp.autoBalance = 1
index,shape_records = string_match(sf,regex,field)
_fi = [sf.fields[j] for j in index]
for f in _fi:
newshp.field(name=f[0],fieldType=f[1],size=f[2],decimal=f[3])
_shre = shape_records
for sr in _shre:
_points = []
_parts = []
for point in sr.shape.points:
_points.append(point)
_parts.append(_points)
newshp.line(parts=_parts)
newshp.records.append(sr.record)
print len(sr.record)
newshp.save(name)
print index

View file

@ -1,331 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# netpbmfile.py
# Copyright (c) 2011-2013, Christoph Gohlke
# Copyright (c) 2011-2013, The Regents of the University of California
# Produced at the Laboratory for Fluorescence Dynamics.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the copyright holders nor the names of any
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
"""Read and write image data from respectively to Netpbm files.
This implementation follows the Netpbm format specifications at
http://netpbm.sourceforge.net/doc/. No gamma correction is performed.
The following image formats are supported: PBM (bi-level), PGM (grayscale),
PPM (color), PAM (arbitrary), XV thumbnail (RGB332, read-only).
:Author:
`Christoph Gohlke <http://www.lfd.uci.edu/~gohlke/>`_
:Organization:
Laboratory for Fluorescence Dynamics, University of California, Irvine
:Version: 2013.01.18
Requirements
------------
* `CPython 2.7, 3.2 or 3.3 <http://www.python.org>`_
* `Numpy 1.7 <http://www.numpy.org>`_
* `Matplotlib 1.2 <http://www.matplotlib.org>`_ (optional for plotting)
Examples
--------
>>> im1 = numpy.array([[0, 1],[65534, 65535]], dtype=numpy.uint16)
>>> imsave('_tmp.pgm', im1)
>>> im2 = imread('_tmp.pgm')
>>> assert numpy.all(im1 == im2)
"""
from __future__ import division, print_function
import sys
import re
import math
from copy import deepcopy
import numpy
__version__ = '2013.01.18'
__docformat__ = 'restructuredtext en'
__all__ = ['imread', 'imsave', 'NetpbmFile']
def imread(filename, *args, **kwargs):
"""Return image data from Netpbm file as numpy array.
`args` and `kwargs` are arguments to NetpbmFile.asarray().
Examples
--------
>>> image = imread('_tmp.pgm')
"""
try:
netpbm = NetpbmFile(filename)
image = netpbm.asarray()
finally:
netpbm.close()
return image
def imsave(filename, data, maxval=None, pam=False):
"""Write image data to Netpbm file.
Examples
--------
>>> image = numpy.array([[0, 1],[65534, 65535]], dtype=numpy.uint16)
>>> imsave('_tmp.pgm', image)
"""
try:
netpbm = NetpbmFile(data, maxval=maxval)
netpbm.write(filename, pam=pam)
finally:
netpbm.close()
class NetpbmFile(object):
"""Read and write Netpbm PAM, PBM, PGM, PPM, files."""
_types = {b'P1': b'BLACKANDWHITE', b'P2': b'GRAYSCALE', b'P3': b'RGB',
b'P4': b'BLACKANDWHITE', b'P5': b'GRAYSCALE', b'P6': b'RGB',
b'P7 332': b'RGB', b'P7': b'RGB_ALPHA'}
def __init__(self, arg=None, **kwargs):
"""Initialize instance from filename, open file, or numpy array."""
for attr in ('header', 'magicnum', 'width', 'height', 'maxval',
'depth', 'tupltypes', '_filename', '_fh', '_data'):
setattr(self, attr, None)
if arg is None:
self._fromdata([], **kwargs)
elif isinstance(arg, basestring):
self._fh = open(arg, 'rb')
self._filename = arg
self._fromfile(self._fh, **kwargs)
elif hasattr(arg, 'seek'):
self._fromfile(arg, **kwargs)
self._fh = arg
else:
self._fromdata(arg, **kwargs)
def asarray(self, copy=True, cache=False, **kwargs):
"""Return image data from file as numpy array."""
data = self._data
if data is None:
data = self._read_data(self._fh, **kwargs)
if cache:
self._data = data
else:
return data
return deepcopy(data) if copy else data
def write(self, arg, **kwargs):
"""Write instance to file."""
if hasattr(arg, 'seek'):
self._tofile(arg, **kwargs)
else:
with open(arg, 'wb') as fid:
self._tofile(fid, **kwargs)
def close(self):
"""Close open file. Future asarray calls might fail."""
if self._filename and self._fh:
self._fh.close()
self._fh = None
def __del__(self):
self.close()
def _fromfile(self, fh):
"""Initialize instance from open file."""
fh.seek(0)
data = fh.read(4096)
if (len(data) < 7) or not (b'0' < data[1:2] < b'8'):
raise ValueError("Not a Netpbm file:\n%s" % data[:32])
try:
self._read_pam_header(data)
except Exception:
try:
self._read_pnm_header(data)
except Exception:
raise ValueError("Not a Netpbm file:\n%s" % data[:32])
def _read_pam_header(self, data):
"""Read PAM header and initialize instance."""
regroups = re.search(
b"(^P7[\n\r]+(?:(?:[\n\r]+)|(?:#.*)|"
b"(HEIGHT\s+\d+)|(WIDTH\s+\d+)|(DEPTH\s+\d+)|(MAXVAL\s+\d+)|"
b"(?:TUPLTYPE\s+\w+))*ENDHDR\n)", data).groups()
self.header = regroups[0]
self.magicnum = b'P7'
for group in regroups[1:]:
key, value = group.split()
setattr(self, unicode(key).lower(), int(value))
matches = re.findall(b"(TUPLTYPE\s+\w+)", self.header)
self.tupltypes = [s.split(None, 1)[1] for s in matches]
def _read_pnm_header(self, data):
"""Read PNM header and initialize instance."""
bpm = data[1:2] in b"14"
regroups = re.search(b"".join((
b"(^(P[123456]|P7 332)\s+(?:#.*[\r\n])*",
b"\s*(\d+)\s+(?:#.*[\r\n])*",
b"\s*(\d+)\s+(?:#.*[\r\n])*" * (not bpm),
b"\s*(\d+)\s(?:\s*#.*[\r\n]\s)*)")), data).groups() + (1, ) * bpm
self.header = regroups[0]
self.magicnum = regroups[1]
self.width = int(regroups[2])
self.height = int(regroups[3])
self.maxval = int(regroups[4])
self.depth = 3 if self.magicnum in b"P3P6P7 332" else 1
self.tupltypes = [self._types[self.magicnum]]
def _read_data(self, fh, byteorder='>'):
"""Return image data from open file as numpy array."""
fh.seek(len(self.header))
data = fh.read()
dtype = 'u1' if self.maxval < 256 else byteorder + 'u2'
depth = 1 if self.magicnum == b"P7 332" else self.depth
shape = [-1, self.height, self.width, depth]
size = numpy.prod(shape[1:])
if self.magicnum in b"P1P2P3":
data = numpy.array(data.split(None, size)[:size], dtype)
data = data.reshape(shape)
elif self.maxval == 1:
shape[2] = int(math.ceil(self.width / 8))
data = numpy.frombuffer(data, dtype).reshape(shape)
data = numpy.unpackbits(data, axis=-2)[:, :, :self.width, :]
else:
data = numpy.frombuffer(data, dtype)
data = data[:size * (data.size // size)].reshape(shape)
if data.shape[0] < 2:
data = data.reshape(data.shape[1:])
if data.shape[-1] < 2:
data = data.reshape(data.shape[:-1])
if self.magicnum == b"P7 332":
rgb332 = numpy.array(list(numpy.ndindex(8, 8, 4)), numpy.uint8)
rgb332 *= [36, 36, 85]
data = numpy.take(rgb332, data, axis=0)
return data
def _fromdata(self, data, maxval=None):
"""Initialize instance from numpy array."""
data = numpy.array(data, ndmin=2, copy=True)
if data.dtype.kind not in "uib":
raise ValueError("not an integer type: %s" % data.dtype)
if data.dtype.kind == 'i' and numpy.min(data) < 0:
raise ValueError("data out of range: %i" % numpy.min(data))
if maxval is None:
maxval = numpy.max(data)
maxval = 255 if maxval < 256 else 65535
if maxval < 0 or maxval > 65535:
raise ValueError("data out of range: %i" % maxval)
data = data.astype('u1' if maxval < 256 else '>u2')
self._data = data
if data.ndim > 2 and data.shape[-1] in (3, 4):
self.depth = data.shape[-1]
self.width = data.shape[-2]
self.height = data.shape[-3]
self.magicnum = b'P7' if self.depth == 4 else b'P6'
else:
self.depth = 1
self.width = data.shape[-1]
self.height = data.shape[-2]
self.magicnum = b'P5' if maxval > 1 else b'P4'
self.maxval = maxval
self.tupltypes = [self._types[self.magicnum]]
self.header = self._header()
def _tofile(self, fh, pam=False):
"""Write Netbm file."""
fh.seek(0)
fh.write(self._header(pam))
data = self.asarray(copy=False)
if self.maxval == 1:
data = numpy.packbits(data, axis=-1)
data.tofile(fh)
def _header(self, pam=False):
"""Return file header as byte string."""
if pam or self.magicnum == b'P7':
header = "\n".join((
"P7",
"HEIGHT %i" % self.height,
"WIDTH %i" % self.width,
"DEPTH %i" % self.depth,
"MAXVAL %i" % self.maxval,
"\n".join("TUPLTYPE %s" % unicode(i) for i in self.tupltypes),
"ENDHDR\n"))
elif self.maxval == 1:
header = "P4 %i %i\n" % (self.width, self.height)
elif self.depth == 1:
header = "P5 %i %i %i\n" % (self.width, self.height, self.maxval)
else:
header = "P6 %i %i %i\n" % (self.width, self.height, self.maxval)
if sys.version_info[0] > 2:
header = bytes(header, 'ascii')
return header
def __str__(self):
"""Return information about instance."""
return unicode(self.header)
if sys.version_info[0] > 2:
basestring = str
unicode = lambda x: str(x, 'ascii')
if __name__ == "__main__":
# Show images specified on command line or all images in current directory
from glob import glob
from matplotlib import pyplot
files = sys.argv[1:] if len(sys.argv) > 1 else glob('*.p*m')
for fname in files:
try:
pam = NetpbmFile(fname)
img = pam.asarray(copy=False)
if False:
pam.write('_tmp.pgm.out', pam=True)
img2 = imread('_tmp.pgm.out')
assert numpy.all(img == img2)
imsave('_tmp.pgm.out', img)
img2 = imread('_tmp.pgm.out')
assert numpy.all(img == img2)
pam.close()
except ValueError as e:
print(fname, e)
continue
_shape = img.shape
if img.ndim > 3 or (img.ndim > 2 and img.shape[-1] not in (3, 4)):
img = img[0]
cmap = 'gray' if pam.maxval > 1 else 'binary'
pyplot.imshow(img, cmap, interpolation='nearest')
pyplot.title("%s %s %s %s" % (fname, unicode(pam.magicnum),
_shape, img.dtype))
pyplot.show()

View file

@ -1,181 +0,0 @@
import pylab as pb
import numpy as np
from .. import util
from GPy.util.latent_space_visualizations.controllers.imshow_controller import ImshowController
from misc import param_to_array
import itertools
def most_significant_input_dimensions(model, which_indices):
if which_indices is None:
if model.input_dim == 1:
input_1 = 0
input_2 = None
if model.input_dim == 2:
input_1, input_2 = 0, 1
else:
try:
input_1, input_2 = np.argsort(model.input_sensitivity())[::-1][:2]
except:
raise ValueError, "cannot automatically determine which dimensions to plot, please pass 'which_indices'"
else:
input_1, input_2 = which_indices
return input_1, input_2
def plot_latent(model, labels=None, which_indices=None,
resolution=50, ax=None, marker='o', s=40,
fignum=None, plot_inducing=False, legend=True,
aspect='auto', updates=False):
"""
:param labels: a np.array of size model.num_data containing labels for the points (can be number, strings, etc)
:param resolution: the resolution of the grid on which to evaluate the predictive variance
"""
if ax is None:
fig = pb.figure(num=fignum)
ax = fig.add_subplot(111)
util.plot.Tango.reset()
if labels is None:
labels = np.ones(model.num_data)
input_1, input_2 = most_significant_input_dimensions(model, which_indices)
X = param_to_array(model.X)
# first, plot the output variance as a function of the latent space
Xtest, xx, yy, xmin, xmax = util.plot.x_frame2D(X[:, [input_1, input_2]], resolution=resolution)
Xtest_full = np.zeros((Xtest.shape[0], model.X.shape[1]))
def plot_function(x):
Xtest_full[:, [input_1, input_2]] = x
mu, var, low, up = model.predict(Xtest_full)
var = var[:, :1]
return np.log(var)
view = ImshowController(ax, plot_function,
tuple(X[:, [input_1, input_2]].min(0)) + tuple(X[:, [input_1, input_2]].max(0)),
resolution, aspect=aspect, interpolation='bilinear',
cmap=pb.cm.binary)
# ax.imshow(var.reshape(resolution, resolution).T,
# extent=[xmin[0], xmax[0], xmin[1], xmax[1]], cmap=pb.cm.binary, interpolation='bilinear', origin='lower')
# make sure labels are in order of input:
ulabels = []
for lab in labels:
if not lab in ulabels:
ulabels.append(lab)
marker = itertools.cycle(list(marker))
for i, ul in enumerate(ulabels):
if type(ul) is np.string_:
this_label = ul
elif type(ul) is np.int64:
this_label = 'class %i' % ul
else:
this_label = 'class %i' % i
m = marker.next()
index = np.nonzero(labels == ul)[0]
if model.input_dim == 1:
x = X[index, input_1]
y = np.zeros(index.size)
else:
x = X[index, input_1]
y = X[index, input_2]
ax.scatter(x, y, marker=m, s=s, color=util.plot.Tango.nextMedium(), label=this_label)
ax.set_xlabel('latent dimension %i' % input_1)
ax.set_ylabel('latent dimension %i' % input_2)
if not np.all(labels == 1.) and legend:
ax.legend(loc=0, numpoints=1)
ax.set_xlim(xmin[0], xmax[0])
ax.set_ylim(xmin[1], xmax[1])
ax.grid(b=False) # remove the grid if present, it doesn't look good
ax.set_aspect('auto') # set a nice aspect ratio
if plot_inducing:
Z = param_to_array(model.Z)
ax.plot(Z[:, input_1], Z[:, input_2], '^w')
if updates:
ax.figure.canvas.show()
raw_input('Enter to continue')
return ax
def plot_magnification(model, labels=None, which_indices=None,
resolution=60, ax=None, marker='o', s=40,
fignum=None, plot_inducing=False, legend=True,
aspect='auto', updates=False):
"""
:param labels: a np.array of size model.num_data containing labels for the points (can be number, strings, etc)
:param resolution: the resolution of the grid on which to evaluate the predictive variance
"""
if ax is None:
fig = pb.figure(num=fignum)
ax = fig.add_subplot(111)
util.plot.Tango.reset()
if labels is None:
labels = np.ones(model.num_data)
input_1, input_2 = most_significant_input_dimensions(model, which_indices)
# first, plot the output variance as a function of the latent space
Xtest, xx, yy, xmin, xmax = util.plot.x_frame2D(model.X[:, [input_1, input_2]], resolution=resolution)
Xtest_full = np.zeros((Xtest.shape[0], model.X.shape[1]))
def plot_function(x):
Xtest_full[:, [input_1, input_2]] = x
mf=model.magnification(Xtest_full)
return mf
view = ImshowController(ax, plot_function,
tuple(model.X.min(0)[:, [input_1, input_2]]) + tuple(model.X.max(0)[:, [input_1, input_2]]),
resolution, aspect=aspect, interpolation='bilinear',
cmap=pb.cm.gray)
# make sure labels are in order of input:
ulabels = []
for lab in labels:
if not lab in ulabels:
ulabels.append(lab)
marker = itertools.cycle(list(marker))
for i, ul in enumerate(ulabels):
if type(ul) is np.string_:
this_label = ul
elif type(ul) is np.int64:
this_label = 'class %i' % ul
else:
this_label = 'class %i' % i
m = marker.next()
index = np.nonzero(labels == ul)[0]
if model.input_dim == 1:
x = model.X[index, input_1]
y = np.zeros(index.size)
else:
x = model.X[index, input_1]
y = model.X[index, input_2]
ax.scatter(x, y, marker=m, s=s, color=util.plot.Tango.nextMedium(), label=this_label)
ax.set_xlabel('latent dimension %i' % input_1)
ax.set_ylabel('latent dimension %i' % input_2)
if not np.all(labels == 1.) and legend:
ax.legend(loc=0, numpoints=1)
ax.set_xlim(xmin[0], xmax[0])
ax.set_ylim(xmin[1], xmax[1])
ax.grid(b=False) # remove the grid if present, it doesn't look good
ax.set_aspect('auto') # set a nice aspect ratio
if plot_inducing:
ax.plot(model.Z[:, input_1], model.Z[:, input_2], '^w')
if updates:
ax.figure.canvas.show()
raw_input('Enter to continue')
pb.title('Magnification Factor')
return ax

View file

@ -1,538 +0,0 @@
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import GPy
import numpy as np
import matplotlib as mpl
import time
import Image
try:
import visual
visual_available = True
except ImportError:
visual_available = False
class data_show:
"""
The data_show class is a base class which describes how to visualize a
particular data set. For example, motion capture data can be plotted as a
stick figure, or images are shown using imshow. This class enables latent
to data visualizations for the GP-LVM.
"""
def __init__(self, vals):
self.vals = vals.copy()
# If no axes are defined, create some.
def modify(self, vals):
raise NotImplementedError, "this needs to be implemented to use the data_show class"
def close(self):
raise NotImplementedError, "this needs to be implemented to use the data_show class"
class vpython_show(data_show):
"""
the vpython_show class is a base class for all visualization methods that use vpython to display. It is initialized with a scene. If the scene is set to None it creates a scene window.
"""
def __init__(self, vals, scene=None):
data_show.__init__(self, vals)
# If no axes are defined, create some.
if scene==None:
self.scene = visual.display(title='Data Visualization')
else:
self.scene = scene
def close(self):
self.scene.exit()
class matplotlib_show(data_show):
"""
the matplotlib_show class is a base class for all visualization methods that use matplotlib. It is initialized with an axis. If the axis is set to None it creates a figure window.
"""
def __init__(self, vals, axes=None):
data_show.__init__(self, vals)
# If no axes are defined, create some.
if axes==None:
fig = plt.figure()
self.axes = fig.add_subplot(111)
else:
self.axes = axes
def close(self):
plt.close(self.axes.get_figure())
class vector_show(matplotlib_show):
"""
A base visualization class that just shows a data vector as a plot of
vector elements alongside their indices.
"""
def __init__(self, vals, axes=None):
matplotlib_show.__init__(self, vals, axes)
self.handle = self.axes.plot(np.arange(0, len(vals))[:, None], self.vals.T)[0]
def modify(self, vals):
self.vals = vals.copy()
xdata, ydata = self.handle.get_data()
self.handle.set_data(xdata, self.vals.T)
self.axes.figure.canvas.draw()
class lvm(matplotlib_show):
def __init__(self, vals, model, data_visualize, latent_axes=None, sense_axes=None, latent_index=[0,1]):
"""Visualize a latent variable model
:param model: the latent variable model to visualize.
:param data_visualize: the object used to visualize the data which has been modelled.
:type data_visualize: visualize.data_show type.
:param latent_axes: the axes where the latent visualization should be plotted.
"""
if vals == None:
vals = model.X[0]
matplotlib_show.__init__(self, vals, axes=latent_axes)
if isinstance(latent_axes,mpl.axes.Axes):
self.cid = latent_axes.figure.canvas.mpl_connect('button_press_event', self.on_click)
self.cid = latent_axes.figure.canvas.mpl_connect('motion_notify_event', self.on_move)
self.cid = latent_axes.figure.canvas.mpl_connect('axes_leave_event', self.on_leave)
self.cid = latent_axes.figure.canvas.mpl_connect('axes_enter_event', self.on_enter)
else:
self.cid = latent_axes[0].figure.canvas.mpl_connect('button_press_event', self.on_click)
self.cid = latent_axes[0].figure.canvas.mpl_connect('motion_notify_event', self.on_move)
self.cid = latent_axes[0].figure.canvas.mpl_connect('axes_leave_event', self.on_leave)
self.cid = latent_axes[0].figure.canvas.mpl_connect('axes_enter_event', self.on_enter)
self.data_visualize = data_visualize
self.model = model
self.latent_axes = latent_axes
self.sense_axes = sense_axes
self.called = False
self.move_on = False
self.latent_index = latent_index
self.latent_dim = model.input_dim
# The red cross which shows current latent point.
self.latent_values = vals
self.latent_handle = self.latent_axes.plot([0],[0],'rx',mew=2)[0]
self.modify(vals)
self.show_sensitivities()
def modify(self, vals):
"""When latent values are modified update the latent representation and ulso update the output visualization."""
self.vals = vals.copy()
y = self.model.predict(self.vals)[0]
self.data_visualize.modify(y)
self.latent_handle.set_data(self.vals[self.latent_index[0]], self.vals[self.latent_index[1]])
self.axes.figure.canvas.draw()
def on_enter(self,event):
pass
def on_leave(self,event):
pass
def on_click(self, event):
print 'click!'
if event.inaxes!=self.latent_axes: return
self.move_on = not self.move_on
self.called = True
def on_move(self, event):
if event.inaxes!=self.latent_axes: return
if self.called and self.move_on:
# Call modify code on move
self.latent_values[self.latent_index[0]]=event.xdata
self.latent_values[self.latent_index[1]]=event.ydata
self.modify(self.latent_values)
def show_sensitivities(self):
# A click in the bar chart axis for selection a dimension.
if self.sense_axes != None:
self.sense_axes.cla()
self.sense_axes.bar(np.arange(self.model.input_dim), self.model.input_sensitivity(), color='b')
if self.latent_index[1] == self.latent_index[0]:
self.sense_axes.bar(np.array(self.latent_index[0]), self.model.input_sensitivity()[self.latent_index[0]], color='y')
self.sense_axes.bar(np.array(self.latent_index[1]), self.model.input_sensitivity()[self.latent_index[1]], color='y')
else:
self.sense_axes.bar(np.array(self.latent_index[0]), self.model.input_sensitivity()[self.latent_index[0]], color='g')
self.sense_axes.bar(np.array(self.latent_index[1]), self.model.input_sensitivity()[self.latent_index[1]], color='r')
self.sense_axes.figure.canvas.draw()
class lvm_subplots(lvm):
"""
latent_axes is a np array of dimension np.ceil(input_dim/2),
one for each pair of the latent dimensions.
"""
def __init__(self, vals, Model, data_visualize, latent_axes=None, sense_axes=None):
self.nplots = int(np.ceil(Model.input_dim/2.))+1
assert len(latent_axes)==self.nplots
if vals==None:
vals = Model.X[0, :]
self.latent_values = vals
for i, axis in enumerate(latent_axes):
if i == self.nplots-1:
if self.nplots*2!=Model.input_dim:
latent_index = [i*2, i*2]
lvm.__init__(self, self.latent_vals, Model, data_visualize, axis, sense_axes, latent_index=latent_index)
else:
latent_index = [i*2, i*2+1]
lvm.__init__(self, self.latent_vals, Model, data_visualize, axis, latent_index=latent_index)
class lvm_dimselect(lvm):
"""
A visualizer for latent variable models which allows selection of the latent dimensions to use by clicking on a bar chart of their length scales.
For an example of the visualizer's use try:
GPy.examples.dimensionality_reduction.BGPVLM_oil()
"""
def __init__(self, vals, model, data_visualize, latent_axes=None, sense_axes=None, latent_index=[0, 1], labels=None):
if latent_axes==None and sense_axes==None:
self.fig,(latent_axes,self.sense_axes) = plt.subplots(1,2)
elif sense_axes==None:
fig=plt.figure()
self.sense_axes = fig.add_subplot(111)
else:
self.sense_axes = sense_axes
self.labels = labels
lvm.__init__(self,vals,model,data_visualize,latent_axes,sense_axes,latent_index)
self.show_sensitivities()
print "use left and right mouse butons to select dimensions"
def on_click(self, event):
if event.inaxes==self.sense_axes:
new_index = max(0,min(int(np.round(event.xdata-0.5)),self.model.input_dim-1))
if event.button == 1:
# Make it red if and y-axis (red=port=left) if it is a left button click
self.latent_index[1] = new_index
else:
# Make it green and x-axis (green=starboard=right) if it is a right button click
self.latent_index[0] = new_index
self.show_sensitivities()
self.latent_axes.cla()
self.model.plot_latent(which_indices=self.latent_index,
ax=self.latent_axes, labels=self.labels)
self.latent_handle = self.latent_axes.plot([0],[0],'rx',mew=2)[0]
self.modify(self.latent_values)
elif event.inaxes==self.latent_axes:
self.move_on = not self.move_on
self.called = True
def on_leave(self,event):
latent_values = self.latent_values.copy()
y = self.model.predict(latent_values[None,:])[0]
self.data_visualize.modify(y)
class image_show(matplotlib_show):
"""Show a data vector as an image. This visualizer rehapes the output vector and displays it as an image.
:param vals: the values of the output to display.
:type vals: ndarray
:param axes: the axes to show the output on.
:type vals: axes handle
:param dimensions: the dimensions that the image needs to be transposed to for display.
:type dimensions: tuple
:param transpose: whether to transpose the image before display.
:type bool: default is False.
:param order: whether array is in Fortan ordering ('F') or Python ordering ('C'). Default is python ('C').
:type order: string
:param invert: whether to invert the pixels or not (default False).
:type invert: bool
:param palette: a palette to use for the image.
:param preset_mean: the preset mean of a scaled image.
:type preset_mean: double
:param preset_std: the preset standard deviation of a scaled image.
:type preset_std: double"""
def __init__(self, vals, axes=None, dimensions=(16,16), transpose=False, order='C', invert=False, scale=False, palette=[], preset_mean = 0., preset_std = -1., select_image=0):
matplotlib_show.__init__(self, vals, axes)
self.dimensions = dimensions
self.transpose = transpose
self.order = order
self.invert = invert
self.scale = scale
self.palette = palette
self.preset_mean = preset_mean
self.preset_std = preset_std
self.select_image = select_image # This is used when the y vector contains multiple images concatenated.
self.set_image(self.vals)
if not self.palette == []: # Can just show the image (self.set_image() took care of setting the palette)
self.handle = self.axes.imshow(self.vals, interpolation='nearest')
else: # Use a boring gray map.
self.handle = self.axes.imshow(self.vals, cmap=plt.cm.gray, interpolation='nearest') # @UndefinedVariable
plt.show()
def modify(self, vals):
self.set_image(vals.copy())
self.handle.set_array(self.vals)
self.axes.figure.canvas.draw()
def set_image(self, vals):
dim = self.dimensions[0] * self.dimensions[1]
num_images = np.sqrt(vals[0,].size/dim)
if num_images > 1 and num_images.is_integer(): # Show a mosaic of images
num_images = np.int(num_images)
self.vals = np.zeros((self.dimensions[0]*num_images, self.dimensions[1]*num_images))
for iR in range(num_images):
for iC in range(num_images):
cur_img_id = iR*num_images + iC
cur_img = np.reshape(vals[0,dim*cur_img_id+np.array(range(dim))], self.dimensions, order=self.order)
first_row = iR*self.dimensions[0]
last_row = (iR+1)*self.dimensions[0]
first_col = iC*self.dimensions[1]
last_col = (iC+1)*self.dimensions[1]
self.vals[first_row:last_row, first_col:last_col] = cur_img
else:
self.vals = np.reshape(vals[0,dim*self.select_image+np.array(range(dim))], self.dimensions, order=self.order)
if self.transpose:
self.vals = self.vals.T
# if not self.scale:
# self.vals = self.vals
if self.invert:
self.vals = -self.vals
# un-normalizing, for visualisation purposes:
if self.preset_std >= 0: # The Mean is assumed to be in the range (0,255)
self.vals = self.vals*self.preset_std + self.preset_mean
# Clipping the values:
self.vals[self.vals < 0] = 0
self.vals[self.vals > 255] = 255
else:
self.vals = 255*(self.vals - self.vals.min())/(self.vals.max() - self.vals.min())
if not self.palette == []: # applying using an image palette (e.g. if the image has been quantized)
self.vals = Image.fromarray(self.vals.astype('uint8'))
self.vals.putpalette(self.palette) # palette is a list, must be loaded before calling this function
class mocap_data_show_vpython(vpython_show):
"""Base class for visualizing motion capture data using visual module."""
def __init__(self, vals, scene=None, connect=None, radius=0.1):
vpython_show.__init__(self, vals, scene)
self.radius = radius
self.connect = connect
self.process_values()
self.draw_edges()
self.draw_vertices()
def draw_vertices(self):
self.spheres = []
for i in range(self.vals.shape[0]):
self.spheres.append(visual.sphere(pos=(self.vals[i, 0], self.vals[i, 2], self.vals[i, 1]), radius=self.radius))
self.scene.visible=True
def draw_edges(self):
self.rods = []
self.line_handle = []
if not self.connect==None:
self.I, self.J = np.nonzero(self.connect)
for i, j in zip(self.I, self.J):
pos, axis = self.pos_axis(i, j)
self.rods.append(visual.cylinder(pos=pos, axis=axis, radius=self.radius))
def modify_vertices(self):
for i in range(self.vals.shape[0]):
self.spheres[i].pos = (self.vals[i, 0], self.vals[i, 2], self.vals[i, 1])
def modify_edges(self):
self.line_handle = []
if not self.connect==None:
self.I, self.J = np.nonzero(self.connect)
for rod, i, j in zip(self.rods, self.I, self.J):
rod.pos, rod.axis = self.pos_axis(i, j)
def pos_axis(self, i, j):
pos = []
axis = []
pos.append(self.vals[i, 0])
axis.append(self.vals[j, 0]-self.vals[i,0])
pos.append(self.vals[i, 2])
axis.append(self.vals[j, 2]-self.vals[i,2])
pos.append(self.vals[i, 1])
axis.append(self.vals[j, 1]-self.vals[i,1])
return pos, axis
def modify(self, vals):
self.vals = vals.copy()
self.process_values()
self.modify_edges()
self.modify_vertices()
def process_values(self):
raise NotImplementedError, "this needs to be implemented to use the data_show class"
class mocap_data_show(matplotlib_show):
"""Base class for visualizing motion capture data."""
def __init__(self, vals, axes=None, connect=None):
if axes==None:
fig = plt.figure()
axes = fig.add_subplot(111, projection='3d')
matplotlib_show.__init__(self, vals, axes)
self.connect = connect
self.process_values()
self.initialize_axes()
self.draw_vertices()
self.finalize_axes()
self.draw_edges()
self.axes.figure.canvas.draw()
def draw_vertices(self):
self.points_handle = self.axes.scatter(self.vals[:, 0], self.vals[:, 1], self.vals[:, 2])
def draw_edges(self):
self.line_handle = []
if not self.connect==None:
x = []
y = []
z = []
self.I, self.J = np.nonzero(self.connect)
for i, j in zip(self.I, self.J):
x.append(self.vals[i, 0])
x.append(self.vals[j, 0])
x.append(np.NaN)
y.append(self.vals[i, 1])
y.append(self.vals[j, 1])
y.append(np.NaN)
z.append(self.vals[i, 2])
z.append(self.vals[j, 2])
z.append(np.NaN)
self.line_handle = self.axes.plot(np.array(x), np.array(y), np.array(z), 'b-')
def modify(self, vals):
self.vals = vals.copy()
self.process_values()
self.initialize_axes_modify()
self.draw_vertices()
self.finalize_axes_modify()
self.draw_edges()
self.axes.figure.canvas.draw()
def process_values(self):
raise NotImplementedError, "this needs to be implemented to use the data_show class"
def initialize_axes(self):
"""Set up the axes with the right limits and scaling."""
self.x_lim = np.array([self.vals[:, 0].min(), self.vals[:, 0].max()])
self.y_lim = np.array([self.vals[:, 1].min(), self.vals[:, 1].max()])
self.z_lim = np.array([self.vals[:, 2].min(), self.vals[:, 2].max()])
def initialize_axes_modify(self):
self.points_handle.remove()
self.line_handle[0].remove()
def finalize_axes(self):
self.axes.set_xlim(self.x_lim)
self.axes.set_ylim(self.y_lim)
self.axes.set_zlim(self.z_lim)
self.axes.auto_scale_xyz([-1., 1.], [-1., 1.], [-1.5, 1.5])
#self.axes.set_aspect('equal')
self.axes.autoscale(enable=False)
def finalize_axes_modify(self):
self.axes.set_xlim(self.x_lim)
self.axes.set_ylim(self.y_lim)
self.axes.set_zlim(self.z_lim)
class stick_show(mocap_data_show):
"""Show a three dimensional point cloud as a figure. Connect elements of the figure together using the matrix connect."""
def __init__(self, vals, connect=None, axes=None):
mocap_data_show.__init__(self, vals, axes=axes, connect=connect)
def process_values(self):
self.vals = self.vals.reshape((3, self.vals.shape[1]/3)).T
class skeleton_show(mocap_data_show):
"""data_show class for visualizing motion capture data encoded as a skeleton with angles."""
def __init__(self, vals, skel, axes=None, padding=0):
"""data_show class for visualizing motion capture data encoded as a skeleton with angles.
:param vals: set of modeled angles to use for printing in the axis when it's first created.
:type vals: np.array
:param skel: skeleton object that has the parameters of the motion capture skeleton associated with it.
:type skel: mocap.skeleton object
:param padding:
:type int
"""
self.skel = skel
self.padding = padding
connect = skel.connection_matrix()
mocap_data_show.__init__(self, vals, axes=axes, connect=connect)
def process_values(self):
"""Takes a set of angles and converts them to the x,y,z coordinates in the internal prepresentation of the class, ready for plotting.
:param vals: the values that are being modelled."""
if self.padding>0:
channels = np.zeros((self.vals.shape[0], self.vals.shape[1]+self.padding))
channels[:, 0:self.vals.shape[0]] = self.vals
else:
channels = self.vals
vals_mat = self.skel.to_xyz(channels.flatten())
self.vals = np.zeros_like(vals_mat)
# Flip the Y and Z axes
self.vals[:, 0] = vals_mat[:, 0].copy()
self.vals[:, 1] = vals_mat[:, 2].copy()
self.vals[:, 2] = vals_mat[:, 1].copy()
def wrap_around(self, lim, connect):
quot = lim[1] - lim[0]
self.vals = rem(self.vals, quot)+lim[0]
nVals = floor(self.vals/quot)
for i in range(connect.shape[0]):
for j in find(connect[i, :]):
if nVals[i] != nVals[j]:
connect[i, j] = False
return connect
def data_play(Y, visualizer, frame_rate=30):
"""Play a data set using the data_show object given.
:Y: the data set to be visualized.
:param visualizer: the data show objectwhether to display during optimisation
:type visualizer: data_show
Example usage:
This example loads in the CMU mocap database (http://mocap.cs.cmu.edu) subject number 35 motion number 01. It then plays it using the mocap_show visualize object.
.. code-block:: python
data = GPy.util.datasets.cmu_mocap(subject='35', train_motions=['01'])
Y = data['Y']
Y[:, 0:3] = 0. # Make figure walk in place
visualize = GPy.util.visualize.skeleton_show(Y[0, :], data['skel'])
GPy.util.visualize.data_play(Y, visualize)
"""
for y in Y:
visualizer.modify(y[None, :])
time.sleep(1./float(frame_rate))