mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-12 05:22:38 +02:00
Many modifications in GP plots to make it work
This commit is contained in:
parent
4dce1a428f
commit
5b19d8609a
3 changed files with 63 additions and 17 deletions
|
|
@ -3,10 +3,11 @@
|
|||
|
||||
|
||||
import numpy as np
|
||||
import pylab as pb
|
||||
from ..core.parameterised import parameterised
|
||||
from kernpart import kernpart
|
||||
import itertools
|
||||
from product_orthogonal import product_orthogonal
|
||||
from product_orthogonal import product_orthogonal
|
||||
|
||||
class kern(parameterised):
|
||||
def __init__(self,D,parts=[], input_slices=None):
|
||||
|
|
@ -372,3 +373,50 @@ class kern(parameterised):
|
|||
|
||||
#TODO: there are some extra terms to compute here!
|
||||
return target_mu, target_S
|
||||
|
||||
def plot(self, x = None, plot_limits=None,which_functions='all',resolution=None):
|
||||
if which_functions=='all':
|
||||
which_functions = [True]*self.Nparts
|
||||
if self.D == 1:
|
||||
if x is None:
|
||||
x = np.zeros((1,1))
|
||||
else:
|
||||
assert x.shape == (1,1), "The shape fo the fixed variable x is not (1,D)"
|
||||
|
||||
if plot_limits == None:
|
||||
xmin, xmax = (x-5).flatten(), (x+5).flatten()
|
||||
elif len(plot_limits) == 2:
|
||||
xmin, xmax = plot_limits
|
||||
else:
|
||||
raise ValueError, "Bad limits for plotting"
|
||||
|
||||
Xnew = np.linspace(xmin,xmax,resolution or 200)[:,None]
|
||||
Kx = self.K(Xnew,x)
|
||||
pb.plot(Xnew,Kx)
|
||||
pb.xlim(xmin,xmax)
|
||||
pb.ylim(Kx.min() - (Kx.max()-Kx.min())*0.15,Kx.max() + (Kx.max()-Kx.min())*0.15)
|
||||
|
||||
elif self.D == 2:
|
||||
if x is None:
|
||||
x = np.zeros((1,2))
|
||||
else:
|
||||
assert x.shape == (1,2), "The shape fo the fixed variable x is not (1,D)"
|
||||
|
||||
if plot_limits == None:
|
||||
xmin, xmax = (x-5).flatten(), (x+5).flatten()
|
||||
elif len(plot_limits) == 2:
|
||||
xmin, xmax = plot_limits
|
||||
else:
|
||||
raise ValueError, "Bad limits for plotting"
|
||||
|
||||
resolution = resolution or 50
|
||||
xx,yy = np.mgrid[xmin[0]:xmax[0]:1j*resolution,xmin[1]:xmax[1]:1j*resolution]
|
||||
Xnew = np.vstack((xx.flatten(),yy.flatten())).T
|
||||
Kx = self.K(Xnew,x)
|
||||
Kx = Kx.reshape(resolution,resolution)
|
||||
pb.contour(xx,yy,zz,vmin=zz.min(),vmax=zz.max(),cmap=pb.cm.jet)
|
||||
pb.scatter(Xorig[:,0],Xorig[:,1],40,Yorig,linewidth=0,cmap=pb.cm.jet,vmin=zz.min(),vmax=zz.max())
|
||||
pb.xlim(xmin[0],xmax[0])
|
||||
pb.ylim(xmin[1],xmax[1])
|
||||
else:
|
||||
raise NotImplementedError, "Cannot plot a kernel with more than two input dimensions"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue