diff --git a/GPy/kern/constructors.py b/GPy/kern/constructors.py index 0dd579cc..0ddc09e3 100644 --- a/GPy/kern/constructors.py +++ b/GPy/kern/constructors.py @@ -176,18 +176,24 @@ import sympy as sp from sympykern import spkern from sympy.parsing.sympy_parser import parse_expr -def rbf_sympy(D,variance=1., lengthscale=1.): +def rbf_sympy(D,ARD=False,variance=1., lengthscale=1.): """ Radial Basis Function covariance. """ X = [sp.var('x%i'%i) for i in range(D)] Z = [sp.var('z%i'%i) for i in range(D)] rbf_variance = sp.var('rbf_variance',positive=True) - rbf_lengthscale = sp.var('rbf_lengthscale',positive=True) - dist_string = ' + '.join(['(x%i-z%i)**2'%(i,i) for i in range(D)]) - dist = parse_expr(dist_string) - f = rbf_variance*sp.exp(-dist/(2*rbf_lengthscale**2)) - return kern(D,[spkern(D,f,np.array([variance,lengthscale]))]) + if ARD: + rbf_lengthscales = [sp.var('rbf_lengthscale_%i'%i,positive=True) for i in range(D)] + dist_string = ' + '.join(['(x%i-z%i)**2/rbf_lengthscale_%i**2'%(i,i,i) for i in range(D)]) + dist = parse_expr(dist_string) + f = rbf_variance*sp.exp(-dist/2.) + else: + rbf_lengthscale = sp.var('rbf_lengthscale',positive=True) + dist_string = ' + '.join(['(x%i-z%i)**2'%(i,i) for i in range(D)]) + dist = parse_expr(dist_string) + f = rbf_variance*sp.exp(-dist/(2*rbf_lengthscale**2)) + return kern(D,[spkern(D,f)]) def sympykern(D,k): """