Merge branch 'newGP'

Conflicts:
	GPy/models/GP_regression.py
This commit is contained in:
James Hensman 2013-02-04 12:41:22 +00:00
commit 687631f719
23 changed files with 1622 additions and 1138 deletions

View file

@ -6,30 +6,26 @@ import Tango
import pylab as pb
import numpy as np
def gpplot(x,mu,var,edgecol=Tango.coloursHex['darkBlue'],fillcol=Tango.coloursHex['lightBlue'],axes=None,**kwargs):
def gpplot(x,mu,lower,upper,edgecol=Tango.coloursHex['darkBlue'],fillcol=Tango.coloursHex['lightBlue'],axes=None,**kwargs):
if axes is None:
axes = pb.gca()
mu = mu.flatten()
x = x.flatten()
lower = lower.flatten()
upper = upper.flatten()
#here's the mean
axes.plot(x,mu,color=edgecol,linewidth=2)
#ensure variance is a vector
if len(var.shape)>1:
err = 2*np.sqrt(np.diag(var))
else:
err = 2*np.sqrt(var)
#here's the 2*std box
#here's the box
kwargs['linewidth']=0.5
if not 'alpha' in kwargs.keys():
kwargs['alpha'] = 0.3
axes.fill(np.hstack((x,x[::-1])),np.hstack((mu+err,mu[::-1]-err[::-1])),color=fillcol,**kwargs)
axes.fill(np.hstack((x,x[::-1])),np.hstack((upper,lower[::-1])),color=fillcol,**kwargs)
#this is the edge:
axes.plot(x,mu+err,color=edgecol,linewidth=0.2)
axes.plot(x,mu-err,color=edgecol,linewidth=0.2)
axes.plot(x,upper,color=edgecol,linewidth=0.2)
axes.plot(x,lower,color=edgecol,linewidth=0.2)
def removeRightTicks(ax=None):
ax = ax or pb.gca()
@ -74,4 +70,37 @@ def align_subplots(N,M,xlim=None, ylim=None):
else:
removeUpperTicks()
def x_frame1D(X,plot_limits=None,resolution=None):
"""
Internal helper function for making plots, returns a set of input values to plot as well as lower and upper limits
"""
assert X.shape[1] ==1, "x_frame1D is defined for one-dimensional inputs"
if plot_limits is None:
xmin,xmax = X.min(0),X.max(0)
xmin, xmax = xmin-0.2*(xmax-xmin), xmax+0.2*(xmax-xmin)
elif len(plot_limits)==2:
xmin, xmax = plot_limits
else:
raise ValueError, "Bad limits for plotting"
Xnew = np.linspace(xmin,xmax,resolution or 200)[:,None]
return Xnew, xmin, xmax
def x_frame2D(X,plot_limits=None,resolution=None):
"""
Internal helper function for making plots, returns a set of input values to plot as well as lower and upper limits
"""
assert X.shape[1] ==2, "x_frame2D is defined for two-dimensional inputs"
if plot_limits is None:
xmin,xmax = X.min(0),X.max(0)
xmin, xmax = xmin-0.2*(xmax-xmin), xmax+0.2*(xmax-xmin)
elif len(plot_limits)==2:
xmin, xmax = plot_limits
else:
raise ValueError, "Bad limits for plotting"
resolution = resolution or 50
xx,yy = np.mgrid[xmin[0]:xmax[0]:1j*resolution,xmin[1]:xmax[1]:1j*resolution]
Xnew = np.vstack((xx.flatten(),yy.flatten())).T
return Xnew, xx,yy,xmin, xmax