mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-06 02:24:17 +02:00
[SSGPLVM] add plotting class
This commit is contained in:
parent
01860455af
commit
5cfc250ad1
9 changed files with 96 additions and 10 deletions
|
|
@ -15,3 +15,5 @@ import latent_space_visualizations
|
|||
import netpbmfile
|
||||
import inference_plots
|
||||
import maps
|
||||
import img_plots
|
||||
from ssgplvm import SSGPLVM_plot
|
||||
|
|
|
|||
56
GPy/plotting/matplot_dep/img_plots.py
Normal file
56
GPy/plotting/matplot_dep/img_plots.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
"""
|
||||
The module contains the tools for ploting 2D image visualizations
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from matplotlib.cm import jet
|
||||
|
||||
width_max = 15
|
||||
height_max = 12
|
||||
|
||||
def _calculateFigureSize(x_size, y_size, fig_ncols, fig_nrows, pad):
|
||||
width = (x_size*fig_ncols+pad*(fig_ncols-1))
|
||||
height = (y_size*fig_nrows+pad*(fig_nrows-1))
|
||||
if width > float(height)/height_max*width_max:
|
||||
return (width_max, float(width_max)/width*height)
|
||||
else:
|
||||
return (float(height_max)/height*width, height_max)
|
||||
|
||||
def plot_2D_images(figure, arr, symmetric=False, pad=None, zoom=None, mode=None, interpolation='nearest'):
|
||||
ax = figure.add_subplot(111)
|
||||
if len(arr.shape)==2:
|
||||
arr = arr.reshape(*((1,)+arr.shape))
|
||||
fig_num = arr.shape[0]
|
||||
y_size = arr.shape[1]
|
||||
x_size = arr.shape[2]
|
||||
fig_ncols = int(np.ceil(np.sqrt(fig_num)))
|
||||
fig_nrows = int(np.ceil((float)(fig_num)/fig_ncols))
|
||||
if pad==None:
|
||||
pad = max(int(min(y_size,x_size)/10),1)
|
||||
|
||||
figsize = _calculateFigureSize(x_size, y_size, fig_ncols, fig_nrows, pad)
|
||||
figure.set_size_inches(figsize,forward=True)
|
||||
#figure.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95)
|
||||
|
||||
if symmetric:
|
||||
# symmetric around zero: fix zero as the middle color
|
||||
mval = max(abs(arr.max()),abs(arr.min()))
|
||||
arr = arr/(2.*mval)+0.5
|
||||
else:
|
||||
minval,maxval = arr.max(),arr.min()
|
||||
arr = (arr-minval)/(maxval-minval)
|
||||
|
||||
if mode=='L':
|
||||
arr_color = np.empty(arr.shape+(3,))
|
||||
arr_color[:] = arr.reshape(*(arr.shape+(1,)))
|
||||
elif mode==None or mode=='jet':
|
||||
arr_color = jet(arr)
|
||||
|
||||
buf = np.ones((y_size*fig_nrows+pad*(fig_nrows-1), x_size*fig_ncols+pad*(fig_ncols-1), 3),dtype=arr.dtype)
|
||||
|
||||
for y in xrange(fig_nrows):
|
||||
for x in xrange(fig_ncols):
|
||||
if y*fig_ncols+x<fig_num:
|
||||
buf[y*y_size+y*pad:(y+1)*y_size+y*pad, x*x_size+x*pad:(x+1)*x_size+x*pad] = arr_color[y*fig_ncols+x,:,:,:3]
|
||||
img_plot = ax.imshow(buf, interpolation=interpolation)
|
||||
ax.axis('off')
|
||||
29
GPy/plotting/matplot_dep/ssgplvm.py
Normal file
29
GPy/plotting/matplot_dep/ssgplvm.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
"""
|
||||
The module plotting results for SSGPLVM
|
||||
"""
|
||||
|
||||
import pylab
|
||||
|
||||
from ...models import SSGPLVM
|
||||
from img_plots import plot_2D_images
|
||||
from ...util.misc import param_to_array
|
||||
|
||||
class SSGPLVM_plot(object):
|
||||
def __init__(self,model, imgsize):
|
||||
assert isinstance(model,SSGPLVM)
|
||||
self.model = model
|
||||
self.imgsize= imgsize
|
||||
assert model.Y.shape[1] == imgsize[0]*imgsize[1]
|
||||
|
||||
def plot_inducing(self):
|
||||
fig1 = pylab.figure()
|
||||
mean = self.model.posterior.mean
|
||||
arr = mean.reshape(*(mean.shape[0],self.imgsize[1],self.imgsize[0]))
|
||||
plot_2D_images(fig1, arr)
|
||||
fig1.gca().set_title('The mean of inducing points')
|
||||
|
||||
fig2 = pylab.figure()
|
||||
covar = self.model.posterior.covariance
|
||||
plot_2D_images(fig2, covar)
|
||||
fig2.gca().set_title('The variance of inducing points')
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue