diff --git a/GPy/kern/src/poly.py b/GPy/kern/src/poly.py index 216e3a00..a3dbfed5 100644 --- a/GPy/kern/src/poly.py +++ b/GPy/kern/src/poly.py @@ -5,32 +5,49 @@ import numpy as np from .kern import Kern from ...core.parameterization import Param from paramz.transformations import Logexp +from paramz.caching import Cache_this class Poly(Kern): """ Polynomial kernel """ - def __init__(self, input_dim, variance=1., order=3., active_dims=None, name='poly'): + def __init__(self, input_dim, variance=1., scale=1., bias=1., order=3., active_dims=None, name='poly'): super(Poly, self).__init__(input_dim, active_dims, name) self.variance = Param('variance', variance, Logexp()) - self.link_parameter(self.variance) + self.scale = Param('scale', scale, Logexp()) + self.bias = Param('bias', bias, Logexp()) + + self.link_parameters(self.variance, self.scale, self.bias) + assert order >= 1, 'The order of the polynomial has to be at least 1.' self.order=order - def K(self, X, X2=None): - return (self._dot_product(X, X2) + 1.)**self.order * self.variance - def _dot_product(self, X, X2=None): + def K(self, X, X2=None): + _, _, B = self._AB(X, X2) + return B * self.variance + + #@Cache_this(limit=2) + def _AB(self, X, X2=None): if X2 is None: - return np.dot(X, X.T) + dot_prod = np.dot(X, X.T) else: - return np.dot(X, X2.T) + dot_prod = np.dot(X, X2.T) + A = (self.scale * dot_prod) + self.bias + B = A ** self.order + return dot_prod, A, B def Kdiag(self, X): - return self.variance*(np.square(X).sum(1) + 1.)**self.order + return self.K(X).diagonal()#self.variance*(np.square(X).sum(1) + 1.)**self.order def update_gradients_full(self, dL_dK, X, X2=None): - self.variance.gradient = np.sum(dL_dK * (self._dot_product(X, X2) + 1.)**self.order) + dot_prod, A, B = self._AB(X, X2) + dK_dA = self.variance * self.order * A ** (self.order-1.) + dL_dA = dL_dK * (dK_dA) + self.scale.gradient = (dL_dA * dot_prod).sum() + self.bias.gradient = dL_dA.sum() + self.variance.gradient = np.sum(dL_dK * B) + #import ipdb;ipdb.set_trace() def update_gradients_diag(self, dL_dKdiag, X): raise NotImplementedError diff --git a/GPy/testing/kernel_tests.py b/GPy/testing/kernel_tests.py index bae7b2e4..5278c8b2 100644 --- a/GPy/testing/kernel_tests.py +++ b/GPy/testing/kernel_tests.py @@ -324,6 +324,11 @@ class KernelGradientTestsContinuous(unittest.TestCase): k.randomize() self.assertTrue(check_kernel_gradient_functions(k, X=self.X, X2=self.X2, verbose=verbose)) + def test_Poly(self): + k = GPy.kern.Poly(self.D, order=5) + k.randomize() + self.assertTrue(check_kernel_gradient_functions(k, X=self.X, X2=self.X2, verbose=verbose)) + def test_standard_periodic(self): k = GPy.kern.StdPeriodic(self.D, self.D-1) k.randomize()