diff --git a/GPy/kern/parts/rbf_inv.py b/GPy/kern/parts/rbf_inv.py new file mode 100644 index 00000000..35b6e1a9 --- /dev/null +++ b/GPy/kern/parts/rbf_inv.py @@ -0,0 +1,334 @@ +# Copyright (c) 2012, GPy authors (see AUTHORS.txt). +# Licensed under the BSD 3-clause license (see LICENSE.txt) + + +from kernpart import Kernpart +import numpy as np +import hashlib +from scipy import weave +from ...util.linalg import tdot + +class RBFInv(Kernpart): + """ + Radial Basis Function kernel, aka squared-exponential, exponentiated quadratic or Gaussian kernel: + + .. math:: + + k(r) = \sigma^2 \exp \\bigg(- \\frac{1}{2} r^2 \\bigg) \ \ \ \ \ \\text{ where } r^2 = \sum_{i=1}^d \\frac{ (x_i-x^\prime_i)^2}{\ell_i^2} + + where \ell_i is the lengthscale, \sigma^2 the variance and d the dimensionality of the input. + + :param input_dim: the number of input dimensions + :type input_dim: int + :param variance: the variance of the kernel + :type variance: float + :param lengthscale: the vector of lengthscale of the kernel + :type lengthscale: array or list of the appropriate size (or float if there is only one lengthscale parameter) + :param ARD: Auto Relevance Determination. If equal to "False", the kernel is isotropic (ie. one single lengthscale parameter \ell), otherwise there is one lengthscale parameter per dimension. + :type ARD: Boolean + :rtype: kernel object + + .. Note: this object implements both the ARD and 'spherical' version of the function + """ + + def __init__(self, input_dim, variance=1., inv_lengthscale=None, ARD=False): + self.input_dim = input_dim + self.name = 'rbf' + self.ARD = ARD + if not ARD: + self.num_params = 2 + if inv_lengthscale is not None: + inv_lengthscale = np.asarray(inv_lengthscale) + assert inv_lengthscale.size == 1, "Only one lengthscale needed for non-ARD kernel" + else: + inv_lengthscale = np.ones(1) + else: + self.num_params = self.input_dim + 1 + if inv_lengthscale is not None: + inv_lengthscale = np.asarray(inv_lengthscale) + assert inv_lengthscale.size == self.input_dim, "bad number of lengthscales" + else: + inv_lengthscale = np.ones(self.input_dim) + + self._set_params(np.hstack((variance, inv_lengthscale.flatten()))) + + # initialize cache + self._Z, self._mu, self._S = np.empty(shape=(3, 1)) + self._X, self._X2, self._params = np.empty(shape=(3, 1)) + + # a set of optional args to pass to weave + self.weave_options = {'headers' : [''], + 'extra_compile_args': ['-fopenmp -O3'], # -march=native'], + 'extra_link_args' : ['-lgomp']} + + + + def _get_params(self): + return np.hstack((self.variance, self.inv_lengthscale)) + + def _set_params(self, x): + assert x.size == (self.num_params) + self.variance = x[0] + self.inv_lengthscale = x[1:] + self.lengthscale = 1./self.inv_lengthscale + self.lengthscale2 = np.square(self.lengthscale) + # reset cached results + self._X, self._X2, self._params = np.empty(shape=(3, 1)) + self._Z, self._mu, self._S = np.empty(shape=(3, 1)) # cached versions of Z,mu,S + + def _get_param_names(self): + if self.num_params == 2: + return ['variance', 'inv_lengthscale'] + else: + return ['variance'] + ['inv_lengthscale_%i' % i for i in range(self.inv_lengthscale.size)] + + def K(self, X, X2, target): + self._K_computations(X, X2) + target += self.variance * self._K_dvar + + def Kdiag(self, X, target): + np.add(target, self.variance, target) + + def dK_dtheta(self, dL_dK, X, X2, target): + self._K_computations(X, X2) + target[0] += np.sum(self._K_dvar * dL_dK) + if self.ARD: + dvardLdK = self._K_dvar * dL_dK + var_len3 = self.variance / np.power(self.lengthscale, 3) + if X2 is None: + # save computation for the symmetrical case + dvardLdK = dvardLdK + dvardLdK.T + code = """ + int q,i,j; + double tmp; + for(q=0; q + #include + """ + weave.inline(code, support_code=support_code, libraries=['gomp'], + arg_names=['N','num_inducing','input_dim','mu','Zhat','mudist_sq','mudist','lengthscale2','_psi2_denom','psi2_Zdist_sq','psi2_exponent','half_log_psi2_denom','psi2','variance_sq'], + type_converters=weave.converters.blitz, **self.weave_options) + + return mudist, mudist_sq, psi2_exponent, psi2