mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-08 11:32:39 +02:00
added an ARD option to the sympy RBF kern
This commit is contained in:
parent
cc3ccf1ddc
commit
e0384902ff
1 changed files with 12 additions and 6 deletions
|
|
@ -176,18 +176,24 @@ import sympy as sp
|
||||||
from sympykern import spkern
|
from sympykern import spkern
|
||||||
from sympy.parsing.sympy_parser import parse_expr
|
from sympy.parsing.sympy_parser import parse_expr
|
||||||
|
|
||||||
def rbf_sympy(D,variance=1., lengthscale=1.):
|
def rbf_sympy(D,ARD=False,variance=1., lengthscale=1.):
|
||||||
"""
|
"""
|
||||||
Radial Basis Function covariance.
|
Radial Basis Function covariance.
|
||||||
"""
|
"""
|
||||||
X = [sp.var('x%i'%i) for i in range(D)]
|
X = [sp.var('x%i'%i) for i in range(D)]
|
||||||
Z = [sp.var('z%i'%i) for i in range(D)]
|
Z = [sp.var('z%i'%i) for i in range(D)]
|
||||||
rbf_variance = sp.var('rbf_variance',positive=True)
|
rbf_variance = sp.var('rbf_variance',positive=True)
|
||||||
rbf_lengthscale = sp.var('rbf_lengthscale',positive=True)
|
if ARD:
|
||||||
dist_string = ' + '.join(['(x%i-z%i)**2'%(i,i) for i in range(D)])
|
rbf_lengthscales = [sp.var('rbf_lengthscale_%i'%i,positive=True) for i in range(D)]
|
||||||
dist = parse_expr(dist_string)
|
dist_string = ' + '.join(['(x%i-z%i)**2/rbf_lengthscale_%i**2'%(i,i,i) for i in range(D)])
|
||||||
f = rbf_variance*sp.exp(-dist/(2*rbf_lengthscale**2))
|
dist = parse_expr(dist_string)
|
||||||
return kern(D,[spkern(D,f,np.array([variance,lengthscale]))])
|
f = rbf_variance*sp.exp(-dist/2.)
|
||||||
|
else:
|
||||||
|
rbf_lengthscale = sp.var('rbf_lengthscale',positive=True)
|
||||||
|
dist_string = ' + '.join(['(x%i-z%i)**2'%(i,i) for i in range(D)])
|
||||||
|
dist = parse_expr(dist_string)
|
||||||
|
f = rbf_variance*sp.exp(-dist/(2*rbf_lengthscale**2))
|
||||||
|
return kern(D,[spkern(D,f)])
|
||||||
|
|
||||||
def sympykern(D,k):
|
def sympykern(D,k):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue