omp for dX

This commit is contained in:
James Hensman 2014-10-29 20:28:23 +00:00
parent f6f36234c6
commit 85ed44c9da

View file

@ -249,25 +249,34 @@ class Stationary(Kern):
X2 = X
code = """
int n,q,d;
int n,m,d;
double retnd;
for(n=0;n<N;n++){
for(d=0;d<D;d++){
retnd = 0;
for(q=0;q<Q;q++){
retnd += tmp(n,q)*(X(n,d)-X2(q,d));
#pragma omp parallel for private(n,d, retnd, m)
for(d=0;d<D;d++){
for(n=0;n<N;n++){
retnd = 0.0;
for(m=0;m<M;m++){
retnd += tmp(n,m)*(X(n,d)-X2(m,d));
}
ret(n,d) = retnd;
}
}
"""
if hasattr(X, 'values'):X = X.values #remove the GPy wrapping to make passing into weave safe
if hasattr(X2, 'values'):X2 = X2.values
ret = np.zeros(X.shape)
N,D = X.shape
Q = tmp.shape[1]
N,M = tmp.shape
from scipy import weave
weave.inline(code, ['ret', 'N', 'D', 'Q', 'tmp', 'X', 'X2'], type_converters=weave.converters.blitz)
support_code = """
#include <omp.h>
#include <stdio.h>
"""
weave_options = {'headers' : ['<omp.h>'],
'extra_compile_args': ['-fopenmp -O3'], # -march=native'],
'extra_link_args' : ['-lgomp']}
weave.inline(code, ['ret', 'N', 'D', 'M', 'tmp', 'X', 'X2'], type_converters=weave.converters.blitz, support_code=support_code, **weave_options)
return ret/self.lengthscale**2
def gradients_X_diag(self, dL_dKdiag, X):