From 8fd79f6eee2212aaac02f7afd351df7ab5066625 Mon Sep 17 00:00:00 2001 From: Nicolas Date: Wed, 6 Feb 2013 15:20:04 +0000 Subject: [PATCH] Added new plotting function for kernels --- GPy/kern/kern.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/GPy/kern/kern.py b/GPy/kern/kern.py index 5e9273ae..be382d11 100644 --- a/GPy/kern/kern.py +++ b/GPy/kern/kern.py @@ -383,7 +383,9 @@ class kern(parameterised): if x is None: x = np.zeros((1,1)) else: - assert x.shape == (1,1), "The shape fo the fixed variable x is not (1,D)" + x = np.asarray(x) + assert x.size == 1, "The size of the fixed variable x is not 1" + x = x.reshape((1,1)) if plot_limits == None: xmin, xmax = (x-5).flatten(), (x+5).flatten() @@ -392,17 +394,20 @@ class kern(parameterised): else: raise ValueError, "Bad limits for plotting" - Xnew = np.linspace(xmin,xmax,resolution or 200)[:,None] - Kx = self.K(Xnew,x) + Xnew = np.linspace(xmin,xmax,resolution or 201)[:,None] + Kx = self.K(Xnew,x,slices2=which_functions) pb.plot(Xnew,Kx) pb.xlim(xmin,xmax) - pb.ylim(Kx.min() - (Kx.max()-Kx.min())*0.15,Kx.max() + (Kx.max()-Kx.min())*0.15) + pb.xlabel("x") + pb.ylabel("k(x,%0.1f)" %x) elif self.D == 2: if x is None: x = np.zeros((1,2)) else: - assert x.shape == (1,2), "The shape fo the fixed variable x is not (1,D)" + x = np.asarray(x) + assert x.size == 2, "The size of the fixed variable x is not 2" + x = x.reshape((1,2)) if plot_limits == None: xmin, xmax = (x-5).flatten(), (x+5).flatten() @@ -411,14 +416,18 @@ class kern(parameterised): else: raise ValueError, "Bad limits for plotting" - resolution = resolution or 50 + resolution = resolution or 51 xx,yy = np.mgrid[xmin[0]:xmax[0]:1j*resolution,xmin[1]:xmax[1]:1j*resolution] + xg = np.linspace(xmin[0],xmax[0],resolution) + yg = np.linspace(xmin[1],xmax[1],resolution) Xnew = np.vstack((xx.flatten(),yy.flatten())).T - Kx = self.K(Xnew,x) - Kx = Kx.reshape(resolution,resolution) - pb.contour(xx,yy,zz,vmin=zz.min(),vmax=zz.max(),cmap=pb.cm.jet) - pb.scatter(Xorig[:,0],Xorig[:,1],40,Yorig,linewidth=0,cmap=pb.cm.jet,vmin=zz.min(),vmax=zz.max()) + Kx = self.K(Xnew,x,slices2=which_functions) + Kx = Kx.reshape(resolution,resolution).T + pb.contour(xg,yg,Kx,vmin=Kx.min(),vmax=Kx.max(),cmap=pb.cm.jet) pb.xlim(xmin[0],xmax[0]) pb.ylim(xmin[1],xmax[1]) + pb.xlabel("x1") + pb.ylabel("x2") + pb.title("k(x1,x2 ; %0.1f,%0.1f)" %(x[0,0],x[0,1]) ) else: raise NotImplementedError, "Cannot plot a kernel with more than two input dimensions"