OMP for psi2 computations in RBF

This commit is contained in:
Nicolo Fusi 2013-04-10 16:50:02 +01:00
parent 99ca20b77c
commit 7e0e4f63c4

View file

@ -227,7 +227,10 @@ class rbf(kernpart):
self._Z, self._mu, self._S = Z, mu,S
def weave_psi2(self,mu,Zhat):
weave_options = {'extra_compile_args': ['-O3']}
weave_options = {'headers' : ['<omp.h>'],
'extra_compile_args': ['-fopenmp -march=native'],
'extra_link_args' : ['-lgomp'],
'compiler' : 'gcc'}
N,Q = mu.shape
M = Zhat.shape[0]
@ -247,6 +250,8 @@ class rbf(kernpart):
_psi2_denom = self._psi2_denom.squeeze()
code = """
double tmp;
#pragma omp parallel for private(tmp)
for (int n=0; n<N; n++){
for (int m=0; m<M; m++){
for (int mm=0; mm<(m+1); mm++){
@ -275,7 +280,15 @@ class rbf(kernpart):
}
}
}
"""
weave.inline(code,support_code='#include "math.h"',arg_names=['N','M','Q','mu','Zhat','mudist_sq','mudist','lengthscale2','_psi2_denom','psi2_Zdist_sq','psi2_exponent','half_log_psi2_denom','psi2','variance_sq'],type_converters=weave.converters.blitz,**weave_options)
return mudist,mudist_sq, psi2_exponent, psi2
"""
support_code = """
#include <omp.h>
#include <math.h>
"""
weave.inline(code, support_code=support_code, libraries=['gomp'],
arg_names=['N','M','Q','mu','Zhat','mudist_sq','mudist','lengthscale2','_psi2_denom','psi2_Zdist_sq','psi2_exponent','half_log_psi2_denom','psi2','variance_sq'],
type_converters=weave.converters.blitz,**weave_options)
return mudist,mudist_sq, psi2_exponent, psi2