new implementation for mlp kernel (speed improvemd)

This commit is contained in:
Zhenwen Dai 2015-09-07 00:01:04 +01:00
parent 368f975aa9
commit e906da0309

View file

@ -5,6 +5,7 @@ from .kern import Kern
from ...core.parameterization import Param from ...core.parameterization import Param
from ...core.parameterization.transformations import Logexp from ...core.parameterization.transformations import Logexp
import numpy as np import numpy as np
from ...util.caching import Cache_this
four_over_tau = 2./np.pi four_over_tau = 2./np.pi
class MLP(Kern): class MLP(Kern):
@ -31,7 +32,7 @@ class MLP(Kern):
""" """
def __init__(self, input_dim, variance=1., weight_variance=1., bias_variance=100., active_dims=None, name='mlp'): def __init__(self, input_dim, variance=1., weight_variance=1., bias_variance=1., active_dims=None, name='mlp'):
super(MLP, self).__init__(input_dim, active_dims, name) super(MLP, self).__init__(input_dim, active_dims, name)
self.variance = Param('variance', variance, Logexp()) self.variance = Param('variance', variance, Logexp())
self.weight_variance = Param('weight_variance', weight_variance, Logexp()) self.weight_variance = Param('weight_variance', weight_variance, Logexp())
@ -40,96 +41,77 @@ class MLP(Kern):
def K(self, X, X2=None): def K(self, X, X2=None):
self._K_computations(X, X2) if X2 is None:
return self.variance*self._K_dvar X_denom = np.sqrt(self._comp_prod(X)+1.)
X2_denom = X_denom
X2 = X
else:
X_denom = np.sqrt(self._comp_prod(X)+1.)
X2_denom = np.sqrt(self._comp_prod(X2)+1.)
XTX = self._comp_prod(X,X2)/X_denom[:,None]/X2_denom[None,:]
return self.variance*four_over_tau*np.arcsin(XTX)
def Kdiag(self, X): def Kdiag(self, X):
"""Compute the diagonal of the covariance matrix for X.""" """Compute the diagonal of the covariance matrix for X."""
self._K_diag_computations(X) X_prod = self._comp_prod(X)
return self.variance*self._K_diag_dvar return self.variance*four_over_tau*np.arcsin(X_prod/(X_prod+1.))
def update_gradients_full(self, dL_dK, X, X2=None): def update_gradients_full(self, dL_dK, X, X2=None):
"""Derivative of the covariance with respect to the parameters.""" """Derivative of the covariance with respect to the parameters."""
self._K_computations(X, X2) dvar, dw, db = self._comp_grads(dL_dK, X, X2)[:3]
self.variance.gradient = np.sum(self._K_dvar*dL_dK) self.variance.gradient = dvar
self.weight_variance.gradient = dw
denom3 = self._K_denom**3 self.bias_variance.gradient = db
base = four_over_tau*self.variance/np.sqrt(1-self._K_asin_arg*self._K_asin_arg)
base_cov_grad = base*dL_dK
if X2 is None:
vec = np.diag(self._K_inner_prod)
self.weight_variance.gradient = ((self._K_inner_prod/self._K_denom
-.5*self._K_numer/denom3
*(np.outer((self.weight_variance*vec+self.bias_variance+1.), vec)
+np.outer(vec,(self.weight_variance*vec+self.bias_variance+1.))))*base_cov_grad).sum()
self.bias_variance.gradient = ((1./self._K_denom
-.5*self._K_numer/denom3
*((vec[None, :]+vec[:, None])*self.weight_variance
+2.*self.bias_variance + 2.))*base_cov_grad).sum()
else:
vec1 = (X*X).sum(1)
vec2 = (X2*X2).sum(1)
self.weight_variance.gradient = ((self._K_inner_prod/self._K_denom
-.5*self._K_numer/denom3
*(np.outer((self.weight_variance*vec1+self.bias_variance+1.), vec2) + np.outer(vec1, self.weight_variance*vec2 + self.bias_variance+1.)))*base_cov_grad).sum()
self.bias_variance.gradient = ((1./self._K_denom
-.5*self._K_numer/denom3
*((vec1[:, None]+vec2[None, :])*self.weight_variance
+ 2*self.bias_variance + 2.))*base_cov_grad).sum()
def update_gradients_diag(self, dL_dKdiag, X): def update_gradients_diag(self, dL_dKdiag, X):
self._K_diag_computations(X) dvar, dw, db = self._comp_grads_diag(dL_dKdiag, X)[:3]
self.variance.gradient = np.sum(self._K_diag_dvar*dL_dKdiag) self.variance.gradient = dvar
self.weight_variance.gradient = dw
self.bias_variance.gradient = db
base = four_over_tau*self.variance/np.sqrt(1-self._K_diag_asin_arg*self._K_diag_asin_arg)
base_cov_grad = base*dL_dKdiag/np.square(self._K_diag_denom)
self.weight_variance.gradient = (base_cov_grad*np.square(X).sum(axis=1)).sum()
self.bias_variance.gradient = base_cov_grad.sum()
def gradients_X(self, dL_dK, X, X2): def gradients_X(self, dL_dK, X, X2):
"""Derivative of the covariance matrix with respect to X""" """Derivative of the covariance matrix with respect to X"""
self._K_computations(X, X2) return self._comp_grads(dL_dK, X, X2)[3]
arg = self._K_asin_arg
numer = self._K_numer
denom = self._K_denom
denom3 = denom*denom*denom
if X2 is not None:
vec2 = (X2*X2).sum(1)*self.weight_variance+self.bias_variance + 1.
return four_over_tau*self.weight_variance*self.variance*((X2[None, :, :]/denom[:, :, None] - vec2[None, :, None]*X[:, None, :]*(numer/denom3)[:, :, None])*(dL_dK/np.sqrt(1-arg*arg))[:, :, None]).sum(1)
else:
vec = (X*X).sum(1)*self.weight_variance+self.bias_variance + 1.
return 2*four_over_tau*self.weight_variance*self.variance*((X[None, :, :]/denom[:, :, None] - vec[None, :, None]*X[:, None, :]*(numer/denom3)[:, :, None])*(dL_dK/np.sqrt(1-arg*arg))[:, :, None]).sum(1)
def gradients_X_diag(self, dL_dKdiag, X): def gradients_X_diag(self, dL_dKdiag, X):
"""Gradient of diagonal of covariance with respect to X""" """Gradient of diagonal of covariance with respect to X"""
self._K_diag_computations(X) return self._comp_grads_diag(dL_dKdiag, X)[3]
arg = self._K_diag_asin_arg
denom = self._K_diag_denom
#numer = self._K_diag_numer
return four_over_tau*2.*self.weight_variance*self.variance*X*(1./denom*(1. - arg)*dL_dKdiag/(np.sqrt(1-arg*arg)))[:, None]
@Cache_this(limit=50, ignore_args=())
def _K_computations(self, X, X2): def _comp_prod(self, X, X2=None):
"""Pre-computations for the covariance matrix (used for computing the covariance and its gradients."""
if X2 is None: if X2 is None:
self._K_inner_prod = np.dot(X,X.T) return (np.square(X)*self.weight_variance).sum(axis=1)+self.bias_variance
self._K_numer = self._K_inner_prod*self.weight_variance + self.bias_variance
vec = np.diag(self._K_numer) + 1.
self._K_denom = np.sqrt(np.outer(vec,vec))
else: else:
self._K_inner_prod = np.dot(X,X2.T) return (X*self.weight_variance).dot(X2.T)+self.bias_variance
self._K_numer = self._K_inner_prod*self.weight_variance + self.bias_variance
vec1 = (X*X).sum(1)*self.weight_variance + self.bias_variance + 1. @Cache_this(limit=20, ignore_args=(1,))
vec2 = (X2*X2).sum(1)*self.weight_variance + self.bias_variance + 1. def _comp_grads(self, dL_dK, X, X2=None):
self._K_denom = np.sqrt(np.outer(vec1,vec2)) var,w,b = self.variance, self.weight_variance, self.bias_variance
self._K_asin_arg = self._K_numer/self._K_denom K = self.K(X, X2)
self._K_dvar = four_over_tau*np.arcsin(self._K_asin_arg) dvar = (dL_dK*K).sum()/var
X_prod = self._comp_prod(X)
def _K_diag_computations(self, X): X2_prod = self._comp_prod(X2) if X2 is not None else X_prod
"""Pre-computations concerning the diagonal terms (used for computation of diagonal and its gradients).""" XTX = self._comp_prod(X,X2) if X2 is not None else self._comp_prod(X, X)
self._K_diag_numer = (X*X).sum(1)*self.weight_variance + self.bias_variance common = var*four_over_tau/np.sqrt((X_prod[:,None]+1.)*(X2_prod[None,:]+1.)-np.square(XTX))*dL_dK
self._K_diag_denom = self._K_diag_numer+1. dw = (common*((XTX-b)/w-XTX*(((X_prod-b)/(w*(X_prod+1.)))[:,None]+((X2_prod-b)/(w*(X2_prod+1.)))[None,:])/2.)).sum()
self._K_diag_asin_arg = self._K_diag_numer/self._K_diag_denom db = (common*(1.-XTX*(1./(X_prod[:,None]+1.)+1./(X2_prod[None,:]+1.))/2.)).sum()
self._K_diag_dvar = four_over_tau*np.arcsin(self._K_diag_asin_arg) if X2 is None:
common = common+common.T
dX = common.dot(X)*w-((common*XTX).sum(axis=1)/(X_prod+1.))[:,None]*X*w
dX2 = dX
else:
dX = common.dot(X2)*w-((common*XTX).sum(axis=1)/(X_prod+1.))[:,None]*X*w
dX2 = common.T.dot(X)*w-((common*XTX).sum(axis=0)/(X2_prod+1.))[:,None]*X2*w
return dvar, dw, db, dX, dX2
@Cache_this(limit=20, ignore_args=(1,))
def _comp_grads_diag(self, dL_dKdiag, X):
var,w,b = self.variance, self.weight_variance, self.bias_variance
K = self.Kdiag(X)
dvar = (dL_dKdiag*K).sum()/var
X_prod = self._comp_prod(X)
common = var*four_over_tau/(np.sqrt(1-np.square(X_prod/(X_prod+1)))*np.square(X_prod+1))*dL_dKdiag
dw = (common*(X_prod-b)).sum()/w
db = common.sum()
dX = common[:,None]*X*w*2
return dvar, dw, db, dX