Many modifications in GP plots to make it work

This commit is contained in:
Nicolas 2013-02-05 12:27:12 +00:00
parent 4dce1a428f
commit 5b19d8609a
3 changed files with 63 additions and 17 deletions

View file

@ -3,10 +3,11 @@
import numpy as np
import pylab as pb
from ..core.parameterised import parameterised
from kernpart import kernpart
import itertools
from product_orthogonal import product_orthogonal
from product_orthogonal import product_orthogonal
class kern(parameterised):
def __init__(self,D,parts=[], input_slices=None):
@ -372,3 +373,50 @@ class kern(parameterised):
#TODO: there are some extra terms to compute here!
return target_mu, target_S
def plot(self, x = None, plot_limits=None,which_functions='all',resolution=None):
if which_functions=='all':
which_functions = [True]*self.Nparts
if self.D == 1:
if x is None:
x = np.zeros((1,1))
else:
assert x.shape == (1,1), "The shape fo the fixed variable x is not (1,D)"
if plot_limits == None:
xmin, xmax = (x-5).flatten(), (x+5).flatten()
elif len(plot_limits) == 2:
xmin, xmax = plot_limits
else:
raise ValueError, "Bad limits for plotting"
Xnew = np.linspace(xmin,xmax,resolution or 200)[:,None]
Kx = self.K(Xnew,x)
pb.plot(Xnew,Kx)
pb.xlim(xmin,xmax)
pb.ylim(Kx.min() - (Kx.max()-Kx.min())*0.15,Kx.max() + (Kx.max()-Kx.min())*0.15)
elif self.D == 2:
if x is None:
x = np.zeros((1,2))
else:
assert x.shape == (1,2), "The shape fo the fixed variable x is not (1,D)"
if plot_limits == None:
xmin, xmax = (x-5).flatten(), (x+5).flatten()
elif len(plot_limits) == 2:
xmin, xmax = plot_limits
else:
raise ValueError, "Bad limits for plotting"
resolution = resolution or 50
xx,yy = np.mgrid[xmin[0]:xmax[0]:1j*resolution,xmin[1]:xmax[1]:1j*resolution]
Xnew = np.vstack((xx.flatten(),yy.flatten())).T
Kx = self.K(Xnew,x)
Kx = Kx.reshape(resolution,resolution)
pb.contour(xx,yy,zz,vmin=zz.min(),vmax=zz.max(),cmap=pb.cm.jet)
pb.scatter(Xorig[:,0],Xorig[:,1],40,Yorig,linewidth=0,cmap=pb.cm.jet,vmin=zz.min(),vmax=zz.max())
pb.xlim(xmin[0],xmax[0])
pb.ylim(xmin[1],xmax[1])
else:
raise NotImplementedError, "Cannot plot a kernel with more than two input dimensions"

View file

@ -206,13 +206,13 @@ class GP(model):
gpplot(Xnew,m,m-np.sqrt(v),m+np.sqrt(v))
pb.plot(self.X[which_data],self.likelihood.Y[which_data],'kx',mew=1.5)
pb.xlim(xmin,xmax)
elif X.shape[1]==2:
elif self.X.shape[1]==2:
resolution = resolution or 50
Xnew, xmin, xmax,xx,yy = x_frame2D(self.X, plot_limits=plot_limits)
Xnew, xmin, xmax,xx,yy = x_frame2D(self.X, plot_limits,resolution)
m,v = self._raw_predict(Xnew, slices=which_functions)
m = m.reshape(resolution,resolution)
pb.contour(xx,yy,zz,vmin=zz.min(),vmax=zz.max(),cmap=pb.cm.jet)
pb.scatter(Xorig[:,0],Xorig[:,1],40,Yorig,linewidth=0,cmap=pb.cm.jet,vmin=zz.min(),vmax=zz.max())
m = m.reshape(resolution,resolution).T
pb.contour(xx,yy,m,vmin=m.min(),vmax=m.max(),cmap=pb.cm.jet)
pb.scatter(Xorig[:,0],Xorig[:,1],40,Yorig,linewidth=0,cmap=pb.cm.jet,vmin=m.min(), vmax=m.max())
pb.xlim(xmin[0],xmax[0])
pb.ylim(xmin[1],xmax[1])
else:
@ -232,17 +232,16 @@ class GP(model):
ymin,ymax = self.likelihood.data.min()*1.2,self.likelihood.data.max()*1.2
pb.xlim(xmin,xmax)
pb.ylim(ymin,ymax)
elif X.shape[1]==2:
elif self.X.shape[1]==2:
resolution = resolution or 50
Xnew, xmin, xmax,xx,yy = x_frame2D(self.X, plot_limits=plot_limits)
m,v = self.predict(Xnew, slices=which_functions)
m = m.reshape(resolution,resolution)
pb.contour(xx,yy,zz,vmin=zz.min(),vmax=zz.max(),cmap=pb.cm.jet)
pb.scatter(Xorig[:,0],Xorig[:,1],40,Yorig,linewidth=0,cmap=pb.cm.jet,vmin=zz.min(),vmax=zz.max())
Xnew, xx, yy, xmin, xmax = x_frame2D(self.X, plot_limits,resolution)
x, y = np.linspace(xmin[0],xmax[0],resolution), np.linspace(xmin[1],xmax[1],resolution)
m,lower,upper = self.predict(Xnew, slices=which_functions)
m = m.reshape(resolution,resolution).T
pb.contour(x,y,m,vmin=m.min(),vmax=m.max(),cmap=pb.cm.jet)
Yf = self.likelihood.Y.flatten()
pb.scatter(self.X[:,0], self.X[:,1], 40, Yf, cmap=pb.cm.jet,vmin=m.min(),vmax=m.max(), linewidth=0.)
pb.xlim(xmin[0],xmax[0])
pb.ylim(xmin[1],xmax[1])
else:
raise NotImplementedError, "Cannot define a frame with more than two input dimensions"

View file

@ -102,5 +102,4 @@ def x_frame2D(X,plot_limits=None,resolution=None):
resolution = resolution or 50
xx,yy = np.mgrid[xmin[0]:xmax[0]:1j*resolution,xmin[1]:xmax[1]:1j*resolution]
Xnew = np.vstack((xx.flatten(),yy.flatten())).T
return Xnew, xx,yy,xmin, xmax
return Xnew, xx, yy, xmin, xmax