mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-08 11:32:39 +02:00
added an ARD option to the sympy RBF kern
This commit is contained in:
parent
cc3ccf1ddc
commit
e0384902ff
1 changed files with 12 additions and 6 deletions
|
|
@ -176,18 +176,24 @@ import sympy as sp
|
|||
from sympykern import spkern
|
||||
from sympy.parsing.sympy_parser import parse_expr
|
||||
|
||||
def rbf_sympy(D,variance=1., lengthscale=1.):
|
||||
def rbf_sympy(D,ARD=False,variance=1., lengthscale=1.):
|
||||
"""
|
||||
Radial Basis Function covariance.
|
||||
"""
|
||||
X = [sp.var('x%i'%i) for i in range(D)]
|
||||
Z = [sp.var('z%i'%i) for i in range(D)]
|
||||
rbf_variance = sp.var('rbf_variance',positive=True)
|
||||
rbf_lengthscale = sp.var('rbf_lengthscale',positive=True)
|
||||
dist_string = ' + '.join(['(x%i-z%i)**2'%(i,i) for i in range(D)])
|
||||
dist = parse_expr(dist_string)
|
||||
f = rbf_variance*sp.exp(-dist/(2*rbf_lengthscale**2))
|
||||
return kern(D,[spkern(D,f,np.array([variance,lengthscale]))])
|
||||
if ARD:
|
||||
rbf_lengthscales = [sp.var('rbf_lengthscale_%i'%i,positive=True) for i in range(D)]
|
||||
dist_string = ' + '.join(['(x%i-z%i)**2/rbf_lengthscale_%i**2'%(i,i,i) for i in range(D)])
|
||||
dist = parse_expr(dist_string)
|
||||
f = rbf_variance*sp.exp(-dist/2.)
|
||||
else:
|
||||
rbf_lengthscale = sp.var('rbf_lengthscale',positive=True)
|
||||
dist_string = ' + '.join(['(x%i-z%i)**2'%(i,i) for i in range(D)])
|
||||
dist = parse_expr(dist_string)
|
||||
f = rbf_variance*sp.exp(-dist/(2*rbf_lengthscale**2))
|
||||
return kern(D,[spkern(D,f)])
|
||||
|
||||
def sympykern(D,k):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue