mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-06 10:32:39 +02:00
add ARD to MLP kernel
This commit is contained in:
parent
07c2536719
commit
5608dc8c67
3 changed files with 41 additions and 5 deletions
|
|
@ -33,9 +33,14 @@ class MLP(Kern):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim, variance=1., weight_variance=1., bias_variance=1., active_dims=None, name='mlp'):
|
||||
def __init__(self, input_dim, variance=1., weight_variance=1., bias_variance=1., ARD=False, active_dims=None, name='mlp'):
|
||||
super(MLP, self).__init__(input_dim, active_dims, name)
|
||||
self.variance = Param('variance', variance, Logexp())
|
||||
self.ARD= ARD
|
||||
if ARD:
|
||||
wv = np.empty((input_dim,))
|
||||
wv[:] = weight_variance
|
||||
weight_variance = wv
|
||||
self.weight_variance = Param('weight_variance', weight_variance, Logexp())
|
||||
self.bias_variance = Param('bias_variance', bias_variance, Logexp())
|
||||
self.link_parameters(self.variance, self.weight_variance, self.bias_variance)
|
||||
|
|
@ -100,7 +105,22 @@ class MLP(Kern):
|
|||
X2_prod = self._comp_prod(X2) if X2 is not None else X_prod
|
||||
XTX = self._comp_prod(X,X2) if X2 is not None else self._comp_prod(X, X)
|
||||
common = var*four_over_tau/np.sqrt((X_prod[:,None]+1.)*(X2_prod[None,:]+1.)-np.square(XTX))*dL_dK
|
||||
dw = (common*((XTX-b)/w-XTX*(((X_prod-b)/(w*(X_prod+1.)))[:,None]+((X2_prod-b)/(w*(X2_prod+1.)))[None,:])/2.)).sum()
|
||||
if self.ARD:
|
||||
if X2 is not None:
|
||||
XX2 = X[:,None,:]*X2[None,:,:] if X2 is not None else X[:,None,:]*X[None,:,:]
|
||||
XX = np.square(X)
|
||||
X2X2 = np.square(X2)
|
||||
Q = self.weight_variance.shape[0]
|
||||
common_XTX = common*XTX
|
||||
dw = np.dot(common.flat,XX2.reshape(-1,Q)) -( (common_XTX.sum(1)/(X_prod+1.)).T.dot(XX)+(common_XTX.sum(0)/(X2_prod+1.)).dot(X2X2))/2
|
||||
else:
|
||||
XX2 = X[:,None,:]*X[None,:,:]
|
||||
XX = np.square(X)
|
||||
Q = self.weight_variance.shape[0]
|
||||
common_XTX = common*XTX
|
||||
dw = np.dot(common.flat,XX2.reshape(-1,Q)) - ((common_XTX.sum(0)+common_XTX.sum(1))/(X_prod+1.)).dot(XX)/2
|
||||
else:
|
||||
dw = (common*((XTX-b)/w-XTX*(((X_prod-b)/(w*(X_prod+1.)))[:,None]+((X2_prod-b)/(w*(X2_prod+1.)))[None,:])/2.)).sum()
|
||||
db = (common*(1.-XTX*(1./(X_prod[:,None]+1.)+1./(X2_prod[None,:]+1.))/2.)).sum()
|
||||
if X2 is None:
|
||||
common = common+common.T
|
||||
|
|
@ -118,7 +138,11 @@ class MLP(Kern):
|
|||
dvar = (dL_dKdiag*K).sum()/var
|
||||
X_prod = self._comp_prod(X)
|
||||
common = var*four_over_tau/(np.sqrt(1-np.square(X_prod/(X_prod+1)))*np.square(X_prod+1))*dL_dKdiag
|
||||
dw = (common*(X_prod-b)).sum()/w
|
||||
if self.ARD:
|
||||
XX = np.square(X)
|
||||
dw = np.dot(common,XX)
|
||||
else:
|
||||
dw = (common*(X_prod-b)).sum()/w
|
||||
db = common.sum()
|
||||
dX = common[:,None]*X*w*2
|
||||
return dvar, dw, db, dX
|
||||
|
|
|
|||
|
|
@ -7,11 +7,15 @@ An approximated psi-statistics implementation based on Gauss-Hermite Quadrature
|
|||
|
||||
import numpy as np
|
||||
|
||||
from ....core.parameterization import Param
|
||||
from GPy.util.caching import Cache_this
|
||||
from ....util.linalg import tdot
|
||||
from . import PSICOMP
|
||||
|
||||
class PSICOMP_GH(PSICOMP):
|
||||
"""
|
||||
TODO: support Psi2 with shape NxMxM
|
||||
"""
|
||||
|
||||
def __init__(self, degree=5, cache_K=True):
|
||||
self.degree = degree
|
||||
|
|
@ -64,7 +68,10 @@ class PSICOMP_GH(PSICOMP):
|
|||
|
||||
dtheta_old = kern.gradient.copy()
|
||||
dtheta = np.zeros_like(kern.gradient)
|
||||
dZ = np.zeros_like(Z.values)
|
||||
if isinstance(Z, Param):
|
||||
dZ = np.zeros_like(Z.values)
|
||||
else:
|
||||
dZ = np.zeros_like(Z)
|
||||
dmu = np.zeros_like(mu)
|
||||
dS = np.zeros_like(S)
|
||||
for i in xrange(self.degree):
|
||||
|
|
|
|||
|
|
@ -252,6 +252,11 @@ class KernelGradientTestsContinuous(unittest.TestCase):
|
|||
continuous_kerns = ['RBF', 'Linear']
|
||||
self.kernclasses = [getattr(GPy.kern, s) for s in continuous_kerns]
|
||||
|
||||
def test_MLP(self):
|
||||
k = GPy.kern.MLP(self.D,ARD=True)
|
||||
k.randomize()
|
||||
self.assertTrue(check_kernel_gradient_functions(k, X=self.X, X2=self.X2, verbose=verbose))
|
||||
|
||||
def test_Matern32(self):
|
||||
k = GPy.kern.Matern32(self.D)
|
||||
k.randomize()
|
||||
|
|
@ -464,7 +469,7 @@ class Kernel_Psi_statistics_GradientTests(unittest.TestCase):
|
|||
self.w3n = self.w3n+np.swapaxes(self.w3n, 1,2)
|
||||
|
||||
def test_kernels(self):
|
||||
from GPy.kern import RBF,Linear
|
||||
from GPy.kern import RBF,Linear,MLP
|
||||
Q = self.Z.shape[1]
|
||||
kernels = [RBF(Q,ARD=True), Linear(Q,ARD=True)]
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue