diff --git a/GPy/kern/kern.py b/GPy/kern/kern.py index 89def0e5..a00a20e5 100644 --- a/GPy/kern/kern.py +++ b/GPy/kern/kern.py @@ -3,10 +3,11 @@ import numpy as np +import pylab as pb from ..core.parameterised import parameterised from kernpart import kernpart import itertools -from product_orthogonal import product_orthogonal +from product_orthogonal import product_orthogonal class kern(parameterised): def __init__(self,D,parts=[], input_slices=None): @@ -372,3 +373,50 @@ class kern(parameterised): #TODO: there are some extra terms to compute here! return target_mu, target_S + + def plot(self, x = None, plot_limits=None,which_functions='all',resolution=None): + if which_functions=='all': + which_functions = [True]*self.Nparts + if self.D == 1: + if x is None: + x = np.zeros((1,1)) + else: + assert x.shape == (1,1), "The shape fo the fixed variable x is not (1,D)" + + if plot_limits == None: + xmin, xmax = (x-5).flatten(), (x+5).flatten() + elif len(plot_limits) == 2: + xmin, xmax = plot_limits + else: + raise ValueError, "Bad limits for plotting" + + Xnew = np.linspace(xmin,xmax,resolution or 200)[:,None] + Kx = self.K(Xnew,x) + pb.plot(Xnew,Kx) + pb.xlim(xmin,xmax) + pb.ylim(Kx.min() - (Kx.max()-Kx.min())*0.15,Kx.max() + (Kx.max()-Kx.min())*0.15) + + elif self.D == 2: + if x is None: + x = np.zeros((1,2)) + else: + assert x.shape == (1,2), "The shape fo the fixed variable x is not (1,D)" + + if plot_limits == None: + xmin, xmax = (x-5).flatten(), (x+5).flatten() + elif len(plot_limits) == 2: + xmin, xmax = plot_limits + else: + raise ValueError, "Bad limits for plotting" + + resolution = resolution or 50 + xx,yy = np.mgrid[xmin[0]:xmax[0]:1j*resolution,xmin[1]:xmax[1]:1j*resolution] + Xnew = np.vstack((xx.flatten(),yy.flatten())).T + Kx = self.K(Xnew,x) + Kx = Kx.reshape(resolution,resolution) + pb.contour(xx,yy,zz,vmin=zz.min(),vmax=zz.max(),cmap=pb.cm.jet) + pb.scatter(Xorig[:,0],Xorig[:,1],40,Yorig,linewidth=0,cmap=pb.cm.jet,vmin=zz.min(),vmax=zz.max()) + pb.xlim(xmin[0],xmax[0]) + pb.ylim(xmin[1],xmax[1]) + else: + raise NotImplementedError, "Cannot plot a kernel with more than two input dimensions" diff --git a/GPy/models/GP.py b/GPy/models/GP.py index 2afa4252..b482ab89 100644 --- a/GPy/models/GP.py +++ b/GPy/models/GP.py @@ -206,13 +206,13 @@ class GP(model): gpplot(Xnew,m,m-np.sqrt(v),m+np.sqrt(v)) pb.plot(self.X[which_data],self.likelihood.Y[which_data],'kx',mew=1.5) pb.xlim(xmin,xmax) - elif X.shape[1]==2: + elif self.X.shape[1]==2: resolution = resolution or 50 - Xnew, xmin, xmax,xx,yy = x_frame2D(self.X, plot_limits=plot_limits) + Xnew, xmin, xmax,xx,yy = x_frame2D(self.X, plot_limits,resolution) m,v = self._raw_predict(Xnew, slices=which_functions) - m = m.reshape(resolution,resolution) - pb.contour(xx,yy,zz,vmin=zz.min(),vmax=zz.max(),cmap=pb.cm.jet) - pb.scatter(Xorig[:,0],Xorig[:,1],40,Yorig,linewidth=0,cmap=pb.cm.jet,vmin=zz.min(),vmax=zz.max()) + m = m.reshape(resolution,resolution).T + pb.contour(xx,yy,m,vmin=m.min(),vmax=m.max(),cmap=pb.cm.jet) + pb.scatter(Xorig[:,0],Xorig[:,1],40,Yorig,linewidth=0,cmap=pb.cm.jet,vmin=m.min(), vmax=m.max()) pb.xlim(xmin[0],xmax[0]) pb.ylim(xmin[1],xmax[1]) else: @@ -232,17 +232,16 @@ class GP(model): ymin,ymax = self.likelihood.data.min()*1.2,self.likelihood.data.max()*1.2 pb.xlim(xmin,xmax) pb.ylim(ymin,ymax) - elif X.shape[1]==2: + elif self.X.shape[1]==2: resolution = resolution or 50 - Xnew, xmin, xmax,xx,yy = x_frame2D(self.X, plot_limits=plot_limits) - m,v = self.predict(Xnew, slices=which_functions) - m = m.reshape(resolution,resolution) - pb.contour(xx,yy,zz,vmin=zz.min(),vmax=zz.max(),cmap=pb.cm.jet) - pb.scatter(Xorig[:,0],Xorig[:,1],40,Yorig,linewidth=0,cmap=pb.cm.jet,vmin=zz.min(),vmax=zz.max()) + Xnew, xx, yy, xmin, xmax = x_frame2D(self.X, plot_limits,resolution) + x, y = np.linspace(xmin[0],xmax[0],resolution), np.linspace(xmin[1],xmax[1],resolution) + m,lower,upper = self.predict(Xnew, slices=which_functions) + m = m.reshape(resolution,resolution).T + pb.contour(x,y,m,vmin=m.min(),vmax=m.max(),cmap=pb.cm.jet) + Yf = self.likelihood.Y.flatten() + pb.scatter(self.X[:,0], self.X[:,1], 40, Yf, cmap=pb.cm.jet,vmin=m.min(),vmax=m.max(), linewidth=0.) pb.xlim(xmin[0],xmax[0]) pb.ylim(xmin[1],xmax[1]) else: raise NotImplementedError, "Cannot define a frame with more than two input dimensions" - - - diff --git a/GPy/util/plot.py b/GPy/util/plot.py index 7b346330..8e71764d 100644 --- a/GPy/util/plot.py +++ b/GPy/util/plot.py @@ -102,5 +102,4 @@ def x_frame2D(X,plot_limits=None,resolution=None): resolution = resolution or 50 xx,yy = np.mgrid[xmin[0]:xmax[0]:1j*resolution,xmin[1]:xmax[1]:1j*resolution] Xnew = np.vstack((xx.flatten(),yy.flatten())).T - return Xnew, xx,yy,xmin, xmax - + return Xnew, xx, yy, xmin, xmax