diff --git a/GPy/kern/__init__.py b/GPy/kern/__init__.py index 0e265a64..55b69bd7 100644 --- a/GPy/kern/__init__.py +++ b/GPy/kern/__init__.py @@ -1,6 +1,6 @@ from _src.kern import Kern from _src.rbf import RBF -from _src.linear import Linear +from _src.linear import Linear, LinearFull from _src.static import Bias, White from _src.brownian import Brownian from _src.sympykern import Sympykern diff --git a/GPy/kern/_src/kern.py b/GPy/kern/_src/kern.py index 31fa8690..9d8d3f7b 100644 --- a/GPy/kern/_src/kern.py +++ b/GPy/kern/_src/kern.py @@ -11,7 +11,7 @@ from ...util.caching import Cache_this class Kern(Parameterized): #=========================================================================== - # This adds input slice support. The rather ugly code for slicing can be + # This adds input slice support. The rather ugly code for slicing can be # found in kernel_slice_operations __metaclass__ = KernCallsViaSlicerMeta #=========================================================================== diff --git a/GPy/kern/_src/linear.py b/GPy/kern/_src/linear.py index 7d9eeac2..b6b1ec1b 100644 --- a/GPy/kern/_src/linear.py +++ b/GPy/kern/_src/linear.py @@ -313,3 +313,47 @@ class Linear(Kern): def input_sensitivity(self): return np.ones(self.input_dim) * self.variances + +class LinearFull(Kern): + def __init__(self, input_dim, rank, W=None, kappa=None, active_dims=None, name='linear_full'): + super(LinearFull, self).__init__(input_dim, active_dims, name) + if W is None: + W = np.ones((input_dim, rank)) + if kappa is None: + kappa = np.ones(input_dim) + assert W.shape == (input_dim, rank) + assert kappa.shape == (input_dim,) + + self.W = Param('W', W) + self.kappa = Param('kappa', kappa, Logexp()) + self.add_parameters(self.W, self.kappa) + + def K(self, X, X2=None): + P = np.dot(self.W, self.W.T) + np.diag(self.kappa) + return np.einsum('ij,jk,lk->il', X, P, X if X2 is None else X2) + + def update_gradients_full(self, dL_dK, X, X2=None): + self.kappa.gradient = np.einsum('ij,ik,kj->j', X, dL_dK, X if X2 is None else X2) + self.W.gradient = np.einsum('ij,kl,ik,lm->jm', X, X if X2 is None else X2, dL_dK, self.W) + self.W.gradient += np.einsum('ij,kl,ik,jm->lm', X, X if X2 is None else X2, dL_dK, self.W) + + def Kdiag(self, X): + P = np.dot(self.W, self.W.T) + np.diag(self.kappa) + return np.einsum('ij,jk,ik->i', X, P, X) + + def update_gradients_diag(self, dL_dKdiag, X): + self.kappa.gradient = np.einsum('ij,i->j', np.square(X), dL_dKdiag) + self.W.gradient = 2.*np.einsum('ij,ik,jl,i->kl', X, X, self.W, dL_dKdiag) + + def gradients_X(self, dL_dK, X, X2=None): + P = np.dot(self.W, self.W.T) + np.diag(self.kappa) + if X2 is None: + return 2.*np.einsum('ij,jk,kl->il', dL_dK, X, P) + else: + return np.einsum('ij,jk,kl->il', dL_dK, X2, P) + + def gradients_X_diag(self, dL_dKdiag, X): + P = np.dot(self.W, self.W.T) + np.diag(self.kappa) + return 2.*np.einsum('jk,i,ij->ik', P, dL_dKdiag, X) + + diff --git a/GPy/kern/_src/rbf.py b/GPy/kern/_src/rbf.py index c2877d06..0f19dbd1 100644 --- a/GPy/kern/_src/rbf.py +++ b/GPy/kern/_src/rbf.py @@ -64,7 +64,7 @@ class RBF(Stationary): if self.ARD: self.lengthscale.gradient = (dL_dpsi1[:,:,None]*_dpsi1_dlengthscale).reshape(-1,self.input_dim).sum(axis=0) else: - self.lengthscale.gradient = (dL_dpsi1[:,:,None]*_dpsi1_dlengthscale).sum() + self.lengthscale.gradient = (dL_dpsi1[:,:,None]*_dpsi1_dlengthscale).sum() #from psi2 self.variance.gradient += (dL_dpsi2 * _dpsi2_dvariance).sum() diff --git a/GPy/testing/kernel_tests.py b/GPy/testing/kernel_tests.py index 9ed218d8..0a74143c 100644 --- a/GPy/testing/kernel_tests.py +++ b/GPy/testing/kernel_tests.py @@ -276,6 +276,11 @@ class KernelGradientTestsContinuous(unittest.TestCase): k.randomize() self.assertTrue(check_kernel_gradient_functions(k, X=self.X, X2=self.X2, verbose=verbose)) + def test_LinearFull(self): + k = GPy.kern.LinearFull(self.D, self.D-1) + k.randomize() + self.assertTrue(check_kernel_gradient_functions(k, X=self.X, X2=self.X2, verbose=verbose)) + #TODO: turn off grad checkingwrt X for indexed kernels like coregionalize # class KernelGradientTestsContinuous1D(unittest.TestCase): # def setUp(self):