[Poly] added bias and scale

This commit is contained in:
Max Zwiessele 2016-02-25 08:39:21 +00:00
parent c1ed2fe8da
commit 01a7f85c9e
2 changed files with 31 additions and 9 deletions

View file

@ -5,32 +5,49 @@ import numpy as np
from .kern import Kern
from ...core.parameterization import Param
from paramz.transformations import Logexp
from paramz.caching import Cache_this
class Poly(Kern):
"""
Polynomial kernel
"""
def __init__(self, input_dim, variance=1., order=3., active_dims=None, name='poly'):
def __init__(self, input_dim, variance=1., scale=1., bias=1., order=3., active_dims=None, name='poly'):
super(Poly, self).__init__(input_dim, active_dims, name)
self.variance = Param('variance', variance, Logexp())
self.link_parameter(self.variance)
self.scale = Param('scale', scale, Logexp())
self.bias = Param('bias', bias, Logexp())
self.link_parameters(self.variance, self.scale, self.bias)
assert order >= 1, 'The order of the polynomial has to be at least 1.'
self.order=order
def K(self, X, X2=None):
return (self._dot_product(X, X2) + 1.)**self.order * self.variance
def _dot_product(self, X, X2=None):
def K(self, X, X2=None):
_, _, B = self._AB(X, X2)
return B * self.variance
#@Cache_this(limit=2)
def _AB(self, X, X2=None):
if X2 is None:
return np.dot(X, X.T)
dot_prod = np.dot(X, X.T)
else:
return np.dot(X, X2.T)
dot_prod = np.dot(X, X2.T)
A = (self.scale * dot_prod) + self.bias
B = A ** self.order
return dot_prod, A, B
def Kdiag(self, X):
return self.variance*(np.square(X).sum(1) + 1.)**self.order
return self.K(X).diagonal()#self.variance*(np.square(X).sum(1) + 1.)**self.order
def update_gradients_full(self, dL_dK, X, X2=None):
self.variance.gradient = np.sum(dL_dK * (self._dot_product(X, X2) + 1.)**self.order)
dot_prod, A, B = self._AB(X, X2)
dK_dA = self.variance * self.order * A ** (self.order-1.)
dL_dA = dL_dK * (dK_dA)
self.scale.gradient = (dL_dA * dot_prod).sum()
self.bias.gradient = dL_dA.sum()
self.variance.gradient = np.sum(dL_dK * B)
#import ipdb;ipdb.set_trace()
def update_gradients_diag(self, dL_dKdiag, X):
raise NotImplementedError

View file

@ -324,6 +324,11 @@ class KernelGradientTestsContinuous(unittest.TestCase):
k.randomize()
self.assertTrue(check_kernel_gradient_functions(k, X=self.X, X2=self.X2, verbose=verbose))
def test_Poly(self):
k = GPy.kern.Poly(self.D, order=5)
k.randomize()
self.assertTrue(check_kernel_gradient_functions(k, X=self.X, X2=self.X2, verbose=verbose))
def test_standard_periodic(self):
k = GPy.kern.StdPeriodic(self.D, self.D-1)
k.randomize()