2012-11-29 16:39:20 +00:00
# Copyright (c) 2012, GPy authors (see AUTHORS.txt).
# Licensed under the BSD 3-clause license (see LICENSE.txt)
2012-11-29 16:31:48 +00:00
from kernpart import kernpart
import numpy as np
import hashlib
2013-04-10 15:50:31 +01:00
from scipy import weave
2013-04-26 19:32:33 +01:00
from . . util . linalg import tdot
2012-11-29 16:31:48 +00:00
2012-11-30 10:31:02 +00:00
class rbf ( kernpart ) :
2012-11-29 16:31:48 +00:00
"""
2013-01-18 17:43:32 +00:00
Radial Basis Function kernel , aka squared - exponential , exponentiated quadratic or Gaussian kernel :
2012-12-05 19:19:15 -08:00
. . math : :
2013-03-11 18:45:04 +00:00
k ( r ) = \sigma ^ 2 \exp \\bigg ( - \\frac { 1 } { 2 } r ^ 2 \\bigg ) \ \ \ \ \ \\text { where } r ^ 2 = \sum_ { i = 1 } ^ d \\frac { ( x_i - x ^ \prime_i ) ^ 2 } { \ell_i ^ 2 }
2012-12-05 19:19:15 -08:00
2013-01-18 17:43:32 +00:00
where \ell_i is the lengthscale , \sigma ^ 2 the variance and d the dimensionality of the input .
2012-11-29 16:31:48 +00:00
: param D : the number of input dimensions
: type D : int
: param variance : the variance of the kernel
: type variance : float
2013-01-18 13:58:41 +00:00
: param lengthscale : the vector of lengthscale of the kernel
2013-01-31 17:19:15 +00:00
: type lengthscale : array or list of the appropriate size ( or float if there is only one lengthscale parameter )
2013-01-18 13:58:41 +00:00
: param ARD : Auto Relevance Determination . If equal to " False " , the kernel is isotropic ( ie . one single lengthscale parameter \ell ) , otherwise there is one lengthscale parameter per dimension .
: type ARD : Boolean
2013-01-18 16:14:13 +00:00
: rtype : kernel object
2012-11-29 16:31:48 +00:00
2013-01-31 10:57:43 +00:00
. . Note : this object implements both the ARD and ' spherical ' version of the function
2012-11-29 16:31:48 +00:00
"""
2013-01-18 13:58:41 +00:00
def __init__ ( self , D , variance = 1. , lengthscale = None , ARD = False ) :
2012-11-29 16:31:48 +00:00
self . D = D
2013-01-31 10:57:43 +00:00
self . name = ' rbf '
2013-01-18 13:58:41 +00:00
self . ARD = ARD
2013-01-30 15:51:36 +00:00
if not ARD :
2013-01-18 13:58:41 +00:00
self . Nparam = 2
if lengthscale is not None :
2013-01-31 10:57:43 +00:00
lengthscale = np . asarray ( lengthscale )
assert lengthscale . size == 1 , " Only one lengthscale needed for non-ARD kernel "
2013-01-18 13:58:41 +00:00
else :
2013-01-30 16:27:45 +00:00
lengthscale = np . ones ( 1 )
2013-01-18 13:58:41 +00:00
else :
self . Nparam = self . D + 1
if lengthscale is not None :
2013-01-31 10:57:43 +00:00
lengthscale = np . asarray ( lengthscale )
assert lengthscale . size == self . D , " bad number of lengthscales "
2013-01-18 13:58:41 +00:00
else :
lengthscale = np . ones ( self . D )
2013-01-31 10:57:43 +00:00
self . _set_params ( np . hstack ( ( variance , lengthscale . flatten ( ) ) ) )
2012-11-29 16:31:48 +00:00
#initialize cache
self . _Z , self . _mu , self . _S = np . empty ( shape = ( 3 , 1 ) )
self . _X , self . _X2 , self . _params = np . empty ( shape = ( 3 , 1 ) )
2013-05-30 09:29:26 +01:00
#a set of optional args to pass to weave
self . weave_options = { ' headers ' : [ ' <omp.h> ' ] ,
' extra_compile_args ' : [ ' -fopenmp -O3 ' ] , #-march=native'],
' extra_link_args ' : [ ' -lgomp ' ] }
2013-01-18 12:31:37 +00:00
def _get_params ( self ) :
2012-11-29 16:31:48 +00:00
return np . hstack ( ( self . variance , self . lengthscale ) )
2013-01-18 12:31:37 +00:00
def _set_params ( self , x ) :
2013-01-18 13:58:41 +00:00
assert x . size == ( self . Nparam )
self . variance = x [ 0 ]
self . lengthscale = x [ 1 : ]
2012-11-29 16:31:48 +00:00
self . lengthscale2 = np . square ( self . lengthscale )
#reset cached results
self . _X , self . _X2 , self . _params = np . empty ( shape = ( 3 , 1 ) )
self . _Z , self . _mu , self . _S = np . empty ( shape = ( 3 , 1 ) ) # cached versions of Z,mu,S
2013-01-18 12:31:37 +00:00
def _get_param_names ( self ) :
2013-01-18 13:58:41 +00:00
if self . Nparam == 2 :
return [ ' variance ' , ' lengthscale ' ]
else :
2013-01-30 16:27:45 +00:00
return [ ' variance ' ] + [ ' lengthscale_ %i ' % i for i in range ( self . lengthscale . size ) ]
2012-11-29 16:31:48 +00:00
def K ( self , X , X2 , target ) :
self . _K_computations ( X , X2 )
2013-04-26 19:32:33 +01:00
target + = self . variance * self . _K_dvar
2012-11-29 16:31:48 +00:00
def Kdiag ( self , X , target ) :
np . add ( target , self . variance , target )
2013-03-11 12:15:59 +00:00
def dK_dtheta ( self , dL_dK , X , X2 , target ) :
2012-11-29 16:31:48 +00:00
self . _K_computations ( X , X2 )
2013-03-11 12:15:59 +00:00
target [ 0 ] + = np . sum ( self . _K_dvar * dL_dK )
2013-03-27 15:03:46 +00:00
if self . ARD :
2013-05-30 09:29:26 +01:00
dvardLdK = self . _K_dvar * dL_dK
var_len3 = self . variance / np . power ( self . lengthscale , 3 )
if X2 is None :
#save computation for the symmetrical case
dvardLdK + = dvardLdK . T
code = """
int q , i , j ;
double tmp ;
for ( q = 0 ; q < D ; q + + ) {
tmp = 0 ;
for ( i = 0 ; i < N ; i + + ) {
for ( j = 0 ; j < i ; j + + ) {
tmp + = ( X ( i , q ) - X ( j , q ) ) * ( X ( i , q ) - X ( j , q ) ) * dvardLdK ( i , j ) ;
}
}
target ( q + 1 ) + = var_len3 ( q ) * tmp ;
}
"""
N , M , D = X . shape [ 0 ] , X . shape [ 0 ] , self . D
else :
code = """
int q , i , j ;
double tmp ;
for ( q = 0 ; q < D ; q + + ) {
tmp = 0 ;
for ( i = 0 ; i < N ; i + + ) {
for ( j = 0 ; j < M ; j + + ) {
tmp + = ( X ( i , q ) - X2 ( j , q ) ) * ( X ( i , q ) - X2 ( j , q ) ) * dvardLdK ( i , j ) ;
}
}
target ( q + 1 ) + = var_len3 ( q ) * tmp ;
}
"""
N , M , D = X . shape [ 0 ] , X2 . shape [ 0 ] , self . D
#[np.add(target[1+q:2+q],var_len3[q]*np.sum(dvardLdK*np.square(X[:,q][:,None]-X2[:,q][None,:])),target[1+q:2+q]) for q in range(self.D)]
weave . inline ( code , arg_names = [ ' N ' , ' M ' , ' D ' , ' X ' , ' X2 ' , ' target ' , ' dvardLdK ' , ' var_len3 ' ] ,
type_converters = weave . converters . blitz , * * self . weave_options )
2013-01-18 13:58:41 +00:00
else :
2013-03-27 15:08:25 +00:00
target [ 1 ] + = ( self . variance / self . lengthscale ) * np . sum ( self . _K_dvar * self . _K_dist2 * dL_dK )
2012-11-29 16:31:48 +00:00
2013-03-11 12:15:59 +00:00
def dKdiag_dtheta ( self , dL_dKdiag , X , target ) :
2012-11-30 10:31:02 +00:00
#NB: derivative of diagonal elements wrt lengthscale is 0
2013-03-11 12:15:59 +00:00
target [ 0 ] + = np . sum ( dL_dKdiag )
2012-11-29 16:31:48 +00:00
2013-03-11 12:15:59 +00:00
def dK_dX ( self , dL_dK , X , X2 , target ) :
2012-11-29 16:31:48 +00:00
self . _K_computations ( X , X2 )
2013-03-27 15:03:46 +00:00
_K_dist = X [ : , None , : ] - X2 [ None , : , : ] #don't cache this in _K_computations because it is high memory. If this function is being called, chances are we're not in the high memory arena.
2013-03-27 15:08:25 +00:00
dK_dX = ( - self . variance / self . lengthscale2 ) * np . transpose ( self . _K_dvar [ : , : , np . newaxis ] * _K_dist , ( 1 , 0 , 2 ) )
2013-03-11 12:15:59 +00:00
target + = np . sum ( dK_dX * dL_dK . T [ : , : , None ] , 0 )
2012-11-29 16:31:48 +00:00
2013-03-11 12:15:59 +00:00
def dKdiag_dX ( self , dL_dKdiag , X , target ) :
2012-11-29 16:31:48 +00:00
pass
2013-01-30 16:27:45 +00:00
#---------------------------------------#
# PSI statistics #
#---------------------------------------#
2012-11-29 16:31:48 +00:00
2012-11-30 15:49:20 +00:00
def psi0 ( self , Z , mu , S , target ) :
target + = self . variance
2013-03-11 12:15:59 +00:00
def dpsi0_dtheta ( self , dL_dpsi0 , Z , mu , S , target ) :
target [ 0 ] + = np . sum ( dL_dpsi0 )
2012-11-30 15:49:20 +00:00
2013-03-11 12:15:59 +00:00
def dpsi0_dmuS ( self , dL_dpsi0 , Z , mu , S , target_mu , target_S ) :
2012-11-30 15:49:20 +00:00
pass
def psi1 ( self , Z , mu , S , target ) :
self . _psi_computations ( Z , mu , S )
target + = self . _psi1
2013-03-11 12:15:59 +00:00
def dpsi1_dtheta ( self , dL_dpsi1 , Z , mu , S , target ) :
2012-11-30 15:49:20 +00:00
self . _psi_computations ( Z , mu , S )
denom_deriv = S [ : , None , : ] / ( self . lengthscale * * 3 + self . lengthscale * S [ : , None , : ] )
d_length = self . _psi1 [ : , : , None ] * ( self . lengthscale * np . square ( self . _psi1_dist / ( self . lengthscale2 + S [ : , None , : ] ) ) + denom_deriv )
2013-03-11 12:15:59 +00:00
target [ 0 ] + = np . sum ( dL_dpsi1 * self . _psi1 / self . variance )
dpsi1_dlength = d_length * dL_dpsi1 [ : , : , None ]
2013-01-30 16:27:45 +00:00
if not self . ARD :
target [ 1 ] + = dpsi1_dlength . sum ( )
else :
target [ 1 : ] + = dpsi1_dlength . sum ( 0 ) . sum ( 0 )
2012-11-30 15:49:20 +00:00
2013-03-11 12:15:59 +00:00
def dpsi1_dZ ( self , dL_dpsi1 , Z , mu , S , target ) :
2012-11-30 15:49:20 +00:00
self . _psi_computations ( Z , mu , S )
2013-01-29 14:02:41 +00:00
denominator = ( self . lengthscale2 * ( self . _psi1_denom ) )
dpsi1_dZ = - self . _psi1 [ : , : , None ] * ( ( self . _psi1_dist / denominator ) )
2013-03-11 12:15:59 +00:00
target + = np . sum ( dL_dpsi1 . T [ : , : , None ] * dpsi1_dZ , 0 )
2012-11-30 15:49:20 +00:00
2013-03-11 12:15:59 +00:00
def dpsi1_dmuS ( self , dL_dpsi1 , Z , mu , S , target_mu , target_S ) :
2012-11-30 15:49:20 +00:00
self . _psi_computations ( Z , mu , S )
tmp = self . _psi1 [ : , : , None ] / self . lengthscale2 / self . _psi1_denom
2013-03-11 12:15:59 +00:00
target_mu + = np . sum ( dL_dpsi1 . T [ : , : , None ] * tmp * self . _psi1_dist , 1 )
target_S + = np . sum ( dL_dpsi1 . T [ : , : , None ] * 0.5 * tmp * ( self . _psi1_dist_sq - 1 ) , 1 )
2012-11-30 15:49:20 +00:00
def psi2 ( self , Z , mu , S , target ) :
self . _psi_computations ( Z , mu , S )
2013-01-30 16:27:45 +00:00
target + = self . _psi2
2012-11-30 15:49:20 +00:00
2013-03-11 12:15:59 +00:00
def dpsi2_dtheta ( self , dL_dpsi2 , Z , mu , S , target ) :
2012-11-30 15:49:20 +00:00
""" Shape N,M,M,Ntheta """
self . _psi_computations ( Z , mu , S )
2013-01-30 16:27:45 +00:00
d_var = 2. * self . _psi2 / self . variance
2013-04-12 15:02:56 +01:00
d_length = 2. * self . _psi2 [ : , : , : , None ] * ( self . _psi2_Zdist_sq * self . _psi2_denom + self . _psi2_mudist_sq + S [ : , None , None , : ] / self . lengthscale2 ) / ( self . lengthscale * self . _psi2_denom )
2013-02-15 18:08:40 +00:00
2013-03-11 12:15:59 +00:00
target [ 0 ] + = np . sum ( dL_dpsi2 * d_var )
dpsi2_dlength = d_length * dL_dpsi2 [ : , : , : , None ]
2013-01-30 16:27:45 +00:00
if not self . ARD :
target [ 1 ] + = dpsi2_dlength . sum ( )
else :
target [ 1 : ] + = dpsi2_dlength . sum ( 0 ) . sum ( 0 ) . sum ( 0 )
2013-03-11 12:15:59 +00:00
def dpsi2_dZ ( self , dL_dpsi2 , Z , mu , S , target ) :
2012-11-30 15:49:20 +00:00
self . _psi_computations ( Z , mu , S )
2013-04-12 15:02:56 +01:00
term1 = self . _psi2_Zdist / self . lengthscale2 # M, M, Q
2013-01-29 14:02:41 +00:00
term2 = self . _psi2_mudist / self . _psi2_denom / self . lengthscale2 # N, M, M, Q
2013-01-30 16:27:45 +00:00
dZ = self . _psi2 [ : , : , : , None ] * ( term1 [ None ] + term2 )
2013-03-11 12:15:59 +00:00
target + = ( dL_dpsi2 [ : , : , : , None ] * dZ ) . sum ( 0 ) . sum ( 0 )
2012-11-30 15:49:20 +00:00
2013-03-11 12:15:59 +00:00
def dpsi2_dmuS ( self , dL_dpsi2 , Z , mu , S , target_mu , target_S ) :
2012-11-30 15:49:20 +00:00
""" Think N,M,M,Q """
self . _psi_computations ( Z , mu , S )
tmp = self . _psi2 [ : , : , : , None ] / self . lengthscale2 / self . _psi2_denom
2013-04-10 20:02:22 +01:00
target_mu + = - 2. * ( dL_dpsi2 [ : , : , : , None ] * tmp * self . _psi2_mudist ) . sum ( 1 ) . sum ( 1 )
2013-03-11 12:15:59 +00:00
target_S + = ( dL_dpsi2 [ : , : , : , None ] * tmp * ( 2. * self . _psi2_mudist_sq - 1 ) ) . sum ( 1 ) . sum ( 1 )
2013-01-30 16:27:45 +00:00
#---------------------------------------#
# Precomputations #
#---------------------------------------#
def _K_computations ( self , X , X2 ) :
2013-04-26 19:32:33 +01:00
if not ( np . array_equal ( X , self . _X ) and np . array_equal ( X2 , self . _X2 ) and np . array_equal ( self . _params , self . _get_params ( ) ) ) :
2013-03-27 15:03:46 +00:00
self . _X = X . copy ( )
self . _params == self . _get_params ( ) . copy ( )
2013-04-26 19:32:33 +01:00
if X2 is None :
self . _X2 = None
X = X / self . lengthscale
Xsquare = np . sum ( np . square ( X ) , 1 )
2013-04-30 11:25:57 +01:00
self . _K_dist2 = - 2. * tdot ( X ) + ( Xsquare [ : , None ] + Xsquare [ None , : ] )
2013-04-26 19:32:33 +01:00
else :
self . _X2 = X2 . copy ( )
X = X / self . lengthscale
X2 = X2 / self . lengthscale
2013-04-30 11:25:57 +01:00
self . _K_dist2 = - 2. * np . dot ( X , X2 . T ) + ( np . sum ( np . square ( X ) , 1 ) [ : , None ] + np . sum ( np . square ( X2 ) , 1 ) [ None , : ] )
2013-03-27 15:03:46 +00:00
self . _K_dvar = np . exp ( - 0.5 * self . _K_dist2 )
2012-11-30 15:49:20 +00:00
def _psi_computations ( self , Z , mu , S ) :
#here are the "statistics" for psi1 and psi2
2013-04-26 19:32:33 +01:00
if not np . array_equal ( Z , self . _Z ) :
2012-11-30 15:49:20 +00:00
#Z has changed, compute Z specific stuff
self . _psi2_Zhat = 0.5 * ( Z [ : , None , : ] + Z [ None , : , : ] ) # M,M,Q
2013-04-12 15:02:56 +01:00
self . _psi2_Zdist = 0.5 * ( Z [ : , None , : ] - Z [ None , : , : ] ) # M,M,Q
self . _psi2_Zdist_sq = np . square ( self . _psi2_Zdist / self . lengthscale ) # M,M,Q
2012-11-30 15:49:20 +00:00
self . _Z = Z
2013-04-26 19:32:33 +01:00
if not ( np . array_equal ( Z , self . _Z ) and np . array_equal ( mu , self . _mu ) and np . array_equal ( S , self . _S ) ) :
2012-11-30 15:49:20 +00:00
#something's changed. recompute EVERYTHING
#psi1
self . _psi1_denom = S [ : , None , : ] / self . lengthscale2 + 1.
self . _psi1_dist = Z [ None , : , : ] - mu [ : , None , : ]
self . _psi1_dist_sq = np . square ( self . _psi1_dist ) / self . lengthscale2 / self . _psi1_denom
self . _psi1_exponent = - 0.5 * np . sum ( self . _psi1_dist_sq + np . log ( self . _psi1_denom ) , - 1 )
self . _psi1 = self . variance * np . exp ( self . _psi1_exponent )
#psi2
self . _psi2_denom = 2. * S [ : , None , None , : ] / self . lengthscale2 + 1. # N,M,M,Q
2013-04-10 16:12:09 +01:00
self . _psi2_mudist , self . _psi2_mudist_sq , self . _psi2_exponent , _ = self . weave_psi2 ( mu , self . _psi2_Zhat )
#self._psi2_mudist = mu[:,None,None,:]-self._psi2_Zhat #N,M,M,Q
2013-04-10 15:50:31 +01:00
#self._psi2_mudist_sq = np.square(self._psi2_mudist)/(self.lengthscale2*self._psi2_denom)
2013-04-12 15:02:56 +01:00
#self._psi2_exponent = np.sum(-self._psi2_Zdist_sq -self._psi2_mudist_sq -0.5*np.log(self._psi2_denom),-1) #N,M,M
2012-11-30 15:49:20 +00:00
self . _psi2 = np . square ( self . variance ) * np . exp ( self . _psi2_exponent ) # N,M,M
2013-04-10 15:50:31 +01:00
#store matrices for caching
2012-11-30 15:49:20 +00:00
self . _Z , self . _mu , self . _S = Z , mu , S
2013-04-10 15:50:31 +01:00
2013-04-10 16:12:09 +01:00
def weave_psi2 ( self , mu , Zhat ) :
N , Q = mu . shape
M = Zhat . shape [ 0 ]
mudist = np . empty ( ( N , M , M , Q ) )
2013-04-10 15:50:31 +01:00
mudist_sq = np . empty ( ( N , M , M , Q ) )
psi2_exponent = np . zeros ( ( N , M , M ) )
psi2 = np . empty ( ( N , M , M ) )
2013-04-10 16:12:09 +01:00
psi2_Zdist_sq = self . _psi2_Zdist_sq
2013-04-15 17:53:26 +01:00
_psi2_denom = self . _psi2_denom . squeeze ( ) . reshape ( N , self . D )
half_log_psi2_denom = 0.5 * np . log ( self . _psi2_denom ) . squeeze ( ) . reshape ( N , self . D )
2013-04-10 15:50:31 +01:00
variance_sq = float ( np . square ( self . variance ) )
if self . ARD :
lengthscale2 = self . lengthscale2
else :
lengthscale2 = np . ones ( Q ) * self . lengthscale2
code = """
double tmp ;
2013-04-10 20:02:22 +01:00
2013-04-10 16:50:02 +01:00
#pragma omp parallel for private(tmp)
2013-04-10 15:50:31 +01:00
for ( int n = 0 ; n < N ; n + + ) {
for ( int m = 0 ; m < M ; m + + ) {
for ( int mm = 0 ; mm < ( m + 1 ) ; mm + + ) {
for ( int q = 0 ; q < Q ; q + + ) {
2013-04-10 16:12:09 +01:00
/ / compute mudist
tmp = mu ( n , q ) - Zhat ( m , mm , q ) ;
mudist ( n , m , mm , q ) = tmp ;
mudist ( n , mm , m , q ) = tmp ;
/ / now mudist_sq
tmp = tmp * tmp / lengthscale2 ( q ) / _psi2_denom ( n , q ) ;
2013-04-10 15:50:31 +01:00
mudist_sq ( n , m , mm , q ) = tmp ;
mudist_sq ( n , mm , m , q ) = tmp ;
2013-04-10 16:12:09 +01:00
/ / now psi2_exponent
2013-04-12 15:02:56 +01:00
tmp = - psi2_Zdist_sq ( m , mm , q ) - tmp - half_log_psi2_denom ( n , q ) ;
2013-04-10 15:50:31 +01:00
psi2_exponent ( n , mm , m ) + = tmp ;
if ( m != mm ) {
psi2_exponent ( n , m , mm ) + = tmp ;
}
2013-04-10 16:12:09 +01:00
/ / psi2 would be computed like this , but np is faster
2013-04-10 15:50:31 +01:00
/ / tmp = variance_sq * exp ( psi2_exponent ( n , m , mm ) ) ;
/ / psi2 ( n , m , mm ) = tmp ;
/ / psi2 ( n , mm , m ) = tmp ;
}
}
}
}
2013-04-10 16:50:02 +01:00
2013-04-10 15:50:31 +01:00
"""
2013-04-10 16:50:02 +01:00
support_code = """
#include <omp.h>
#include <math.h>
"""
2013-04-10 20:02:22 +01:00
weave . inline ( code , support_code = support_code , libraries = [ ' gomp ' ] ,
2013-04-10 16:50:02 +01:00
arg_names = [ ' N ' , ' M ' , ' Q ' , ' mu ' , ' Zhat ' , ' mudist_sq ' , ' mudist ' , ' lengthscale2 ' , ' _psi2_denom ' , ' psi2_Zdist_sq ' , ' psi2_exponent ' , ' half_log_psi2_denom ' , ' psi2 ' , ' variance_sq ' ] ,
2013-05-30 09:29:26 +01:00
type_converters = weave . converters . blitz , * * self . weave_options )
2013-04-10 20:02:22 +01:00
2013-04-10 16:50:02 +01:00
return mudist , mudist_sq , psi2_exponent , psi2