mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 12:02:38 +02:00
mrd example added
This commit is contained in:
commit
89a50e260a
7 changed files with 193 additions and 92 deletions
|
|
@ -5,6 +5,7 @@
|
|||
from kernpart import kernpart
|
||||
import numpy as np
|
||||
import hashlib
|
||||
from scipy import weave
|
||||
|
||||
class rbf(kernpart):
|
||||
"""
|
||||
|
|
@ -172,7 +173,7 @@ class rbf(kernpart):
|
|||
"""Think N,M,M,Q """
|
||||
self._psi_computations(Z,mu,S)
|
||||
tmp = self._psi2[:,:,:,None]/self.lengthscale2/self._psi2_denom
|
||||
target_mu += (dL_dpsi2[:,:,:,None]*-tmp*2.*self._psi2_mudist).sum(1).sum(1)
|
||||
target_mu += -2.*(dL_dpsi2[:,:,:,None]*tmp*self._psi2_mudist).sum(1).sum(1)
|
||||
target_S += (dL_dpsi2[:,:,:,None]*tmp*(2.*self._psi2_mudist_sq-1)).sum(1).sum(1)
|
||||
|
||||
|
||||
|
|
@ -206,7 +207,6 @@ class rbf(kernpart):
|
|||
if not (np.all(Z==self._Z) and np.all(mu==self._mu) and np.all(S==self._S)):
|
||||
#something's changed. recompute EVERYTHING
|
||||
|
||||
#TODO: make more efficient for large Q (using NDL's dot product trick)
|
||||
#psi1
|
||||
self._psi1_denom = S[:,None,:]/self.lengthscale2 + 1.
|
||||
self._psi1_dist = Z[None,:,:]-mu[:,None,:]
|
||||
|
|
@ -216,9 +216,78 @@ class rbf(kernpart):
|
|||
|
||||
#psi2
|
||||
self._psi2_denom = 2.*S[:,None,None,:]/self.lengthscale2+1. # N,M,M,Q
|
||||
self._psi2_mudist = mu[:,None,None,:]-self._psi2_Zhat #N,M,M,Q
|
||||
self._psi2_mudist_sq = np.square(self._psi2_mudist)/(self.lengthscale2*self._psi2_denom)
|
||||
self._psi2_exponent = np.sum(-self._psi2_Zdist_sq/4. -self._psi2_mudist_sq -0.5*np.log(self._psi2_denom),-1) #N,M,M
|
||||
self._psi2_mudist, self._psi2_mudist_sq, self._psi2_exponent, _ = self.weave_psi2(mu,self._psi2_Zhat)
|
||||
#self._psi2_mudist = mu[:,None,None,:]-self._psi2_Zhat #N,M,M,Q
|
||||
#self._psi2_mudist_sq = np.square(self._psi2_mudist)/(self.lengthscale2*self._psi2_denom)
|
||||
#self._psi2_exponent = np.sum(-self._psi2_Zdist_sq/4. -self._psi2_mudist_sq -0.5*np.log(self._psi2_denom),-1) #N,M,M
|
||||
self._psi2 = np.square(self.variance)*np.exp(self._psi2_exponent) # N,M,M
|
||||
|
||||
#store matrices for caching
|
||||
self._Z, self._mu, self._S = Z, mu,S
|
||||
|
||||
def weave_psi2(self,mu,Zhat):
|
||||
weave_options = {'headers' : ['<omp.h>'],
|
||||
'extra_compile_args': ['-fopenmp -march=native'],
|
||||
'extra_link_args' : ['-lgomp'],
|
||||
'compiler' : 'gcc'}
|
||||
|
||||
N,Q = mu.shape
|
||||
M = Zhat.shape[0]
|
||||
|
||||
mudist = np.empty((N,M,M,Q))
|
||||
mudist_sq = np.empty((N,M,M,Q))
|
||||
psi2_exponent = np.zeros((N,M,M))
|
||||
psi2 = np.empty((N,M,M))
|
||||
|
||||
psi2_Zdist_sq = self._psi2_Zdist_sq
|
||||
half_log_psi2_denom = 0.5*np.log(self._psi2_denom).squeeze()
|
||||
variance_sq = float(np.square(self.variance))
|
||||
if self.ARD:
|
||||
lengthscale2 = self.lengthscale2
|
||||
else:
|
||||
lengthscale2 = np.ones(Q)*self.lengthscale2
|
||||
_psi2_denom = self._psi2_denom.squeeze()
|
||||
code = """
|
||||
double tmp;
|
||||
|
||||
#pragma omp parallel for private(tmp)
|
||||
for (int n=0; n<N; n++){
|
||||
for (int m=0; m<M; m++){
|
||||
for (int mm=0; mm<(m+1); mm++){
|
||||
for (int q=0; q<Q; q++){
|
||||
//compute mudist
|
||||
tmp = mu(n,q) - Zhat(m,mm,q);
|
||||
mudist(n,m,mm,q) = tmp;
|
||||
mudist(n,mm,m,q) = tmp;
|
||||
|
||||
//now mudist_sq
|
||||
tmp = tmp*tmp/lengthscale2(q)/_psi2_denom(n,q);
|
||||
mudist_sq(n,m,mm,q) = tmp;
|
||||
mudist_sq(n,mm,m,q) = tmp;
|
||||
|
||||
//now psi2_exponent
|
||||
tmp = -psi2_Zdist_sq(m,mm,q)/4.0 - tmp - half_log_psi2_denom(n,q);
|
||||
psi2_exponent(n,mm,m) += tmp;
|
||||
if (m !=mm){
|
||||
psi2_exponent(n,m,mm) += tmp;
|
||||
}
|
||||
//psi2 would be computed like this, but np is faster
|
||||
//tmp = variance_sq*exp(psi2_exponent(n,m,mm));
|
||||
//psi2(n,m,mm) = tmp;
|
||||
//psi2(n,mm,m) = tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
support_code = """
|
||||
#include <omp.h>
|
||||
#include <math.h>
|
||||
"""
|
||||
weave.inline(code, support_code=support_code, libraries=['gomp'],
|
||||
arg_names=['N','M','Q','mu','Zhat','mudist_sq','mudist','lengthscale2','_psi2_denom','psi2_Zdist_sq','psi2_exponent','half_log_psi2_denom','psi2','variance_sq'],
|
||||
type_converters=weave.converters.blitz,**weave_options)
|
||||
|
||||
return mudist,mudist_sq, psi2_exponent, psi2
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue