mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 12:32:40 +02:00
Files relocated
This commit is contained in:
parent
73546f2408
commit
7782f62885
10 changed files with 0 additions and 1727 deletions
|
|
@ -1,135 +0,0 @@
|
|||
# #Copyright (c) 2012, GPy authors (see AUTHORS.txt).
|
||||
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||
|
||||
|
||||
import Tango
|
||||
import pylab as pb
|
||||
import numpy as np
|
||||
|
||||
def gpplot(x,mu,lower,upper,edgecol=Tango.colorsHex['darkBlue'],fillcol=Tango.colorsHex['lightBlue'],axes=None,**kwargs):
|
||||
if axes is None:
|
||||
axes = pb.gca()
|
||||
mu = mu.flatten()
|
||||
x = x.flatten()
|
||||
lower = lower.flatten()
|
||||
upper = upper.flatten()
|
||||
|
||||
#here's the mean
|
||||
axes.plot(x,mu,color=edgecol,linewidth=2)
|
||||
|
||||
#here's the box
|
||||
kwargs['linewidth']=0.5
|
||||
if not 'alpha' in kwargs.keys():
|
||||
kwargs['alpha'] = 0.3
|
||||
axes.fill(np.hstack((x,x[::-1])),np.hstack((upper,lower[::-1])),color=fillcol,**kwargs)
|
||||
|
||||
#this is the edge:
|
||||
axes.plot(x,upper,color=edgecol,linewidth=0.2)
|
||||
axes.plot(x,lower,color=edgecol,linewidth=0.2)
|
||||
|
||||
def removeRightTicks(ax=None):
|
||||
ax = ax or pb.gca()
|
||||
for i, line in enumerate(ax.get_yticklines()):
|
||||
if i%2 == 1: # odd indices
|
||||
line.set_visible(False)
|
||||
def removeUpperTicks(ax=None):
|
||||
ax = ax or pb.gca()
|
||||
for i, line in enumerate(ax.get_xticklines()):
|
||||
if i%2 == 1: # odd indices
|
||||
line.set_visible(False)
|
||||
def fewerXticks(ax=None,divideby=2):
|
||||
ax = ax or pb.gca()
|
||||
ax.set_xticks(ax.get_xticks()[::divideby])
|
||||
|
||||
def align_subplots(N,M,xlim=None, ylim=None):
|
||||
"""make all of the subplots have the same limits, turn off unnecessary ticks"""
|
||||
#find sensible xlim,ylim
|
||||
if xlim is None:
|
||||
xlim = [np.inf,-np.inf]
|
||||
for i in range(N*M):
|
||||
pb.subplot(N,M,i+1)
|
||||
xlim[0] = min(xlim[0],pb.xlim()[0])
|
||||
xlim[1] = max(xlim[1],pb.xlim()[1])
|
||||
if ylim is None:
|
||||
ylim = [np.inf,-np.inf]
|
||||
for i in range(N*M):
|
||||
pb.subplot(N,M,i+1)
|
||||
ylim[0] = min(ylim[0],pb.ylim()[0])
|
||||
ylim[1] = max(ylim[1],pb.ylim()[1])
|
||||
|
||||
for i in range(N*M):
|
||||
pb.subplot(N,M,i+1)
|
||||
pb.xlim(xlim)
|
||||
pb.ylim(ylim)
|
||||
if (i)%M:
|
||||
pb.yticks([])
|
||||
else:
|
||||
removeRightTicks()
|
||||
if i<(M*(N-1)):
|
||||
pb.xticks([])
|
||||
else:
|
||||
removeUpperTicks()
|
||||
|
||||
def align_subplot_array(axes,xlim=None, ylim=None):
|
||||
"""make all of the axes in the array hae the same limits, turn off unnecessary ticks
|
||||
|
||||
use pb.subplots() to get an array of axes
|
||||
"""
|
||||
#find sensible xlim,ylim
|
||||
if xlim is None:
|
||||
xlim = [np.inf,-np.inf]
|
||||
for ax in axes.flatten():
|
||||
xlim[0] = min(xlim[0],ax.get_xlim()[0])
|
||||
xlim[1] = max(xlim[1],ax.get_xlim()[1])
|
||||
if ylim is None:
|
||||
ylim = [np.inf,-np.inf]
|
||||
for ax in axes.flatten():
|
||||
ylim[0] = min(ylim[0],ax.get_ylim()[0])
|
||||
ylim[1] = max(ylim[1],ax.get_ylim()[1])
|
||||
|
||||
N,M = axes.shape
|
||||
for i,ax in enumerate(axes.flatten()):
|
||||
ax.set_xlim(xlim)
|
||||
ax.set_ylim(ylim)
|
||||
if (i)%M:
|
||||
ax.set_yticks([])
|
||||
else:
|
||||
removeRightTicks(ax)
|
||||
if i<(M*(N-1)):
|
||||
ax.set_xticks([])
|
||||
else:
|
||||
removeUpperTicks(ax)
|
||||
|
||||
def x_frame1D(X,plot_limits=None,resolution=None):
|
||||
"""
|
||||
Internal helper function for making plots, returns a set of input values to plot as well as lower and upper limits
|
||||
"""
|
||||
assert X.shape[1] ==1, "x_frame1D is defined for one-dimensional inputs"
|
||||
if plot_limits is None:
|
||||
xmin,xmax = X.min(0),X.max(0)
|
||||
xmin, xmax = xmin-0.2*(xmax-xmin), xmax+0.2*(xmax-xmin)
|
||||
elif len(plot_limits)==2:
|
||||
xmin, xmax = plot_limits
|
||||
else:
|
||||
raise ValueError, "Bad limits for plotting"
|
||||
|
||||
Xnew = np.linspace(xmin,xmax,resolution or 200)[:,None]
|
||||
return Xnew, xmin, xmax
|
||||
|
||||
def x_frame2D(X,plot_limits=None,resolution=None):
|
||||
"""
|
||||
Internal helper function for making plots, returns a set of input values to plot as well as lower and upper limits
|
||||
"""
|
||||
assert X.shape[1] ==2, "x_frame2D is defined for two-dimensional inputs"
|
||||
if plot_limits is None:
|
||||
xmin,xmax = X.min(0),X.max(0)
|
||||
xmin, xmax = xmin-0.2*(xmax-xmin), xmax+0.2*(xmax-xmin)
|
||||
elif len(plot_limits)==2:
|
||||
xmin, xmax = plot_limits
|
||||
else:
|
||||
raise ValueError, "Bad limits for plotting"
|
||||
|
||||
resolution = resolution or 50
|
||||
xx,yy = np.mgrid[xmin[0]:xmax[0]:1j*resolution,xmin[1]:xmax[1]:1j*resolution]
|
||||
Xnew = np.vstack((xx.flatten(),yy.flatten())).T
|
||||
return Xnew, xx, yy, xmin, xmax
|
||||
|
|
@ -1,166 +0,0 @@
|
|||
# Copyright (c) 2012, GPy authors (see AUTHORS.txt).
|
||||
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||
|
||||
|
||||
import matplotlib as mpl
|
||||
import pylab as pb
|
||||
import sys
|
||||
#sys.path.append('/home/james/mlprojects/sitran_cluster/')
|
||||
#from switch_pylab_backend import *
|
||||
|
||||
|
||||
#this stuff isn;t really Tango related: maybe it could be moved out? TODO
|
||||
def removeRightTicks(ax=None):
|
||||
ax = ax or pb.gca()
|
||||
for i, line in enumerate(ax.get_yticklines()):
|
||||
if i%2 == 1: # odd indices
|
||||
line.set_visible(False)
|
||||
def removeUpperTicks(ax=None):
|
||||
ax = ax or pb.gca()
|
||||
for i, line in enumerate(ax.get_xticklines()):
|
||||
if i%2 == 1: # odd indices
|
||||
line.set_visible(False)
|
||||
def fewerXticks(ax=None,divideby=2):
|
||||
ax = ax or pb.gca()
|
||||
ax.set_xticks(ax.get_xticks()[::divideby])
|
||||
|
||||
|
||||
colorsHex = {\
|
||||
"Aluminium6":"#2e3436",\
|
||||
"Aluminium5":"#555753",\
|
||||
"Aluminium4":"#888a85",\
|
||||
"Aluminium3":"#babdb6",\
|
||||
"Aluminium2":"#d3d7cf",\
|
||||
"Aluminium1":"#eeeeec",\
|
||||
"lightPurple":"#ad7fa8",\
|
||||
"mediumPurple":"#75507b",\
|
||||
"darkPurple":"#5c3566",\
|
||||
"lightBlue":"#729fcf",\
|
||||
"mediumBlue":"#3465a4",\
|
||||
"darkBlue": "#204a87",\
|
||||
"lightGreen":"#8ae234",\
|
||||
"mediumGreen":"#73d216",\
|
||||
"darkGreen":"#4e9a06",\
|
||||
"lightChocolate":"#e9b96e",\
|
||||
"mediumChocolate":"#c17d11",\
|
||||
"darkChocolate":"#8f5902",\
|
||||
"lightRed":"#ef2929",\
|
||||
"mediumRed":"#cc0000",\
|
||||
"darkRed":"#a40000",\
|
||||
"lightOrange":"#fcaf3e",\
|
||||
"mediumOrange":"#f57900",\
|
||||
"darkOrange":"#ce5c00",\
|
||||
"lightButter":"#fce94f",\
|
||||
"mediumButter":"#edd400",\
|
||||
"darkButter":"#c4a000"}
|
||||
|
||||
darkList = [colorsHex['darkBlue'],colorsHex['darkRed'],colorsHex['darkGreen'], colorsHex['darkOrange'], colorsHex['darkButter'], colorsHex['darkPurple'], colorsHex['darkChocolate'], colorsHex['Aluminium6']]
|
||||
mediumList = [colorsHex['mediumBlue'], colorsHex['mediumRed'],colorsHex['mediumGreen'], colorsHex['mediumOrange'], colorsHex['mediumButter'], colorsHex['mediumPurple'], colorsHex['mediumChocolate'], colorsHex['Aluminium5']]
|
||||
lightList = [colorsHex['lightBlue'], colorsHex['lightRed'],colorsHex['lightGreen'], colorsHex['lightOrange'], colorsHex['lightButter'], colorsHex['lightPurple'], colorsHex['lightChocolate'], colorsHex['Aluminium4']]
|
||||
|
||||
def currentDark():
|
||||
return darkList[-1]
|
||||
def currentMedium():
|
||||
return mediumList[-1]
|
||||
def currentLight():
|
||||
return lightList[-1]
|
||||
|
||||
def nextDark():
|
||||
darkList.append(darkList.pop(0))
|
||||
return darkList[-1]
|
||||
def nextMedium():
|
||||
mediumList.append(mediumList.pop(0))
|
||||
return mediumList[-1]
|
||||
def nextLight():
|
||||
lightList.append(lightList.pop(0))
|
||||
return lightList[-1]
|
||||
|
||||
def reset():
|
||||
while not darkList[0]==colorsHex['darkBlue']:
|
||||
darkList.append(darkList.pop(0))
|
||||
while not mediumList[0]==colorsHex['mediumBlue']:
|
||||
mediumList.append(mediumList.pop(0))
|
||||
while not lightList[0]==colorsHex['lightBlue']:
|
||||
lightList.append(lightList.pop(0))
|
||||
|
||||
def setLightFigures():
|
||||
mpl.rcParams['axes.edgecolor']=colorsHex['Aluminium6']
|
||||
mpl.rcParams['axes.facecolor']=colorsHex['Aluminium2']
|
||||
mpl.rcParams['axes.labelcolor']=colorsHex['Aluminium6']
|
||||
mpl.rcParams['figure.edgecolor']=colorsHex['Aluminium6']
|
||||
mpl.rcParams['figure.facecolor']=colorsHex['Aluminium2']
|
||||
mpl.rcParams['grid.color']=colorsHex['Aluminium6']
|
||||
mpl.rcParams['savefig.edgecolor']=colorsHex['Aluminium2']
|
||||
mpl.rcParams['savefig.facecolor']=colorsHex['Aluminium2']
|
||||
mpl.rcParams['text.color']=colorsHex['Aluminium6']
|
||||
mpl.rcParams['xtick.color']=colorsHex['Aluminium6']
|
||||
mpl.rcParams['ytick.color']=colorsHex['Aluminium6']
|
||||
|
||||
def setDarkFigures():
|
||||
mpl.rcParams['axes.edgecolor']=colorsHex['Aluminium2']
|
||||
mpl.rcParams['axes.facecolor']=colorsHex['Aluminium6']
|
||||
mpl.rcParams['axes.labelcolor']=colorsHex['Aluminium2']
|
||||
mpl.rcParams['figure.edgecolor']=colorsHex['Aluminium2']
|
||||
mpl.rcParams['figure.facecolor']=colorsHex['Aluminium6']
|
||||
mpl.rcParams['grid.color']=colorsHex['Aluminium2']
|
||||
mpl.rcParams['savefig.edgecolor']=colorsHex['Aluminium6']
|
||||
mpl.rcParams['savefig.facecolor']=colorsHex['Aluminium6']
|
||||
mpl.rcParams['text.color']=colorsHex['Aluminium2']
|
||||
mpl.rcParams['xtick.color']=colorsHex['Aluminium2']
|
||||
mpl.rcParams['ytick.color']=colorsHex['Aluminium2']
|
||||
|
||||
def hex2rgb(hexcolor):
|
||||
hexcolor = [hexcolor[1+2*i:1+2*(i+1)] for i in range(3)]
|
||||
r,g,b = [int(n,16) for n in hexcolor]
|
||||
return (r,g,b)
|
||||
|
||||
colorsRGB = dict([(k,hex2rgb(i)) for k,i in colorsHex.items()])
|
||||
|
||||
cdict_RB = {'red' :((0.,colorsRGB['mediumRed'][0]/256.,colorsRGB['mediumRed'][0]/256.),
|
||||
(.5,colorsRGB['mediumPurple'][0]/256.,colorsRGB['mediumPurple'][0]/256.),
|
||||
(1.,colorsRGB['mediumBlue'][0]/256.,colorsRGB['mediumBlue'][0]/256.)),
|
||||
'green':((0.,colorsRGB['mediumRed'][1]/256.,colorsRGB['mediumRed'][1]/256.),
|
||||
(.5,colorsRGB['mediumPurple'][1]/256.,colorsRGB['mediumPurple'][1]/256.),
|
||||
(1.,colorsRGB['mediumBlue'][1]/256.,colorsRGB['mediumBlue'][1]/256.)),
|
||||
'blue':((0.,colorsRGB['mediumRed'][2]/256.,colorsRGB['mediumRed'][2]/256.),
|
||||
(.5,colorsRGB['mediumPurple'][2]/256.,colorsRGB['mediumPurple'][2]/256.),
|
||||
(1.,colorsRGB['mediumBlue'][2]/256.,colorsRGB['mediumBlue'][2]/256.))}
|
||||
|
||||
cdict_BGR = {'red' :((0.,colorsRGB['mediumBlue'][0]/256.,colorsRGB['mediumBlue'][0]/256.),
|
||||
(.5,colorsRGB['mediumGreen'][0]/256.,colorsRGB['mediumGreen'][0]/256.),
|
||||
(1.,colorsRGB['mediumRed'][0]/256.,colorsRGB['mediumRed'][0]/256.)),
|
||||
'green':((0.,colorsRGB['mediumBlue'][1]/256.,colorsRGB['mediumBlue'][1]/256.),
|
||||
(.5,colorsRGB['mediumGreen'][1]/256.,colorsRGB['mediumGreen'][1]/256.),
|
||||
(1.,colorsRGB['mediumRed'][1]/256.,colorsRGB['mediumRed'][1]/256.)),
|
||||
'blue':((0.,colorsRGB['mediumBlue'][2]/256.,colorsRGB['mediumBlue'][2]/256.),
|
||||
(.5,colorsRGB['mediumGreen'][2]/256.,colorsRGB['mediumGreen'][2]/256.),
|
||||
(1.,colorsRGB['mediumRed'][2]/256.,colorsRGB['mediumRed'][2]/256.))}
|
||||
|
||||
|
||||
cdict_Alu = {'red' :((0./5,colorsRGB['Aluminium1'][0]/256.,colorsRGB['Aluminium1'][0]/256.),
|
||||
(1./5,colorsRGB['Aluminium2'][0]/256.,colorsRGB['Aluminium2'][0]/256.),
|
||||
(2./5,colorsRGB['Aluminium3'][0]/256.,colorsRGB['Aluminium3'][0]/256.),
|
||||
(3./5,colorsRGB['Aluminium4'][0]/256.,colorsRGB['Aluminium4'][0]/256.),
|
||||
(4./5,colorsRGB['Aluminium5'][0]/256.,colorsRGB['Aluminium5'][0]/256.),
|
||||
(5./5,colorsRGB['Aluminium6'][0]/256.,colorsRGB['Aluminium6'][0]/256.)),
|
||||
'green' :((0./5,colorsRGB['Aluminium1'][1]/256.,colorsRGB['Aluminium1'][1]/256.),
|
||||
(1./5,colorsRGB['Aluminium2'][1]/256.,colorsRGB['Aluminium2'][1]/256.),
|
||||
(2./5,colorsRGB['Aluminium3'][1]/256.,colorsRGB['Aluminium3'][1]/256.),
|
||||
(3./5,colorsRGB['Aluminium4'][1]/256.,colorsRGB['Aluminium4'][1]/256.),
|
||||
(4./5,colorsRGB['Aluminium5'][1]/256.,colorsRGB['Aluminium5'][1]/256.),
|
||||
(5./5,colorsRGB['Aluminium6'][1]/256.,colorsRGB['Aluminium6'][1]/256.)),
|
||||
'blue' :((0./5,colorsRGB['Aluminium1'][2]/256.,colorsRGB['Aluminium1'][2]/256.),
|
||||
(1./5,colorsRGB['Aluminium2'][2]/256.,colorsRGB['Aluminium2'][2]/256.),
|
||||
(2./5,colorsRGB['Aluminium3'][2]/256.,colorsRGB['Aluminium3'][2]/256.),
|
||||
(3./5,colorsRGB['Aluminium4'][2]/256.,colorsRGB['Aluminium4'][2]/256.),
|
||||
(4./5,colorsRGB['Aluminium5'][2]/256.,colorsRGB['Aluminium5'][2]/256.),
|
||||
(5./5,colorsRGB['Aluminium6'][2]/256.,colorsRGB['Aluminium6'][2]/256.))}
|
||||
# cmap_Alu = mpl.colors.LinearSegmentedColormap('TangoAluminium',cdict_Alu,256)
|
||||
# cmap_BGR = mpl.colors.LinearSegmentedColormap('TangoRedBlue',cdict_BGR,256)
|
||||
# cmap_RB = mpl.colors.LinearSegmentedColormap('TangoRedBlue',cdict_RB,256)
|
||||
if __name__=='__main__':
|
||||
import pylab as pb
|
||||
pb.figure()
|
||||
pb.pcolor(pb.rand(10,10),cmap=cmap_RB)
|
||||
pb.colorbar()
|
||||
pb.show()
|
||||
|
|
@ -1 +0,0 @@
|
|||
import controllers
|
||||
|
|
@ -1 +0,0 @@
|
|||
import axis_event_controller, imshow_controller
|
||||
|
|
@ -1,142 +0,0 @@
|
|||
'''
|
||||
Created on 24 Jul 2013
|
||||
|
||||
@author: maxz
|
||||
'''
|
||||
import numpy
|
||||
|
||||
class AxisEventController(object):
|
||||
def __init__(self, ax):
|
||||
self.ax = ax
|
||||
self.activate()
|
||||
def deactivate(self):
|
||||
for cb_class in self.ax.callbacks.callbacks.values():
|
||||
for cb_num in cb_class.keys():
|
||||
self.ax.callbacks.disconnect(cb_num)
|
||||
def activate(self):
|
||||
self.ax.callbacks.connect('xlim_changed', self.xlim_changed)
|
||||
self.ax.callbacks.connect('ylim_changed', self.ylim_changed)
|
||||
def xlim_changed(self, ax):
|
||||
pass
|
||||
def ylim_changed(self, ax):
|
||||
pass
|
||||
|
||||
|
||||
class AxisChangedController(AxisEventController):
|
||||
'''
|
||||
Buffered control of axis limit changes
|
||||
'''
|
||||
_changing = False
|
||||
|
||||
def __init__(self, ax, update_lim=None):
|
||||
'''
|
||||
Constructor
|
||||
'''
|
||||
super(AxisChangedController, self).__init__(ax)
|
||||
self._lim_ratio_threshold = update_lim or .8
|
||||
self._x_lim = self.ax.get_xlim()
|
||||
self._y_lim = self.ax.get_ylim()
|
||||
|
||||
def update(self, ax):
|
||||
pass
|
||||
|
||||
def xlim_changed(self, ax):
|
||||
super(AxisChangedController, self).xlim_changed(ax)
|
||||
if not self._changing and self.lim_changed(ax.get_xlim(), self._x_lim):
|
||||
self._changing = True
|
||||
self._x_lim = ax.get_xlim()
|
||||
self.update(ax)
|
||||
self._changing = False
|
||||
|
||||
def ylim_changed(self, ax):
|
||||
super(AxisChangedController, self).ylim_changed(ax)
|
||||
if not self._changing and self.lim_changed(ax.get_ylim(), self._y_lim):
|
||||
self._changing = True
|
||||
self._y_lim = ax.get_ylim()
|
||||
self.update(ax)
|
||||
self._changing = False
|
||||
|
||||
def extent(self, lim):
|
||||
return numpy.subtract(*lim)
|
||||
|
||||
def lim_changed(self, axlim, savedlim):
|
||||
axextent = self.extent(axlim)
|
||||
extent = self.extent(savedlim)
|
||||
lim_changed = ((axextent / extent) < self._lim_ratio_threshold ** 2
|
||||
or (extent / axextent) < self._lim_ratio_threshold ** 2
|
||||
or ((1 - (self.extent((axlim[0], savedlim[0])) / self.extent((savedlim[0], axlim[1]))))
|
||||
< self._lim_ratio_threshold)
|
||||
or ((1 - (self.extent((savedlim[0], axlim[0])) / self.extent((axlim[0], savedlim[1]))))
|
||||
< self._lim_ratio_threshold)
|
||||
)
|
||||
return lim_changed
|
||||
|
||||
def _buffer_lim(self, lim):
|
||||
# buffer_size = 1 - self._lim_ratio_threshold
|
||||
# extent = self.extent(lim)
|
||||
return lim
|
||||
|
||||
|
||||
class BufferedAxisChangedController(AxisChangedController):
|
||||
def __init__(self, ax, plot_function, plot_limits, resolution=50, update_lim=None, **kwargs):
|
||||
"""
|
||||
:param plot_function:
|
||||
function to use for creating image for plotting (return ndarray-like)
|
||||
plot_function gets called with (2D!) Xtest grid if replotting required
|
||||
:type plot_function: function
|
||||
:param plot_limits:
|
||||
beginning plot limits [xmin, ymin, xmax, ymax]
|
||||
|
||||
:param kwargs: additional kwargs are for pyplot.imshow(**kwargs)
|
||||
"""
|
||||
super(BufferedAxisChangedController, self).__init__(ax, update_lim=update_lim)
|
||||
self.plot_function = plot_function
|
||||
xmin, xmax = self._x_lim # self._compute_buffered(*self._x_lim)
|
||||
ymin, ymax = self._y_lim # self._compute_buffered(*self._y_lim)
|
||||
self.resolution = resolution
|
||||
self._not_init = False
|
||||
self.view = self._init_view(self.ax, self.recompute_X(), xmin, xmax, ymin, ymax, **kwargs)
|
||||
self._not_init = True
|
||||
|
||||
def update(self, ax):
|
||||
super(BufferedAxisChangedController, self).update(ax)
|
||||
if self._not_init:
|
||||
xmin, xmax = self._compute_buffered(*self._x_lim)
|
||||
ymin, ymax = self._compute_buffered(*self._y_lim)
|
||||
self.update_view(self.view, self.recompute_X(), xmin, xmax, ymin, ymax)
|
||||
|
||||
def _init_view(self, ax, X, xmin, xmax, ymin, ymax):
|
||||
raise NotImplementedError('return view for this controller')
|
||||
|
||||
def update_view(self, view, X, xmin, xmax, ymin, ymax):
|
||||
raise NotImplementedError('update view given in here')
|
||||
|
||||
def get_grid(self):
|
||||
xmin, xmax = self._compute_buffered(*self._x_lim)
|
||||
ymin, ymax = self._compute_buffered(*self._y_lim)
|
||||
x, y = numpy.mgrid[xmin:xmax:1j * self.resolution, ymin:ymax:1j * self.resolution]
|
||||
return numpy.hstack((x.flatten()[:, None], y.flatten()[:, None]))
|
||||
|
||||
def recompute_X(self):
|
||||
X = self.plot_function(self.get_grid())
|
||||
if isinstance(X, (tuple, list)):
|
||||
for x in X:
|
||||
x.shape = [self.resolution, self.resolution]
|
||||
x[:, :] = x.T[::-1, :]
|
||||
return X
|
||||
return X.reshape(self.resolution, self.resolution).T[::-1, :]
|
||||
|
||||
def _compute_buffered(self, mi, ma):
|
||||
buffersize = self._buffersize()
|
||||
size = ma - mi
|
||||
return mi - (buffersize * size), ma + (buffersize * size)
|
||||
|
||||
def _buffersize(self):
|
||||
try:
|
||||
buffersize = 1. - self._lim_ratio_threshold
|
||||
except:
|
||||
buffersize = .4
|
||||
return buffersize
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,71 +0,0 @@
|
|||
'''
|
||||
Created on 24 Jul 2013
|
||||
|
||||
@author: maxz
|
||||
'''
|
||||
from GPy.util.latent_space_visualizations.controllers.axis_event_controller import BufferedAxisChangedController
|
||||
import itertools
|
||||
import numpy
|
||||
|
||||
|
||||
class ImshowController(BufferedAxisChangedController):
|
||||
def __init__(self, ax, plot_function, plot_limits, resolution=50, update_lim=.5, **kwargs):
|
||||
"""
|
||||
:param plot_function:
|
||||
function to use for creating image for plotting (return ndarray-like)
|
||||
plot_function gets called with (2D!) Xtest grid if replotting required
|
||||
:type plot_function: function
|
||||
:param plot_limits:
|
||||
beginning plot limits [xmin, ymin, xmax, ymax]
|
||||
|
||||
:param kwargs: additional kwargs are for pyplot.imshow(**kwargs)
|
||||
"""
|
||||
super(ImshowController, self).__init__(ax, plot_function, plot_limits, resolution, update_lim, **kwargs)
|
||||
|
||||
def _init_view(self, ax, X, xmin, xmax, ymin, ymax, **kwargs):
|
||||
return ax.imshow(X, extent=(xmin, xmax,
|
||||
ymin, ymax),
|
||||
vmin=X.min(),
|
||||
vmax=X.max(),
|
||||
**kwargs)
|
||||
|
||||
def update_view(self, view, X, xmin, xmax, ymin, ymax):
|
||||
view.set_data(X)
|
||||
view.set_extent((xmin, xmax, ymin, ymax))
|
||||
|
||||
class ImAnnotateController(ImshowController):
|
||||
def __init__(self, ax, plot_function, plot_limits, resolution=20, update_lim=.99, **kwargs):
|
||||
"""
|
||||
:param plot_function:
|
||||
function to use for creating image for plotting (return ndarray-like)
|
||||
plot_function gets called with (2D!) Xtest grid if replotting required
|
||||
:type plot_function: function
|
||||
:param plot_limits:
|
||||
beginning plot limits [xmin, ymin, xmax, ymax]
|
||||
:param text_props: kwargs for pyplot.text(**text_props)
|
||||
:param kwargs: additional kwargs are for pyplot.imshow(**kwargs)
|
||||
"""
|
||||
super(ImAnnotateController, self).__init__(ax, plot_function, plot_limits, resolution, update_lim, **kwargs)
|
||||
|
||||
def _init_view(self, ax, X, xmin, xmax, ymin, ymax, text_props={}, **kwargs):
|
||||
view = [super(ImAnnotateController, self)._init_view(ax, X[0], xmin, xmax, ymin, ymax, **kwargs)]
|
||||
xoffset, yoffset = self._offsets(xmin, xmax, ymin, ymax)
|
||||
xlin = numpy.linspace(xmin, xmax, self.resolution, endpoint=False)
|
||||
ylin = numpy.linspace(ymin, ymax, self.resolution, endpoint=False)
|
||||
for [i, x], [j, y] in itertools.product(enumerate(xlin), enumerate(ylin[::-1])):
|
||||
view.append(ax.text(x + xoffset, y + yoffset, "{}".format(X[1][j, i]), ha='center', va='center', **text_props))
|
||||
return view
|
||||
|
||||
def update_view(self, view, X, xmin, xmax, ymin, ymax):
|
||||
super(ImAnnotateController, self).update_view(view[0], X[0], xmin, xmax, ymin, ymax)
|
||||
xoffset, yoffset = self._offsets(xmin, xmax, ymin, ymax)
|
||||
xlin = numpy.linspace(xmin, xmax, self.resolution, endpoint=False)
|
||||
ylin = numpy.linspace(ymin, ymax, self.resolution, endpoint=False)
|
||||
for [[i, x], [j, y]], text in itertools.izip(itertools.product(enumerate(xlin), enumerate(ylin[::-1])), view[1:]):
|
||||
text.set_x(x + xoffset)
|
||||
text.set_y(y + yoffset)
|
||||
text.set_text("{}".format(X[1][j, i]))
|
||||
return view
|
||||
|
||||
def _offsets(self, xmin, xmax, ymin, ymax):
|
||||
return (xmax - xmin) / (2 * self.resolution), (ymax - ymin) / (2 * self.resolution)
|
||||
161
GPy/util/maps.py
161
GPy/util/maps.py
|
|
@ -1,161 +0,0 @@
|
|||
import numpy as np
|
||||
import pylab as pb
|
||||
import matplotlib.patches as patches
|
||||
from matplotlib.patches import Polygon
|
||||
from matplotlib.collections import PatchCollection
|
||||
#from matplotlib import cm
|
||||
import shapefile
|
||||
import re
|
||||
|
||||
pb.ion()
|
||||
|
||||
def plot(shape_records,facecolor='w',edgecolor='k',linewidths=.5, ax=None,xlims=None,ylims=None):
|
||||
"""
|
||||
Plot the geometry of a shapefile
|
||||
|
||||
:param shape_records: geometry and attributes list
|
||||
:type shape_records: ShapeRecord object (output of a shapeRecords() method)
|
||||
:param facecolor: color to be used to fill in polygons
|
||||
:param edgecolor: color to be used for lines
|
||||
:param ax: axes to plot on.
|
||||
:type ax: axes handle
|
||||
"""
|
||||
#Axes handle
|
||||
if ax is None:
|
||||
fig = pb.figure()
|
||||
ax = fig.add_subplot(111)
|
||||
|
||||
#Iterate over shape_records
|
||||
for srec in shape_records:
|
||||
points = np.vstack(srec.shape.points)
|
||||
sparts = srec.shape.parts
|
||||
par = list(sparts) + [points.shape[0]]
|
||||
|
||||
polygs = []
|
||||
for pj in xrange(len(sparts)):
|
||||
polygs.append(Polygon(points[par[pj]:par[pj+1]]))
|
||||
ax.add_collection(PatchCollection(polygs,facecolor=facecolor,edgecolor=edgecolor, linewidths=linewidths))
|
||||
|
||||
#Plot limits
|
||||
_box = np.vstack([srec.shape.bbox for srec in shape_records])
|
||||
minx,miny = np.min(_box[:,:2],0)
|
||||
maxx,maxy = np.max(_box[:,2:],0)
|
||||
|
||||
if xlims is not None:
|
||||
minx,maxx = xlims
|
||||
if ylims is not None:
|
||||
miny,maxy = ylims
|
||||
ax.set_xlim(minx,maxx)
|
||||
ax.set_ylim(miny,maxy)
|
||||
|
||||
|
||||
def string_match(sf,regex,field=2):
|
||||
"""
|
||||
Return the geometry and attributes of a shapefile whose fields match a regular expression given
|
||||
|
||||
:param sf: shapefile
|
||||
:type sf: shapefile object
|
||||
:regex: regular expression to match
|
||||
:type regex: string
|
||||
:field: field number to be matched with the regex
|
||||
:type field: integer
|
||||
"""
|
||||
index = []
|
||||
shape_records = []
|
||||
for rec in enumerate(sf.shapeRecords()):
|
||||
m = re.search(regex,rec[1].record[field])
|
||||
if m is not None:
|
||||
index.append(rec[0])
|
||||
shape_records.append(rec[1])
|
||||
return index,shape_records
|
||||
|
||||
def bbox_match(sf,bbox,inside_only=True):
|
||||
"""
|
||||
Return the geometry and attributes of a shapefile that lie within (or intersect) a bounding box
|
||||
|
||||
:param sf: shapefile
|
||||
:type sf: shapefile object
|
||||
:param bbox: bounding box
|
||||
:type bbox: list of floats [x_min,y_min,x_max,y_max]
|
||||
:inside_only: True if the objects returned are those that lie within the bbox and False if the objects returned are any that intersect the bbox
|
||||
:type inside_only: Boolean
|
||||
"""
|
||||
A,B,C,D = bbox
|
||||
index = []
|
||||
shape_records = []
|
||||
for rec in enumerate(sf.shapeRecords()):
|
||||
a,b,c,d = rec[1].shape.bbox
|
||||
if inside_only:
|
||||
if A <= a and B <= b and C >= c and D >= d:
|
||||
index.append(rec[0])
|
||||
shape_records.append(rec[1])
|
||||
else:
|
||||
cond1 = A <= a and B <= b and C >= a and D >= b
|
||||
cond2 = A <= c and B <= d and C >= c and D >= d
|
||||
cond3 = A <= a and D >= d and C >= a and B <= d
|
||||
cond4 = A <= c and D >= b and C >= c and B <= b
|
||||
cond5 = a <= C and b <= B and d >= D
|
||||
cond6 = c <= A and b <= B and d >= D
|
||||
cond7 = d <= B and a <= A and c >= C
|
||||
cond8 = b <= D and a <= A and c >= C
|
||||
if cond1 or cond2 or cond3 or cond4 or cond5 or cond6 or cond7 or cond8:
|
||||
index.append(rec[0])
|
||||
shape_records.append(rec[1])
|
||||
return index,shape_records
|
||||
|
||||
|
||||
def plot_bbox(sf,bbox,inside_only=True):
|
||||
"""
|
||||
Plot the geometry of a shapefile within a bbox
|
||||
|
||||
:param sf: shapefile
|
||||
:type sf: shapefile object
|
||||
:param bbox: bounding box
|
||||
:type bbox: list of floats [x_min,y_min,x_max,y_max]
|
||||
:inside_only: True if the objects returned are those that lie within the bbox and False if the objects returned are any that intersect the bbox
|
||||
:type inside_only: Boolean
|
||||
"""
|
||||
index,shape_records = bbox_match(sf,bbox,inside_only)
|
||||
A,B,C,D = bbox
|
||||
plot(shape_records,xlims=[bbox[0],bbox[2]],ylims=[bbox[1],bbox[3]])
|
||||
|
||||
def plot_string_match(sf,regex,field):
|
||||
"""
|
||||
Plot the geometry of a shapefile whose fields match a regular expression given
|
||||
|
||||
:param sf: shapefile
|
||||
:type sf: shapefile object
|
||||
:regex: regular expression to match
|
||||
:type regex: string
|
||||
:field: field number to be matched with the regex
|
||||
:type field: integer
|
||||
"""
|
||||
index,shape_records = string_match(sf,regex,field)
|
||||
plot(shape_records)
|
||||
|
||||
|
||||
def new_shape_string(sf,name,regex,field=2,type=shapefile.POINT):
|
||||
|
||||
newshp = shapefile.Writer(shapeType = sf.shapeType)
|
||||
newshp.autoBalance = 1
|
||||
|
||||
index,shape_records = string_match(sf,regex,field)
|
||||
|
||||
_fi = [sf.fields[j] for j in index]
|
||||
for f in _fi:
|
||||
newshp.field(name=f[0],fieldType=f[1],size=f[2],decimal=f[3])
|
||||
|
||||
_shre = shape_records
|
||||
for sr in _shre:
|
||||
_points = []
|
||||
_parts = []
|
||||
for point in sr.shape.points:
|
||||
_points.append(point)
|
||||
_parts.append(_points)
|
||||
|
||||
newshp.line(parts=_parts)
|
||||
newshp.records.append(sr.record)
|
||||
print len(sr.record)
|
||||
|
||||
newshp.save(name)
|
||||
print index
|
||||
|
|
@ -1,331 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# netpbmfile.py
|
||||
|
||||
# Copyright (c) 2011-2013, Christoph Gohlke
|
||||
# Copyright (c) 2011-2013, The Regents of the University of California
|
||||
# Produced at the Laboratory for Fluorescence Dynamics.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
# * Neither the name of the copyright holders nor the names of any
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
|
||||
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
# POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
"""Read and write image data from respectively to Netpbm files.
|
||||
|
||||
This implementation follows the Netpbm format specifications at
|
||||
http://netpbm.sourceforge.net/doc/. No gamma correction is performed.
|
||||
|
||||
The following image formats are supported: PBM (bi-level), PGM (grayscale),
|
||||
PPM (color), PAM (arbitrary), XV thumbnail (RGB332, read-only).
|
||||
|
||||
:Author:
|
||||
`Christoph Gohlke <http://www.lfd.uci.edu/~gohlke/>`_
|
||||
|
||||
:Organization:
|
||||
Laboratory for Fluorescence Dynamics, University of California, Irvine
|
||||
|
||||
:Version: 2013.01.18
|
||||
|
||||
Requirements
|
||||
------------
|
||||
* `CPython 2.7, 3.2 or 3.3 <http://www.python.org>`_
|
||||
* `Numpy 1.7 <http://www.numpy.org>`_
|
||||
* `Matplotlib 1.2 <http://www.matplotlib.org>`_ (optional for plotting)
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> im1 = numpy.array([[0, 1],[65534, 65535]], dtype=numpy.uint16)
|
||||
>>> imsave('_tmp.pgm', im1)
|
||||
>>> im2 = imread('_tmp.pgm')
|
||||
>>> assert numpy.all(im1 == im2)
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import division, print_function
|
||||
|
||||
import sys
|
||||
import re
|
||||
import math
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy
|
||||
|
||||
__version__ = '2013.01.18'
|
||||
__docformat__ = 'restructuredtext en'
|
||||
__all__ = ['imread', 'imsave', 'NetpbmFile']
|
||||
|
||||
|
||||
def imread(filename, *args, **kwargs):
|
||||
"""Return image data from Netpbm file as numpy array.
|
||||
|
||||
`args` and `kwargs` are arguments to NetpbmFile.asarray().
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> image = imread('_tmp.pgm')
|
||||
|
||||
"""
|
||||
try:
|
||||
netpbm = NetpbmFile(filename)
|
||||
image = netpbm.asarray()
|
||||
finally:
|
||||
netpbm.close()
|
||||
return image
|
||||
|
||||
|
||||
def imsave(filename, data, maxval=None, pam=False):
|
||||
"""Write image data to Netpbm file.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> image = numpy.array([[0, 1],[65534, 65535]], dtype=numpy.uint16)
|
||||
>>> imsave('_tmp.pgm', image)
|
||||
|
||||
"""
|
||||
try:
|
||||
netpbm = NetpbmFile(data, maxval=maxval)
|
||||
netpbm.write(filename, pam=pam)
|
||||
finally:
|
||||
netpbm.close()
|
||||
|
||||
|
||||
class NetpbmFile(object):
|
||||
"""Read and write Netpbm PAM, PBM, PGM, PPM, files."""
|
||||
|
||||
_types = {b'P1': b'BLACKANDWHITE', b'P2': b'GRAYSCALE', b'P3': b'RGB',
|
||||
b'P4': b'BLACKANDWHITE', b'P5': b'GRAYSCALE', b'P6': b'RGB',
|
||||
b'P7 332': b'RGB', b'P7': b'RGB_ALPHA'}
|
||||
|
||||
def __init__(self, arg=None, **kwargs):
|
||||
"""Initialize instance from filename, open file, or numpy array."""
|
||||
for attr in ('header', 'magicnum', 'width', 'height', 'maxval',
|
||||
'depth', 'tupltypes', '_filename', '_fh', '_data'):
|
||||
setattr(self, attr, None)
|
||||
if arg is None:
|
||||
self._fromdata([], **kwargs)
|
||||
elif isinstance(arg, basestring):
|
||||
self._fh = open(arg, 'rb')
|
||||
self._filename = arg
|
||||
self._fromfile(self._fh, **kwargs)
|
||||
elif hasattr(arg, 'seek'):
|
||||
self._fromfile(arg, **kwargs)
|
||||
self._fh = arg
|
||||
else:
|
||||
self._fromdata(arg, **kwargs)
|
||||
|
||||
def asarray(self, copy=True, cache=False, **kwargs):
|
||||
"""Return image data from file as numpy array."""
|
||||
data = self._data
|
||||
if data is None:
|
||||
data = self._read_data(self._fh, **kwargs)
|
||||
if cache:
|
||||
self._data = data
|
||||
else:
|
||||
return data
|
||||
return deepcopy(data) if copy else data
|
||||
|
||||
def write(self, arg, **kwargs):
|
||||
"""Write instance to file."""
|
||||
if hasattr(arg, 'seek'):
|
||||
self._tofile(arg, **kwargs)
|
||||
else:
|
||||
with open(arg, 'wb') as fid:
|
||||
self._tofile(fid, **kwargs)
|
||||
|
||||
def close(self):
|
||||
"""Close open file. Future asarray calls might fail."""
|
||||
if self._filename and self._fh:
|
||||
self._fh.close()
|
||||
self._fh = None
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
def _fromfile(self, fh):
|
||||
"""Initialize instance from open file."""
|
||||
fh.seek(0)
|
||||
data = fh.read(4096)
|
||||
if (len(data) < 7) or not (b'0' < data[1:2] < b'8'):
|
||||
raise ValueError("Not a Netpbm file:\n%s" % data[:32])
|
||||
try:
|
||||
self._read_pam_header(data)
|
||||
except Exception:
|
||||
try:
|
||||
self._read_pnm_header(data)
|
||||
except Exception:
|
||||
raise ValueError("Not a Netpbm file:\n%s" % data[:32])
|
||||
|
||||
def _read_pam_header(self, data):
|
||||
"""Read PAM header and initialize instance."""
|
||||
regroups = re.search(
|
||||
b"(^P7[\n\r]+(?:(?:[\n\r]+)|(?:#.*)|"
|
||||
b"(HEIGHT\s+\d+)|(WIDTH\s+\d+)|(DEPTH\s+\d+)|(MAXVAL\s+\d+)|"
|
||||
b"(?:TUPLTYPE\s+\w+))*ENDHDR\n)", data).groups()
|
||||
self.header = regroups[0]
|
||||
self.magicnum = b'P7'
|
||||
for group in regroups[1:]:
|
||||
key, value = group.split()
|
||||
setattr(self, unicode(key).lower(), int(value))
|
||||
matches = re.findall(b"(TUPLTYPE\s+\w+)", self.header)
|
||||
self.tupltypes = [s.split(None, 1)[1] for s in matches]
|
||||
|
||||
def _read_pnm_header(self, data):
|
||||
"""Read PNM header and initialize instance."""
|
||||
bpm = data[1:2] in b"14"
|
||||
regroups = re.search(b"".join((
|
||||
b"(^(P[123456]|P7 332)\s+(?:#.*[\r\n])*",
|
||||
b"\s*(\d+)\s+(?:#.*[\r\n])*",
|
||||
b"\s*(\d+)\s+(?:#.*[\r\n])*" * (not bpm),
|
||||
b"\s*(\d+)\s(?:\s*#.*[\r\n]\s)*)")), data).groups() + (1, ) * bpm
|
||||
self.header = regroups[0]
|
||||
self.magicnum = regroups[1]
|
||||
self.width = int(regroups[2])
|
||||
self.height = int(regroups[3])
|
||||
self.maxval = int(regroups[4])
|
||||
self.depth = 3 if self.magicnum in b"P3P6P7 332" else 1
|
||||
self.tupltypes = [self._types[self.magicnum]]
|
||||
|
||||
def _read_data(self, fh, byteorder='>'):
|
||||
"""Return image data from open file as numpy array."""
|
||||
fh.seek(len(self.header))
|
||||
data = fh.read()
|
||||
dtype = 'u1' if self.maxval < 256 else byteorder + 'u2'
|
||||
depth = 1 if self.magicnum == b"P7 332" else self.depth
|
||||
shape = [-1, self.height, self.width, depth]
|
||||
size = numpy.prod(shape[1:])
|
||||
if self.magicnum in b"P1P2P3":
|
||||
data = numpy.array(data.split(None, size)[:size], dtype)
|
||||
data = data.reshape(shape)
|
||||
elif self.maxval == 1:
|
||||
shape[2] = int(math.ceil(self.width / 8))
|
||||
data = numpy.frombuffer(data, dtype).reshape(shape)
|
||||
data = numpy.unpackbits(data, axis=-2)[:, :, :self.width, :]
|
||||
else:
|
||||
data = numpy.frombuffer(data, dtype)
|
||||
data = data[:size * (data.size // size)].reshape(shape)
|
||||
if data.shape[0] < 2:
|
||||
data = data.reshape(data.shape[1:])
|
||||
if data.shape[-1] < 2:
|
||||
data = data.reshape(data.shape[:-1])
|
||||
if self.magicnum == b"P7 332":
|
||||
rgb332 = numpy.array(list(numpy.ndindex(8, 8, 4)), numpy.uint8)
|
||||
rgb332 *= [36, 36, 85]
|
||||
data = numpy.take(rgb332, data, axis=0)
|
||||
return data
|
||||
|
||||
def _fromdata(self, data, maxval=None):
|
||||
"""Initialize instance from numpy array."""
|
||||
data = numpy.array(data, ndmin=2, copy=True)
|
||||
if data.dtype.kind not in "uib":
|
||||
raise ValueError("not an integer type: %s" % data.dtype)
|
||||
if data.dtype.kind == 'i' and numpy.min(data) < 0:
|
||||
raise ValueError("data out of range: %i" % numpy.min(data))
|
||||
if maxval is None:
|
||||
maxval = numpy.max(data)
|
||||
maxval = 255 if maxval < 256 else 65535
|
||||
if maxval < 0 or maxval > 65535:
|
||||
raise ValueError("data out of range: %i" % maxval)
|
||||
data = data.astype('u1' if maxval < 256 else '>u2')
|
||||
self._data = data
|
||||
if data.ndim > 2 and data.shape[-1] in (3, 4):
|
||||
self.depth = data.shape[-1]
|
||||
self.width = data.shape[-2]
|
||||
self.height = data.shape[-3]
|
||||
self.magicnum = b'P7' if self.depth == 4 else b'P6'
|
||||
else:
|
||||
self.depth = 1
|
||||
self.width = data.shape[-1]
|
||||
self.height = data.shape[-2]
|
||||
self.magicnum = b'P5' if maxval > 1 else b'P4'
|
||||
self.maxval = maxval
|
||||
self.tupltypes = [self._types[self.magicnum]]
|
||||
self.header = self._header()
|
||||
|
||||
def _tofile(self, fh, pam=False):
|
||||
"""Write Netbm file."""
|
||||
fh.seek(0)
|
||||
fh.write(self._header(pam))
|
||||
data = self.asarray(copy=False)
|
||||
if self.maxval == 1:
|
||||
data = numpy.packbits(data, axis=-1)
|
||||
data.tofile(fh)
|
||||
|
||||
def _header(self, pam=False):
|
||||
"""Return file header as byte string."""
|
||||
if pam or self.magicnum == b'P7':
|
||||
header = "\n".join((
|
||||
"P7",
|
||||
"HEIGHT %i" % self.height,
|
||||
"WIDTH %i" % self.width,
|
||||
"DEPTH %i" % self.depth,
|
||||
"MAXVAL %i" % self.maxval,
|
||||
"\n".join("TUPLTYPE %s" % unicode(i) for i in self.tupltypes),
|
||||
"ENDHDR\n"))
|
||||
elif self.maxval == 1:
|
||||
header = "P4 %i %i\n" % (self.width, self.height)
|
||||
elif self.depth == 1:
|
||||
header = "P5 %i %i %i\n" % (self.width, self.height, self.maxval)
|
||||
else:
|
||||
header = "P6 %i %i %i\n" % (self.width, self.height, self.maxval)
|
||||
if sys.version_info[0] > 2:
|
||||
header = bytes(header, 'ascii')
|
||||
return header
|
||||
|
||||
def __str__(self):
|
||||
"""Return information about instance."""
|
||||
return unicode(self.header)
|
||||
|
||||
|
||||
if sys.version_info[0] > 2:
|
||||
basestring = str
|
||||
unicode = lambda x: str(x, 'ascii')
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Show images specified on command line or all images in current directory
|
||||
from glob import glob
|
||||
from matplotlib import pyplot
|
||||
files = sys.argv[1:] if len(sys.argv) > 1 else glob('*.p*m')
|
||||
for fname in files:
|
||||
try:
|
||||
pam = NetpbmFile(fname)
|
||||
img = pam.asarray(copy=False)
|
||||
if False:
|
||||
pam.write('_tmp.pgm.out', pam=True)
|
||||
img2 = imread('_tmp.pgm.out')
|
||||
assert numpy.all(img == img2)
|
||||
imsave('_tmp.pgm.out', img)
|
||||
img2 = imread('_tmp.pgm.out')
|
||||
assert numpy.all(img == img2)
|
||||
pam.close()
|
||||
except ValueError as e:
|
||||
print(fname, e)
|
||||
continue
|
||||
_shape = img.shape
|
||||
if img.ndim > 3 or (img.ndim > 2 and img.shape[-1] not in (3, 4)):
|
||||
img = img[0]
|
||||
cmap = 'gray' if pam.maxval > 1 else 'binary'
|
||||
pyplot.imshow(img, cmap, interpolation='nearest')
|
||||
pyplot.title("%s %s %s %s" % (fname, unicode(pam.magicnum),
|
||||
_shape, img.dtype))
|
||||
pyplot.show()
|
||||
|
|
@ -1,181 +0,0 @@
|
|||
import pylab as pb
|
||||
import numpy as np
|
||||
from .. import util
|
||||
from GPy.util.latent_space_visualizations.controllers.imshow_controller import ImshowController
|
||||
from misc import param_to_array
|
||||
import itertools
|
||||
|
||||
def most_significant_input_dimensions(model, which_indices):
|
||||
if which_indices is None:
|
||||
if model.input_dim == 1:
|
||||
input_1 = 0
|
||||
input_2 = None
|
||||
if model.input_dim == 2:
|
||||
input_1, input_2 = 0, 1
|
||||
else:
|
||||
try:
|
||||
input_1, input_2 = np.argsort(model.input_sensitivity())[::-1][:2]
|
||||
except:
|
||||
raise ValueError, "cannot automatically determine which dimensions to plot, please pass 'which_indices'"
|
||||
else:
|
||||
input_1, input_2 = which_indices
|
||||
return input_1, input_2
|
||||
|
||||
def plot_latent(model, labels=None, which_indices=None,
|
||||
resolution=50, ax=None, marker='o', s=40,
|
||||
fignum=None, plot_inducing=False, legend=True,
|
||||
aspect='auto', updates=False):
|
||||
"""
|
||||
:param labels: a np.array of size model.num_data containing labels for the points (can be number, strings, etc)
|
||||
:param resolution: the resolution of the grid on which to evaluate the predictive variance
|
||||
"""
|
||||
if ax is None:
|
||||
fig = pb.figure(num=fignum)
|
||||
ax = fig.add_subplot(111)
|
||||
util.plot.Tango.reset()
|
||||
|
||||
if labels is None:
|
||||
labels = np.ones(model.num_data)
|
||||
|
||||
input_1, input_2 = most_significant_input_dimensions(model, which_indices)
|
||||
X = param_to_array(model.X)
|
||||
|
||||
# first, plot the output variance as a function of the latent space
|
||||
Xtest, xx, yy, xmin, xmax = util.plot.x_frame2D(X[:, [input_1, input_2]], resolution=resolution)
|
||||
Xtest_full = np.zeros((Xtest.shape[0], model.X.shape[1]))
|
||||
|
||||
def plot_function(x):
|
||||
Xtest_full[:, [input_1, input_2]] = x
|
||||
mu, var, low, up = model.predict(Xtest_full)
|
||||
var = var[:, :1]
|
||||
return np.log(var)
|
||||
view = ImshowController(ax, plot_function,
|
||||
tuple(X[:, [input_1, input_2]].min(0)) + tuple(X[:, [input_1, input_2]].max(0)),
|
||||
resolution, aspect=aspect, interpolation='bilinear',
|
||||
cmap=pb.cm.binary)
|
||||
|
||||
# ax.imshow(var.reshape(resolution, resolution).T,
|
||||
# extent=[xmin[0], xmax[0], xmin[1], xmax[1]], cmap=pb.cm.binary, interpolation='bilinear', origin='lower')
|
||||
|
||||
# make sure labels are in order of input:
|
||||
ulabels = []
|
||||
for lab in labels:
|
||||
if not lab in ulabels:
|
||||
ulabels.append(lab)
|
||||
|
||||
marker = itertools.cycle(list(marker))
|
||||
|
||||
for i, ul in enumerate(ulabels):
|
||||
if type(ul) is np.string_:
|
||||
this_label = ul
|
||||
elif type(ul) is np.int64:
|
||||
this_label = 'class %i' % ul
|
||||
else:
|
||||
this_label = 'class %i' % i
|
||||
m = marker.next()
|
||||
|
||||
index = np.nonzero(labels == ul)[0]
|
||||
if model.input_dim == 1:
|
||||
x = X[index, input_1]
|
||||
y = np.zeros(index.size)
|
||||
else:
|
||||
x = X[index, input_1]
|
||||
y = X[index, input_2]
|
||||
ax.scatter(x, y, marker=m, s=s, color=util.plot.Tango.nextMedium(), label=this_label)
|
||||
|
||||
ax.set_xlabel('latent dimension %i' % input_1)
|
||||
ax.set_ylabel('latent dimension %i' % input_2)
|
||||
|
||||
if not np.all(labels == 1.) and legend:
|
||||
ax.legend(loc=0, numpoints=1)
|
||||
|
||||
ax.set_xlim(xmin[0], xmax[0])
|
||||
ax.set_ylim(xmin[1], xmax[1])
|
||||
ax.grid(b=False) # remove the grid if present, it doesn't look good
|
||||
ax.set_aspect('auto') # set a nice aspect ratio
|
||||
|
||||
if plot_inducing:
|
||||
Z = param_to_array(model.Z)
|
||||
ax.plot(Z[:, input_1], Z[:, input_2], '^w')
|
||||
|
||||
if updates:
|
||||
ax.figure.canvas.show()
|
||||
raw_input('Enter to continue')
|
||||
return ax
|
||||
|
||||
def plot_magnification(model, labels=None, which_indices=None,
|
||||
resolution=60, ax=None, marker='o', s=40,
|
||||
fignum=None, plot_inducing=False, legend=True,
|
||||
aspect='auto', updates=False):
|
||||
"""
|
||||
:param labels: a np.array of size model.num_data containing labels for the points (can be number, strings, etc)
|
||||
:param resolution: the resolution of the grid on which to evaluate the predictive variance
|
||||
"""
|
||||
if ax is None:
|
||||
fig = pb.figure(num=fignum)
|
||||
ax = fig.add_subplot(111)
|
||||
util.plot.Tango.reset()
|
||||
|
||||
if labels is None:
|
||||
labels = np.ones(model.num_data)
|
||||
|
||||
input_1, input_2 = most_significant_input_dimensions(model, which_indices)
|
||||
|
||||
# first, plot the output variance as a function of the latent space
|
||||
Xtest, xx, yy, xmin, xmax = util.plot.x_frame2D(model.X[:, [input_1, input_2]], resolution=resolution)
|
||||
Xtest_full = np.zeros((Xtest.shape[0], model.X.shape[1]))
|
||||
def plot_function(x):
|
||||
Xtest_full[:, [input_1, input_2]] = x
|
||||
mf=model.magnification(Xtest_full)
|
||||
return mf
|
||||
view = ImshowController(ax, plot_function,
|
||||
tuple(model.X.min(0)[:, [input_1, input_2]]) + tuple(model.X.max(0)[:, [input_1, input_2]]),
|
||||
resolution, aspect=aspect, interpolation='bilinear',
|
||||
cmap=pb.cm.gray)
|
||||
|
||||
# make sure labels are in order of input:
|
||||
ulabels = []
|
||||
for lab in labels:
|
||||
if not lab in ulabels:
|
||||
ulabels.append(lab)
|
||||
|
||||
marker = itertools.cycle(list(marker))
|
||||
|
||||
for i, ul in enumerate(ulabels):
|
||||
if type(ul) is np.string_:
|
||||
this_label = ul
|
||||
elif type(ul) is np.int64:
|
||||
this_label = 'class %i' % ul
|
||||
else:
|
||||
this_label = 'class %i' % i
|
||||
m = marker.next()
|
||||
|
||||
index = np.nonzero(labels == ul)[0]
|
||||
if model.input_dim == 1:
|
||||
x = model.X[index, input_1]
|
||||
y = np.zeros(index.size)
|
||||
else:
|
||||
x = model.X[index, input_1]
|
||||
y = model.X[index, input_2]
|
||||
ax.scatter(x, y, marker=m, s=s, color=util.plot.Tango.nextMedium(), label=this_label)
|
||||
|
||||
ax.set_xlabel('latent dimension %i' % input_1)
|
||||
ax.set_ylabel('latent dimension %i' % input_2)
|
||||
|
||||
if not np.all(labels == 1.) and legend:
|
||||
ax.legend(loc=0, numpoints=1)
|
||||
|
||||
ax.set_xlim(xmin[0], xmax[0])
|
||||
ax.set_ylim(xmin[1], xmax[1])
|
||||
ax.grid(b=False) # remove the grid if present, it doesn't look good
|
||||
ax.set_aspect('auto') # set a nice aspect ratio
|
||||
|
||||
if plot_inducing:
|
||||
ax.plot(model.Z[:, input_1], model.Z[:, input_2], '^w')
|
||||
|
||||
if updates:
|
||||
ax.figure.canvas.show()
|
||||
raw_input('Enter to continue')
|
||||
|
||||
pb.title('Magnification Factor')
|
||||
return ax
|
||||
|
|
@ -1,538 +0,0 @@
|
|||
import matplotlib.pyplot as plt
|
||||
from mpl_toolkits.mplot3d import Axes3D
|
||||
import GPy
|
||||
import numpy as np
|
||||
import matplotlib as mpl
|
||||
import time
|
||||
import Image
|
||||
try:
|
||||
import visual
|
||||
visual_available = True
|
||||
|
||||
except ImportError:
|
||||
visual_available = False
|
||||
|
||||
|
||||
class data_show:
|
||||
"""
|
||||
The data_show class is a base class which describes how to visualize a
|
||||
particular data set. For example, motion capture data can be plotted as a
|
||||
stick figure, or images are shown using imshow. This class enables latent
|
||||
to data visualizations for the GP-LVM.
|
||||
"""
|
||||
def __init__(self, vals):
|
||||
self.vals = vals.copy()
|
||||
# If no axes are defined, create some.
|
||||
|
||||
def modify(self, vals):
|
||||
raise NotImplementedError, "this needs to be implemented to use the data_show class"
|
||||
|
||||
def close(self):
|
||||
raise NotImplementedError, "this needs to be implemented to use the data_show class"
|
||||
|
||||
class vpython_show(data_show):
|
||||
"""
|
||||
the vpython_show class is a base class for all visualization methods that use vpython to display. It is initialized with a scene. If the scene is set to None it creates a scene window.
|
||||
"""
|
||||
|
||||
def __init__(self, vals, scene=None):
|
||||
data_show.__init__(self, vals)
|
||||
# If no axes are defined, create some.
|
||||
|
||||
if scene==None:
|
||||
self.scene = visual.display(title='Data Visualization')
|
||||
else:
|
||||
self.scene = scene
|
||||
|
||||
def close(self):
|
||||
self.scene.exit()
|
||||
|
||||
|
||||
|
||||
class matplotlib_show(data_show):
|
||||
"""
|
||||
the matplotlib_show class is a base class for all visualization methods that use matplotlib. It is initialized with an axis. If the axis is set to None it creates a figure window.
|
||||
"""
|
||||
def __init__(self, vals, axes=None):
|
||||
data_show.__init__(self, vals)
|
||||
# If no axes are defined, create some.
|
||||
|
||||
if axes==None:
|
||||
fig = plt.figure()
|
||||
self.axes = fig.add_subplot(111)
|
||||
else:
|
||||
self.axes = axes
|
||||
|
||||
def close(self):
|
||||
plt.close(self.axes.get_figure())
|
||||
|
||||
class vector_show(matplotlib_show):
|
||||
"""
|
||||
A base visualization class that just shows a data vector as a plot of
|
||||
vector elements alongside their indices.
|
||||
"""
|
||||
def __init__(self, vals, axes=None):
|
||||
matplotlib_show.__init__(self, vals, axes)
|
||||
self.handle = self.axes.plot(np.arange(0, len(vals))[:, None], self.vals.T)[0]
|
||||
|
||||
def modify(self, vals):
|
||||
self.vals = vals.copy()
|
||||
xdata, ydata = self.handle.get_data()
|
||||
self.handle.set_data(xdata, self.vals.T)
|
||||
self.axes.figure.canvas.draw()
|
||||
|
||||
|
||||
class lvm(matplotlib_show):
|
||||
def __init__(self, vals, model, data_visualize, latent_axes=None, sense_axes=None, latent_index=[0,1]):
|
||||
"""Visualize a latent variable model
|
||||
|
||||
:param model: the latent variable model to visualize.
|
||||
:param data_visualize: the object used to visualize the data which has been modelled.
|
||||
:type data_visualize: visualize.data_show type.
|
||||
:param latent_axes: the axes where the latent visualization should be plotted.
|
||||
"""
|
||||
if vals == None:
|
||||
vals = model.X[0]
|
||||
|
||||
matplotlib_show.__init__(self, vals, axes=latent_axes)
|
||||
|
||||
if isinstance(latent_axes,mpl.axes.Axes):
|
||||
self.cid = latent_axes.figure.canvas.mpl_connect('button_press_event', self.on_click)
|
||||
self.cid = latent_axes.figure.canvas.mpl_connect('motion_notify_event', self.on_move)
|
||||
self.cid = latent_axes.figure.canvas.mpl_connect('axes_leave_event', self.on_leave)
|
||||
self.cid = latent_axes.figure.canvas.mpl_connect('axes_enter_event', self.on_enter)
|
||||
else:
|
||||
self.cid = latent_axes[0].figure.canvas.mpl_connect('button_press_event', self.on_click)
|
||||
self.cid = latent_axes[0].figure.canvas.mpl_connect('motion_notify_event', self.on_move)
|
||||
self.cid = latent_axes[0].figure.canvas.mpl_connect('axes_leave_event', self.on_leave)
|
||||
self.cid = latent_axes[0].figure.canvas.mpl_connect('axes_enter_event', self.on_enter)
|
||||
|
||||
self.data_visualize = data_visualize
|
||||
self.model = model
|
||||
self.latent_axes = latent_axes
|
||||
self.sense_axes = sense_axes
|
||||
self.called = False
|
||||
self.move_on = False
|
||||
self.latent_index = latent_index
|
||||
self.latent_dim = model.input_dim
|
||||
|
||||
# The red cross which shows current latent point.
|
||||
self.latent_values = vals
|
||||
self.latent_handle = self.latent_axes.plot([0],[0],'rx',mew=2)[0]
|
||||
self.modify(vals)
|
||||
self.show_sensitivities()
|
||||
|
||||
def modify(self, vals):
|
||||
"""When latent values are modified update the latent representation and ulso update the output visualization."""
|
||||
self.vals = vals.copy()
|
||||
y = self.model.predict(self.vals)[0]
|
||||
self.data_visualize.modify(y)
|
||||
self.latent_handle.set_data(self.vals[self.latent_index[0]], self.vals[self.latent_index[1]])
|
||||
self.axes.figure.canvas.draw()
|
||||
|
||||
|
||||
def on_enter(self,event):
|
||||
pass
|
||||
def on_leave(self,event):
|
||||
pass
|
||||
|
||||
def on_click(self, event):
|
||||
print 'click!'
|
||||
if event.inaxes!=self.latent_axes: return
|
||||
self.move_on = not self.move_on
|
||||
self.called = True
|
||||
|
||||
def on_move(self, event):
|
||||
if event.inaxes!=self.latent_axes: return
|
||||
if self.called and self.move_on:
|
||||
# Call modify code on move
|
||||
self.latent_values[self.latent_index[0]]=event.xdata
|
||||
self.latent_values[self.latent_index[1]]=event.ydata
|
||||
self.modify(self.latent_values)
|
||||
|
||||
def show_sensitivities(self):
|
||||
# A click in the bar chart axis for selection a dimension.
|
||||
if self.sense_axes != None:
|
||||
self.sense_axes.cla()
|
||||
self.sense_axes.bar(np.arange(self.model.input_dim), self.model.input_sensitivity(), color='b')
|
||||
|
||||
if self.latent_index[1] == self.latent_index[0]:
|
||||
self.sense_axes.bar(np.array(self.latent_index[0]), self.model.input_sensitivity()[self.latent_index[0]], color='y')
|
||||
self.sense_axes.bar(np.array(self.latent_index[1]), self.model.input_sensitivity()[self.latent_index[1]], color='y')
|
||||
|
||||
else:
|
||||
self.sense_axes.bar(np.array(self.latent_index[0]), self.model.input_sensitivity()[self.latent_index[0]], color='g')
|
||||
self.sense_axes.bar(np.array(self.latent_index[1]), self.model.input_sensitivity()[self.latent_index[1]], color='r')
|
||||
|
||||
self.sense_axes.figure.canvas.draw()
|
||||
|
||||
|
||||
class lvm_subplots(lvm):
|
||||
"""
|
||||
latent_axes is a np array of dimension np.ceil(input_dim/2),
|
||||
one for each pair of the latent dimensions.
|
||||
"""
|
||||
def __init__(self, vals, Model, data_visualize, latent_axes=None, sense_axes=None):
|
||||
self.nplots = int(np.ceil(Model.input_dim/2.))+1
|
||||
assert len(latent_axes)==self.nplots
|
||||
if vals==None:
|
||||
vals = Model.X[0, :]
|
||||
self.latent_values = vals
|
||||
|
||||
for i, axis in enumerate(latent_axes):
|
||||
if i == self.nplots-1:
|
||||
if self.nplots*2!=Model.input_dim:
|
||||
latent_index = [i*2, i*2]
|
||||
lvm.__init__(self, self.latent_vals, Model, data_visualize, axis, sense_axes, latent_index=latent_index)
|
||||
else:
|
||||
latent_index = [i*2, i*2+1]
|
||||
lvm.__init__(self, self.latent_vals, Model, data_visualize, axis, latent_index=latent_index)
|
||||
|
||||
|
||||
|
||||
class lvm_dimselect(lvm):
|
||||
"""
|
||||
A visualizer for latent variable models which allows selection of the latent dimensions to use by clicking on a bar chart of their length scales.
|
||||
|
||||
For an example of the visualizer's use try:
|
||||
|
||||
GPy.examples.dimensionality_reduction.BGPVLM_oil()
|
||||
|
||||
"""
|
||||
def __init__(self, vals, model, data_visualize, latent_axes=None, sense_axes=None, latent_index=[0, 1], labels=None):
|
||||
if latent_axes==None and sense_axes==None:
|
||||
self.fig,(latent_axes,self.sense_axes) = plt.subplots(1,2)
|
||||
elif sense_axes==None:
|
||||
fig=plt.figure()
|
||||
self.sense_axes = fig.add_subplot(111)
|
||||
else:
|
||||
self.sense_axes = sense_axes
|
||||
self.labels = labels
|
||||
lvm.__init__(self,vals,model,data_visualize,latent_axes,sense_axes,latent_index)
|
||||
self.show_sensitivities()
|
||||
print "use left and right mouse butons to select dimensions"
|
||||
|
||||
|
||||
def on_click(self, event):
|
||||
|
||||
if event.inaxes==self.sense_axes:
|
||||
new_index = max(0,min(int(np.round(event.xdata-0.5)),self.model.input_dim-1))
|
||||
if event.button == 1:
|
||||
# Make it red if and y-axis (red=port=left) if it is a left button click
|
||||
self.latent_index[1] = new_index
|
||||
else:
|
||||
# Make it green and x-axis (green=starboard=right) if it is a right button click
|
||||
self.latent_index[0] = new_index
|
||||
|
||||
self.show_sensitivities()
|
||||
|
||||
self.latent_axes.cla()
|
||||
self.model.plot_latent(which_indices=self.latent_index,
|
||||
ax=self.latent_axes, labels=self.labels)
|
||||
self.latent_handle = self.latent_axes.plot([0],[0],'rx',mew=2)[0]
|
||||
self.modify(self.latent_values)
|
||||
|
||||
elif event.inaxes==self.latent_axes:
|
||||
self.move_on = not self.move_on
|
||||
|
||||
self.called = True
|
||||
|
||||
|
||||
|
||||
def on_leave(self,event):
|
||||
latent_values = self.latent_values.copy()
|
||||
y = self.model.predict(latent_values[None,:])[0]
|
||||
self.data_visualize.modify(y)
|
||||
|
||||
|
||||
|
||||
class image_show(matplotlib_show):
|
||||
"""Show a data vector as an image. This visualizer rehapes the output vector and displays it as an image.
|
||||
|
||||
:param vals: the values of the output to display.
|
||||
:type vals: ndarray
|
||||
:param axes: the axes to show the output on.
|
||||
:type vals: axes handle
|
||||
:param dimensions: the dimensions that the image needs to be transposed to for display.
|
||||
:type dimensions: tuple
|
||||
:param transpose: whether to transpose the image before display.
|
||||
:type bool: default is False.
|
||||
:param order: whether array is in Fortan ordering ('F') or Python ordering ('C'). Default is python ('C').
|
||||
:type order: string
|
||||
:param invert: whether to invert the pixels or not (default False).
|
||||
:type invert: bool
|
||||
:param palette: a palette to use for the image.
|
||||
:param preset_mean: the preset mean of a scaled image.
|
||||
:type preset_mean: double
|
||||
:param preset_std: the preset standard deviation of a scaled image.
|
||||
:type preset_std: double"""
|
||||
def __init__(self, vals, axes=None, dimensions=(16,16), transpose=False, order='C', invert=False, scale=False, palette=[], preset_mean = 0., preset_std = -1., select_image=0):
|
||||
matplotlib_show.__init__(self, vals, axes)
|
||||
self.dimensions = dimensions
|
||||
self.transpose = transpose
|
||||
self.order = order
|
||||
self.invert = invert
|
||||
self.scale = scale
|
||||
self.palette = palette
|
||||
self.preset_mean = preset_mean
|
||||
self.preset_std = preset_std
|
||||
self.select_image = select_image # This is used when the y vector contains multiple images concatenated.
|
||||
|
||||
self.set_image(self.vals)
|
||||
if not self.palette == []: # Can just show the image (self.set_image() took care of setting the palette)
|
||||
self.handle = self.axes.imshow(self.vals, interpolation='nearest')
|
||||
else: # Use a boring gray map.
|
||||
self.handle = self.axes.imshow(self.vals, cmap=plt.cm.gray, interpolation='nearest') # @UndefinedVariable
|
||||
plt.show()
|
||||
|
||||
def modify(self, vals):
|
||||
self.set_image(vals.copy())
|
||||
self.handle.set_array(self.vals)
|
||||
self.axes.figure.canvas.draw()
|
||||
|
||||
def set_image(self, vals):
|
||||
dim = self.dimensions[0] * self.dimensions[1]
|
||||
num_images = np.sqrt(vals[0,].size/dim)
|
||||
if num_images > 1 and num_images.is_integer(): # Show a mosaic of images
|
||||
num_images = np.int(num_images)
|
||||
self.vals = np.zeros((self.dimensions[0]*num_images, self.dimensions[1]*num_images))
|
||||
for iR in range(num_images):
|
||||
for iC in range(num_images):
|
||||
cur_img_id = iR*num_images + iC
|
||||
cur_img = np.reshape(vals[0,dim*cur_img_id+np.array(range(dim))], self.dimensions, order=self.order)
|
||||
first_row = iR*self.dimensions[0]
|
||||
last_row = (iR+1)*self.dimensions[0]
|
||||
first_col = iC*self.dimensions[1]
|
||||
last_col = (iC+1)*self.dimensions[1]
|
||||
self.vals[first_row:last_row, first_col:last_col] = cur_img
|
||||
|
||||
else:
|
||||
self.vals = np.reshape(vals[0,dim*self.select_image+np.array(range(dim))], self.dimensions, order=self.order)
|
||||
if self.transpose:
|
||||
self.vals = self.vals.T
|
||||
# if not self.scale:
|
||||
# self.vals = self.vals
|
||||
if self.invert:
|
||||
self.vals = -self.vals
|
||||
|
||||
# un-normalizing, for visualisation purposes:
|
||||
if self.preset_std >= 0: # The Mean is assumed to be in the range (0,255)
|
||||
self.vals = self.vals*self.preset_std + self.preset_mean
|
||||
# Clipping the values:
|
||||
self.vals[self.vals < 0] = 0
|
||||
self.vals[self.vals > 255] = 255
|
||||
else:
|
||||
self.vals = 255*(self.vals - self.vals.min())/(self.vals.max() - self.vals.min())
|
||||
if not self.palette == []: # applying using an image palette (e.g. if the image has been quantized)
|
||||
self.vals = Image.fromarray(self.vals.astype('uint8'))
|
||||
self.vals.putpalette(self.palette) # palette is a list, must be loaded before calling this function
|
||||
|
||||
class mocap_data_show_vpython(vpython_show):
|
||||
"""Base class for visualizing motion capture data using visual module."""
|
||||
|
||||
def __init__(self, vals, scene=None, connect=None, radius=0.1):
|
||||
vpython_show.__init__(self, vals, scene)
|
||||
self.radius = radius
|
||||
self.connect = connect
|
||||
self.process_values()
|
||||
self.draw_edges()
|
||||
self.draw_vertices()
|
||||
|
||||
def draw_vertices(self):
|
||||
self.spheres = []
|
||||
for i in range(self.vals.shape[0]):
|
||||
self.spheres.append(visual.sphere(pos=(self.vals[i, 0], self.vals[i, 2], self.vals[i, 1]), radius=self.radius))
|
||||
self.scene.visible=True
|
||||
|
||||
def draw_edges(self):
|
||||
self.rods = []
|
||||
self.line_handle = []
|
||||
if not self.connect==None:
|
||||
self.I, self.J = np.nonzero(self.connect)
|
||||
for i, j in zip(self.I, self.J):
|
||||
pos, axis = self.pos_axis(i, j)
|
||||
self.rods.append(visual.cylinder(pos=pos, axis=axis, radius=self.radius))
|
||||
|
||||
def modify_vertices(self):
|
||||
for i in range(self.vals.shape[0]):
|
||||
self.spheres[i].pos = (self.vals[i, 0], self.vals[i, 2], self.vals[i, 1])
|
||||
|
||||
def modify_edges(self):
|
||||
self.line_handle = []
|
||||
if not self.connect==None:
|
||||
self.I, self.J = np.nonzero(self.connect)
|
||||
for rod, i, j in zip(self.rods, self.I, self.J):
|
||||
rod.pos, rod.axis = self.pos_axis(i, j)
|
||||
|
||||
def pos_axis(self, i, j):
|
||||
pos = []
|
||||
axis = []
|
||||
pos.append(self.vals[i, 0])
|
||||
axis.append(self.vals[j, 0]-self.vals[i,0])
|
||||
pos.append(self.vals[i, 2])
|
||||
axis.append(self.vals[j, 2]-self.vals[i,2])
|
||||
pos.append(self.vals[i, 1])
|
||||
axis.append(self.vals[j, 1]-self.vals[i,1])
|
||||
return pos, axis
|
||||
|
||||
def modify(self, vals):
|
||||
self.vals = vals.copy()
|
||||
self.process_values()
|
||||
self.modify_edges()
|
||||
self.modify_vertices()
|
||||
|
||||
def process_values(self):
|
||||
raise NotImplementedError, "this needs to be implemented to use the data_show class"
|
||||
|
||||
|
||||
class mocap_data_show(matplotlib_show):
|
||||
"""Base class for visualizing motion capture data."""
|
||||
|
||||
def __init__(self, vals, axes=None, connect=None):
|
||||
if axes==None:
|
||||
fig = plt.figure()
|
||||
axes = fig.add_subplot(111, projection='3d')
|
||||
matplotlib_show.__init__(self, vals, axes)
|
||||
|
||||
self.connect = connect
|
||||
self.process_values()
|
||||
self.initialize_axes()
|
||||
self.draw_vertices()
|
||||
self.finalize_axes()
|
||||
self.draw_edges()
|
||||
self.axes.figure.canvas.draw()
|
||||
|
||||
def draw_vertices(self):
|
||||
self.points_handle = self.axes.scatter(self.vals[:, 0], self.vals[:, 1], self.vals[:, 2])
|
||||
|
||||
def draw_edges(self):
|
||||
self.line_handle = []
|
||||
if not self.connect==None:
|
||||
x = []
|
||||
y = []
|
||||
z = []
|
||||
self.I, self.J = np.nonzero(self.connect)
|
||||
for i, j in zip(self.I, self.J):
|
||||
x.append(self.vals[i, 0])
|
||||
x.append(self.vals[j, 0])
|
||||
x.append(np.NaN)
|
||||
y.append(self.vals[i, 1])
|
||||
y.append(self.vals[j, 1])
|
||||
y.append(np.NaN)
|
||||
z.append(self.vals[i, 2])
|
||||
z.append(self.vals[j, 2])
|
||||
z.append(np.NaN)
|
||||
self.line_handle = self.axes.plot(np.array(x), np.array(y), np.array(z), 'b-')
|
||||
|
||||
def modify(self, vals):
|
||||
self.vals = vals.copy()
|
||||
self.process_values()
|
||||
self.initialize_axes_modify()
|
||||
self.draw_vertices()
|
||||
self.finalize_axes_modify()
|
||||
self.draw_edges()
|
||||
self.axes.figure.canvas.draw()
|
||||
|
||||
def process_values(self):
|
||||
raise NotImplementedError, "this needs to be implemented to use the data_show class"
|
||||
|
||||
def initialize_axes(self):
|
||||
"""Set up the axes with the right limits and scaling."""
|
||||
self.x_lim = np.array([self.vals[:, 0].min(), self.vals[:, 0].max()])
|
||||
self.y_lim = np.array([self.vals[:, 1].min(), self.vals[:, 1].max()])
|
||||
self.z_lim = np.array([self.vals[:, 2].min(), self.vals[:, 2].max()])
|
||||
|
||||
def initialize_axes_modify(self):
|
||||
self.points_handle.remove()
|
||||
self.line_handle[0].remove()
|
||||
|
||||
def finalize_axes(self):
|
||||
self.axes.set_xlim(self.x_lim)
|
||||
self.axes.set_ylim(self.y_lim)
|
||||
self.axes.set_zlim(self.z_lim)
|
||||
self.axes.auto_scale_xyz([-1., 1.], [-1., 1.], [-1.5, 1.5])
|
||||
|
||||
#self.axes.set_aspect('equal')
|
||||
self.axes.autoscale(enable=False)
|
||||
|
||||
def finalize_axes_modify(self):
|
||||
self.axes.set_xlim(self.x_lim)
|
||||
self.axes.set_ylim(self.y_lim)
|
||||
self.axes.set_zlim(self.z_lim)
|
||||
|
||||
class stick_show(mocap_data_show):
|
||||
"""Show a three dimensional point cloud as a figure. Connect elements of the figure together using the matrix connect."""
|
||||
def __init__(self, vals, connect=None, axes=None):
|
||||
mocap_data_show.__init__(self, vals, axes=axes, connect=connect)
|
||||
|
||||
def process_values(self):
|
||||
self.vals = self.vals.reshape((3, self.vals.shape[1]/3)).T
|
||||
|
||||
class skeleton_show(mocap_data_show):
|
||||
"""data_show class for visualizing motion capture data encoded as a skeleton with angles."""
|
||||
def __init__(self, vals, skel, axes=None, padding=0):
|
||||
"""data_show class for visualizing motion capture data encoded as a skeleton with angles.
|
||||
:param vals: set of modeled angles to use for printing in the axis when it's first created.
|
||||
:type vals: np.array
|
||||
:param skel: skeleton object that has the parameters of the motion capture skeleton associated with it.
|
||||
:type skel: mocap.skeleton object
|
||||
:param padding:
|
||||
:type int
|
||||
"""
|
||||
self.skel = skel
|
||||
self.padding = padding
|
||||
connect = skel.connection_matrix()
|
||||
mocap_data_show.__init__(self, vals, axes=axes, connect=connect)
|
||||
def process_values(self):
|
||||
"""Takes a set of angles and converts them to the x,y,z coordinates in the internal prepresentation of the class, ready for plotting.
|
||||
|
||||
:param vals: the values that are being modelled."""
|
||||
|
||||
if self.padding>0:
|
||||
channels = np.zeros((self.vals.shape[0], self.vals.shape[1]+self.padding))
|
||||
channels[:, 0:self.vals.shape[0]] = self.vals
|
||||
else:
|
||||
channels = self.vals
|
||||
vals_mat = self.skel.to_xyz(channels.flatten())
|
||||
self.vals = np.zeros_like(vals_mat)
|
||||
# Flip the Y and Z axes
|
||||
self.vals[:, 0] = vals_mat[:, 0].copy()
|
||||
self.vals[:, 1] = vals_mat[:, 2].copy()
|
||||
self.vals[:, 2] = vals_mat[:, 1].copy()
|
||||
|
||||
def wrap_around(self, lim, connect):
|
||||
quot = lim[1] - lim[0]
|
||||
self.vals = rem(self.vals, quot)+lim[0]
|
||||
nVals = floor(self.vals/quot)
|
||||
for i in range(connect.shape[0]):
|
||||
for j in find(connect[i, :]):
|
||||
if nVals[i] != nVals[j]:
|
||||
connect[i, j] = False
|
||||
return connect
|
||||
|
||||
|
||||
def data_play(Y, visualizer, frame_rate=30):
|
||||
"""Play a data set using the data_show object given.
|
||||
|
||||
:Y: the data set to be visualized.
|
||||
:param visualizer: the data show objectwhether to display during optimisation
|
||||
:type visualizer: data_show
|
||||
|
||||
Example usage:
|
||||
|
||||
This example loads in the CMU mocap database (http://mocap.cs.cmu.edu) subject number 35 motion number 01. It then plays it using the mocap_show visualize object.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
data = GPy.util.datasets.cmu_mocap(subject='35', train_motions=['01'])
|
||||
Y = data['Y']
|
||||
Y[:, 0:3] = 0. # Make figure walk in place
|
||||
visualize = GPy.util.visualize.skeleton_show(Y[0, :], data['skel'])
|
||||
GPy.util.visualize.data_play(Y, visualize)
|
||||
|
||||
"""
|
||||
|
||||
|
||||
for y in Y:
|
||||
visualizer.modify(y[None, :])
|
||||
time.sleep(1./float(frame_rate))
|
||||
Loading…
Add table
Add a link
Reference in a new issue