diff --git a/GPy/kern/kern.py b/GPy/kern/kern.py index c01ba815..d12febbb 100644 --- a/GPy/kern/kern.py +++ b/GPy/kern/kern.py @@ -3,10 +3,11 @@ import numpy as np +import pylab as pb from ..core.parameterised import parameterised from kernpart import kernpart import itertools -from product_orthogonal import product_orthogonal +from product_orthogonal import product_orthogonal class kern(parameterised): def __init__(self,D,parts=[], input_slices=None): @@ -386,3 +387,59 @@ class kern(parameterised): #TODO: there are some extra terms to compute here! return target_mu, target_S + + def plot(self, x = None, plot_limits=None,which_functions='all',resolution=None,*args,**kwargs): + if which_functions=='all': + which_functions = [True]*self.Nparts + if self.D == 1: + if x is None: + x = np.zeros((1,1)) + else: + x = np.asarray(x) + assert x.size == 1, "The size of the fixed variable x is not 1" + x = x.reshape((1,1)) + + if plot_limits == None: + xmin, xmax = (x-5).flatten(), (x+5).flatten() + elif len(plot_limits) == 2: + xmin, xmax = plot_limits + else: + raise ValueError, "Bad limits for plotting" + + Xnew = np.linspace(xmin,xmax,resolution or 201)[:,None] + Kx = self.K(Xnew,x,slices2=which_functions) + pb.plot(Xnew,Kx,*args,**kwargs) + pb.xlim(xmin,xmax) + pb.xlabel("x") + pb.ylabel("k(x,%0.1f)" %x) + + elif self.D == 2: + if x is None: + x = np.zeros((1,2)) + else: + x = np.asarray(x) + assert x.size == 2, "The size of the fixed variable x is not 2" + x = x.reshape((1,2)) + + if plot_limits == None: + xmin, xmax = (x-5).flatten(), (x+5).flatten() + elif len(plot_limits) == 2: + xmin, xmax = plot_limits + else: + raise ValueError, "Bad limits for plotting" + + resolution = resolution or 51 + xx,yy = np.mgrid[xmin[0]:xmax[0]:1j*resolution,xmin[1]:xmax[1]:1j*resolution] + xg = np.linspace(xmin[0],xmax[0],resolution) + yg = np.linspace(xmin[1],xmax[1],resolution) + Xnew = np.vstack((xx.flatten(),yy.flatten())).T + Kx = self.K(Xnew,x,slices2=which_functions) + Kx = Kx.reshape(resolution,resolution).T + pb.contour(xg,yg,Kx,vmin=Kx.min(),vmax=Kx.max(),cmap=pb.cm.jet) + pb.xlim(xmin[0],xmax[0]) + pb.ylim(xmin[1],xmax[1]) + pb.xlabel("x1") + pb.ylabel("x2") + pb.title("k(x1,x2 ; %0.1f,%0.1f)" %(x[0,0],x[0,1]) ) + else: + raise NotImplementedError, "Cannot plot a kernel with more than two input dimensions"