omp for dX

This commit is contained in:
James Hensman 2014-10-29 20:28:23 +00:00
parent f6f36234c6
commit 85ed44c9da

View file

@ -249,25 +249,34 @@ class Stationary(Kern):
X2 = X X2 = X
code = """ code = """
int n,q,d; int n,m,d;
double retnd; double retnd;
for(n=0;n<N;n++){ #pragma omp parallel for private(n,d, retnd, m)
for(d=0;d<D;d++){ for(d=0;d<D;d++){
retnd = 0; for(n=0;n<N;n++){
for(q=0;q<Q;q++){ retnd = 0.0;
retnd += tmp(n,q)*(X(n,d)-X2(q,d)); for(m=0;m<M;m++){
retnd += tmp(n,m)*(X(n,d)-X2(m,d));
} }
ret(n,d) = retnd; ret(n,d) = retnd;
} }
} }
""" """
if hasattr(X, 'values'):X = X.values #remove the GPy wrapping to make passing into weave safe if hasattr(X, 'values'):X = X.values #remove the GPy wrapping to make passing into weave safe
if hasattr(X2, 'values'):X2 = X2.values if hasattr(X2, 'values'):X2 = X2.values
ret = np.zeros(X.shape) ret = np.zeros(X.shape)
N,D = X.shape N,D = X.shape
Q = tmp.shape[1] N,M = tmp.shape
from scipy import weave from scipy import weave
weave.inline(code, ['ret', 'N', 'D', 'Q', 'tmp', 'X', 'X2'], type_converters=weave.converters.blitz) support_code = """
#include <omp.h>
#include <stdio.h>
"""
weave_options = {'headers' : ['<omp.h>'],
'extra_compile_args': ['-fopenmp -O3'], # -march=native'],
'extra_link_args' : ['-lgomp']}
weave.inline(code, ['ret', 'N', 'D', 'M', 'tmp', 'X', 'X2'], type_converters=weave.converters.blitz, support_code=support_code, **weave_options)
return ret/self.lengthscale**2 return ret/self.lengthscale**2
def gradients_X_diag(self, dL_dKdiag, X): def gradients_X_diag(self, dL_dKdiag, X):