mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-12 21:42:39 +02:00
reinsert the plot function for kernel that diserpered at one point...
This commit is contained in:
parent
7cd1c426f4
commit
336b15cb79
1 changed files with 58 additions and 1 deletions
|
|
@ -3,6 +3,7 @@
|
||||||
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pylab as pb
|
||||||
from ..core.parameterised import parameterised
|
from ..core.parameterised import parameterised
|
||||||
from kernpart import kernpart
|
from kernpart import kernpart
|
||||||
import itertools
|
import itertools
|
||||||
|
|
@ -386,3 +387,59 @@ class kern(parameterised):
|
||||||
|
|
||||||
#TODO: there are some extra terms to compute here!
|
#TODO: there are some extra terms to compute here!
|
||||||
return target_mu, target_S
|
return target_mu, target_S
|
||||||
|
|
||||||
|
def plot(self, x = None, plot_limits=None,which_functions='all',resolution=None,*args,**kwargs):
|
||||||
|
if which_functions=='all':
|
||||||
|
which_functions = [True]*self.Nparts
|
||||||
|
if self.D == 1:
|
||||||
|
if x is None:
|
||||||
|
x = np.zeros((1,1))
|
||||||
|
else:
|
||||||
|
x = np.asarray(x)
|
||||||
|
assert x.size == 1, "The size of the fixed variable x is not 1"
|
||||||
|
x = x.reshape((1,1))
|
||||||
|
|
||||||
|
if plot_limits == None:
|
||||||
|
xmin, xmax = (x-5).flatten(), (x+5).flatten()
|
||||||
|
elif len(plot_limits) == 2:
|
||||||
|
xmin, xmax = plot_limits
|
||||||
|
else:
|
||||||
|
raise ValueError, "Bad limits for plotting"
|
||||||
|
|
||||||
|
Xnew = np.linspace(xmin,xmax,resolution or 201)[:,None]
|
||||||
|
Kx = self.K(Xnew,x,slices2=which_functions)
|
||||||
|
pb.plot(Xnew,Kx,*args,**kwargs)
|
||||||
|
pb.xlim(xmin,xmax)
|
||||||
|
pb.xlabel("x")
|
||||||
|
pb.ylabel("k(x,%0.1f)" %x)
|
||||||
|
|
||||||
|
elif self.D == 2:
|
||||||
|
if x is None:
|
||||||
|
x = np.zeros((1,2))
|
||||||
|
else:
|
||||||
|
x = np.asarray(x)
|
||||||
|
assert x.size == 2, "The size of the fixed variable x is not 2"
|
||||||
|
x = x.reshape((1,2))
|
||||||
|
|
||||||
|
if plot_limits == None:
|
||||||
|
xmin, xmax = (x-5).flatten(), (x+5).flatten()
|
||||||
|
elif len(plot_limits) == 2:
|
||||||
|
xmin, xmax = plot_limits
|
||||||
|
else:
|
||||||
|
raise ValueError, "Bad limits for plotting"
|
||||||
|
|
||||||
|
resolution = resolution or 51
|
||||||
|
xx,yy = np.mgrid[xmin[0]:xmax[0]:1j*resolution,xmin[1]:xmax[1]:1j*resolution]
|
||||||
|
xg = np.linspace(xmin[0],xmax[0],resolution)
|
||||||
|
yg = np.linspace(xmin[1],xmax[1],resolution)
|
||||||
|
Xnew = np.vstack((xx.flatten(),yy.flatten())).T
|
||||||
|
Kx = self.K(Xnew,x,slices2=which_functions)
|
||||||
|
Kx = Kx.reshape(resolution,resolution).T
|
||||||
|
pb.contour(xg,yg,Kx,vmin=Kx.min(),vmax=Kx.max(),cmap=pb.cm.jet)
|
||||||
|
pb.xlim(xmin[0],xmax[0])
|
||||||
|
pb.ylim(xmin[1],xmax[1])
|
||||||
|
pb.xlabel("x1")
|
||||||
|
pb.ylabel("x2")
|
||||||
|
pb.title("k(x1,x2 ; %0.1f,%0.1f)" %(x[0,0],x[0,1]) )
|
||||||
|
else:
|
||||||
|
raise NotImplementedError, "Cannot plot a kernel with more than two input dimensions"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue