mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-06 18:42:39 +02:00
Added new plotting function for kernels
This commit is contained in:
parent
b0f6495ed4
commit
8fd79f6eee
1 changed files with 19 additions and 10 deletions
|
|
@ -383,7 +383,9 @@ class kern(parameterised):
|
|||
if x is None:
|
||||
x = np.zeros((1,1))
|
||||
else:
|
||||
assert x.shape == (1,1), "The shape fo the fixed variable x is not (1,D)"
|
||||
x = np.asarray(x)
|
||||
assert x.size == 1, "The size of the fixed variable x is not 1"
|
||||
x = x.reshape((1,1))
|
||||
|
||||
if plot_limits == None:
|
||||
xmin, xmax = (x-5).flatten(), (x+5).flatten()
|
||||
|
|
@ -392,17 +394,20 @@ class kern(parameterised):
|
|||
else:
|
||||
raise ValueError, "Bad limits for plotting"
|
||||
|
||||
Xnew = np.linspace(xmin,xmax,resolution or 200)[:,None]
|
||||
Kx = self.K(Xnew,x)
|
||||
Xnew = np.linspace(xmin,xmax,resolution or 201)[:,None]
|
||||
Kx = self.K(Xnew,x,slices2=which_functions)
|
||||
pb.plot(Xnew,Kx)
|
||||
pb.xlim(xmin,xmax)
|
||||
pb.ylim(Kx.min() - (Kx.max()-Kx.min())*0.15,Kx.max() + (Kx.max()-Kx.min())*0.15)
|
||||
pb.xlabel("x")
|
||||
pb.ylabel("k(x,%0.1f)" %x)
|
||||
|
||||
elif self.D == 2:
|
||||
if x is None:
|
||||
x = np.zeros((1,2))
|
||||
else:
|
||||
assert x.shape == (1,2), "The shape fo the fixed variable x is not (1,D)"
|
||||
x = np.asarray(x)
|
||||
assert x.size == 2, "The size of the fixed variable x is not 2"
|
||||
x = x.reshape((1,2))
|
||||
|
||||
if plot_limits == None:
|
||||
xmin, xmax = (x-5).flatten(), (x+5).flatten()
|
||||
|
|
@ -411,14 +416,18 @@ class kern(parameterised):
|
|||
else:
|
||||
raise ValueError, "Bad limits for plotting"
|
||||
|
||||
resolution = resolution or 50
|
||||
resolution = resolution or 51
|
||||
xx,yy = np.mgrid[xmin[0]:xmax[0]:1j*resolution,xmin[1]:xmax[1]:1j*resolution]
|
||||
xg = np.linspace(xmin[0],xmax[0],resolution)
|
||||
yg = np.linspace(xmin[1],xmax[1],resolution)
|
||||
Xnew = np.vstack((xx.flatten(),yy.flatten())).T
|
||||
Kx = self.K(Xnew,x)
|
||||
Kx = Kx.reshape(resolution,resolution)
|
||||
pb.contour(xx,yy,zz,vmin=zz.min(),vmax=zz.max(),cmap=pb.cm.jet)
|
||||
pb.scatter(Xorig[:,0],Xorig[:,1],40,Yorig,linewidth=0,cmap=pb.cm.jet,vmin=zz.min(),vmax=zz.max())
|
||||
Kx = self.K(Xnew,x,slices2=which_functions)
|
||||
Kx = Kx.reshape(resolution,resolution).T
|
||||
pb.contour(xg,yg,Kx,vmin=Kx.min(),vmax=Kx.max(),cmap=pb.cm.jet)
|
||||
pb.xlim(xmin[0],xmax[0])
|
||||
pb.ylim(xmin[1],xmax[1])
|
||||
pb.xlabel("x1")
|
||||
pb.ylabel("x2")
|
||||
pb.title("k(x1,x2 ; %0.1f,%0.1f)" %(x[0,0],x[0,1]) )
|
||||
else:
|
||||
raise NotImplementedError, "Cannot plot a kernel with more than two input dimensions"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue