Added new plotting function for kernels

This commit is contained in:
Nicolas 2013-02-06 15:20:04 +00:00
parent b0f6495ed4
commit 8fd79f6eee

View file

@ -383,7 +383,9 @@ class kern(parameterised):
if x is None:
x = np.zeros((1,1))
else:
assert x.shape == (1,1), "The shape fo the fixed variable x is not (1,D)"
x = np.asarray(x)
assert x.size == 1, "The size of the fixed variable x is not 1"
x = x.reshape((1,1))
if plot_limits == None:
xmin, xmax = (x-5).flatten(), (x+5).flatten()
@ -392,17 +394,20 @@ class kern(parameterised):
else:
raise ValueError, "Bad limits for plotting"
Xnew = np.linspace(xmin,xmax,resolution or 200)[:,None]
Kx = self.K(Xnew,x)
Xnew = np.linspace(xmin,xmax,resolution or 201)[:,None]
Kx = self.K(Xnew,x,slices2=which_functions)
pb.plot(Xnew,Kx)
pb.xlim(xmin,xmax)
pb.ylim(Kx.min() - (Kx.max()-Kx.min())*0.15,Kx.max() + (Kx.max()-Kx.min())*0.15)
pb.xlabel("x")
pb.ylabel("k(x,%0.1f)" %x)
elif self.D == 2:
if x is None:
x = np.zeros((1,2))
else:
assert x.shape == (1,2), "The shape fo the fixed variable x is not (1,D)"
x = np.asarray(x)
assert x.size == 2, "The size of the fixed variable x is not 2"
x = x.reshape((1,2))
if plot_limits == None:
xmin, xmax = (x-5).flatten(), (x+5).flatten()
@ -411,14 +416,18 @@ class kern(parameterised):
else:
raise ValueError, "Bad limits for plotting"
resolution = resolution or 50
resolution = resolution or 51
xx,yy = np.mgrid[xmin[0]:xmax[0]:1j*resolution,xmin[1]:xmax[1]:1j*resolution]
xg = np.linspace(xmin[0],xmax[0],resolution)
yg = np.linspace(xmin[1],xmax[1],resolution)
Xnew = np.vstack((xx.flatten(),yy.flatten())).T
Kx = self.K(Xnew,x)
Kx = Kx.reshape(resolution,resolution)
pb.contour(xx,yy,zz,vmin=zz.min(),vmax=zz.max(),cmap=pb.cm.jet)
pb.scatter(Xorig[:,0],Xorig[:,1],40,Yorig,linewidth=0,cmap=pb.cm.jet,vmin=zz.min(),vmax=zz.max())
Kx = self.K(Xnew,x,slices2=which_functions)
Kx = Kx.reshape(resolution,resolution).T
pb.contour(xg,yg,Kx,vmin=Kx.min(),vmax=Kx.max(),cmap=pb.cm.jet)
pb.xlim(xmin[0],xmax[0])
pb.ylim(xmin[1],xmax[1])
pb.xlabel("x1")
pb.ylabel("x2")
pb.title("k(x1,x2 ; %0.1f,%0.1f)" %(x[0,0],x[0,1]) )
else:
raise NotImplementedError, "Cannot plot a kernel with more than two input dimensions"