diff --git a/GPy/kern/_src/mlp.py b/GPy/kern/_src/mlp.py index c495b77b..40a8d2d7 100644 --- a/GPy/kern/_src/mlp.py +++ b/GPy/kern/_src/mlp.py @@ -33,9 +33,14 @@ class MLP(Kern): """ - def __init__(self, input_dim, variance=1., weight_variance=1., bias_variance=1., active_dims=None, name='mlp'): + def __init__(self, input_dim, variance=1., weight_variance=1., bias_variance=1., ARD=False, active_dims=None, name='mlp'): super(MLP, self).__init__(input_dim, active_dims, name) self.variance = Param('variance', variance, Logexp()) + self.ARD= ARD + if ARD: + wv = np.empty((input_dim,)) + wv[:] = weight_variance + weight_variance = wv self.weight_variance = Param('weight_variance', weight_variance, Logexp()) self.bias_variance = Param('bias_variance', bias_variance, Logexp()) self.link_parameters(self.variance, self.weight_variance, self.bias_variance) @@ -100,7 +105,22 @@ class MLP(Kern): X2_prod = self._comp_prod(X2) if X2 is not None else X_prod XTX = self._comp_prod(X,X2) if X2 is not None else self._comp_prod(X, X) common = var*four_over_tau/np.sqrt((X_prod[:,None]+1.)*(X2_prod[None,:]+1.)-np.square(XTX))*dL_dK - dw = (common*((XTX-b)/w-XTX*(((X_prod-b)/(w*(X_prod+1.)))[:,None]+((X2_prod-b)/(w*(X2_prod+1.)))[None,:])/2.)).sum() + if self.ARD: + if X2 is not None: + XX2 = X[:,None,:]*X2[None,:,:] if X2 is not None else X[:,None,:]*X[None,:,:] + XX = np.square(X) + X2X2 = np.square(X2) + Q = self.weight_variance.shape[0] + common_XTX = common*XTX + dw = np.dot(common.flat,XX2.reshape(-1,Q)) -( (common_XTX.sum(1)/(X_prod+1.)).T.dot(XX)+(common_XTX.sum(0)/(X2_prod+1.)).dot(X2X2))/2 + else: + XX2 = X[:,None,:]*X[None,:,:] + XX = np.square(X) + Q = self.weight_variance.shape[0] + common_XTX = common*XTX + dw = np.dot(common.flat,XX2.reshape(-1,Q)) - ((common_XTX.sum(0)+common_XTX.sum(1))/(X_prod+1.)).dot(XX)/2 + else: + dw = (common*((XTX-b)/w-XTX*(((X_prod-b)/(w*(X_prod+1.)))[:,None]+((X2_prod-b)/(w*(X2_prod+1.)))[None,:])/2.)).sum() db = (common*(1.-XTX*(1./(X_prod[:,None]+1.)+1./(X2_prod[None,:]+1.))/2.)).sum() if X2 is None: common = common+common.T @@ -118,7 +138,11 @@ class MLP(Kern): dvar = (dL_dKdiag*K).sum()/var X_prod = self._comp_prod(X) common = var*four_over_tau/(np.sqrt(1-np.square(X_prod/(X_prod+1)))*np.square(X_prod+1))*dL_dKdiag - dw = (common*(X_prod-b)).sum()/w + if self.ARD: + XX = np.square(X) + dw = np.dot(common,XX) + else: + dw = (common*(X_prod-b)).sum()/w db = common.sum() dX = common[:,None]*X*w*2 return dvar, dw, db, dX diff --git a/GPy/kern/_src/psi_comp/gaussherm.py b/GPy/kern/_src/psi_comp/gaussherm.py index 8e54e6a0..923b3eb0 100644 --- a/GPy/kern/_src/psi_comp/gaussherm.py +++ b/GPy/kern/_src/psi_comp/gaussherm.py @@ -7,11 +7,15 @@ An approximated psi-statistics implementation based on Gauss-Hermite Quadrature import numpy as np +from ....core.parameterization import Param from GPy.util.caching import Cache_this from ....util.linalg import tdot from . import PSICOMP class PSICOMP_GH(PSICOMP): + """ + TODO: support Psi2 with shape NxMxM + """ def __init__(self, degree=5, cache_K=True): self.degree = degree @@ -64,7 +68,10 @@ class PSICOMP_GH(PSICOMP): dtheta_old = kern.gradient.copy() dtheta = np.zeros_like(kern.gradient) - dZ = np.zeros_like(Z.values) + if isinstance(Z, Param): + dZ = np.zeros_like(Z.values) + else: + dZ = np.zeros_like(Z) dmu = np.zeros_like(mu) dS = np.zeros_like(S) for i in xrange(self.degree): diff --git a/GPy/testing/kernel_tests.py b/GPy/testing/kernel_tests.py index 50a5aed8..cf102f99 100644 --- a/GPy/testing/kernel_tests.py +++ b/GPy/testing/kernel_tests.py @@ -252,6 +252,11 @@ class KernelGradientTestsContinuous(unittest.TestCase): continuous_kerns = ['RBF', 'Linear'] self.kernclasses = [getattr(GPy.kern, s) for s in continuous_kerns] + def test_MLP(self): + k = GPy.kern.MLP(self.D,ARD=True) + k.randomize() + self.assertTrue(check_kernel_gradient_functions(k, X=self.X, X2=self.X2, verbose=verbose)) + def test_Matern32(self): k = GPy.kern.Matern32(self.D) k.randomize() @@ -464,7 +469,7 @@ class Kernel_Psi_statistics_GradientTests(unittest.TestCase): self.w3n = self.w3n+np.swapaxes(self.w3n, 1,2) def test_kernels(self): - from GPy.kern import RBF,Linear + from GPy.kern import RBF,Linear,MLP Q = self.Z.shape[1] kernels = [RBF(Q,ARD=True), Linear(Q,ARD=True)]