added an ARD option to the sympy RBF kern

This commit is contained in:
James Hensman 2012-12-21 11:40:22 +00:00
parent cc3ccf1ddc
commit e0384902ff

View file

@ -176,18 +176,24 @@ import sympy as sp
from sympykern import spkern
from sympy.parsing.sympy_parser import parse_expr
def rbf_sympy(D,variance=1., lengthscale=1.):
def rbf_sympy(D,ARD=False,variance=1., lengthscale=1.):
"""
Radial Basis Function covariance.
"""
X = [sp.var('x%i'%i) for i in range(D)]
Z = [sp.var('z%i'%i) for i in range(D)]
rbf_variance = sp.var('rbf_variance',positive=True)
rbf_lengthscale = sp.var('rbf_lengthscale',positive=True)
dist_string = ' + '.join(['(x%i-z%i)**2'%(i,i) for i in range(D)])
dist = parse_expr(dist_string)
f = rbf_variance*sp.exp(-dist/(2*rbf_lengthscale**2))
return kern(D,[spkern(D,f,np.array([variance,lengthscale]))])
if ARD:
rbf_lengthscales = [sp.var('rbf_lengthscale_%i'%i,positive=True) for i in range(D)]
dist_string = ' + '.join(['(x%i-z%i)**2/rbf_lengthscale_%i**2'%(i,i,i) for i in range(D)])
dist = parse_expr(dist_string)
f = rbf_variance*sp.exp(-dist/2.)
else:
rbf_lengthscale = sp.var('rbf_lengthscale',positive=True)
dist_string = ' + '.join(['(x%i-z%i)**2'%(i,i) for i in range(D)])
dist = parse_expr(dist_string)
f = rbf_variance*sp.exp(-dist/(2*rbf_lengthscale**2))
return kern(D,[spkern(D,f)])
def sympykern(D,k):
"""