(much) faster comparison between arrays. Useful for kernel caching

This commit is contained in:
Nicolo Fusi 2013-07-16 17:08:40 +01:00
parent f19a26a006
commit d46b07b18d
3 changed files with 50 additions and 7 deletions

View file

@ -1,8 +1,8 @@
# Copyright (c) 2012, GPy authors (see AUTHORS.txt).
# Licensed under the BSD 3-clause license (see LICENSE.txt)
import numpy as np
from scipy import weave
def opt_wrapper(m, **kwargs):
"""
@ -58,6 +58,47 @@ def kmm_init(X, m = 10):
inducing = np.array(inducing)
return X[inducing]
def fast_array_equal(A, B):
code="""
int i, j;
return_val = 1;
#pragma omp parallel for private(i, j)
for(i=0;i<N;i++){
for(j=0;j<D;j++){
if(A(i, j) != B(i, j)){
return_val = 0;
break;
}
}
}
"""
support_code = """
#include <omp.h>
#include <math.h>
"""
weave_options = {'headers' : ['<omp.h>'],
'extra_compile_args': ['-fopenmp -O3'],
'extra_link_args' : ['-lgomp']}
value = False
if A is not None and B is not None and A.shape == B.shape:
if len(A.shape) == 2:
N, D = A.shape
value = weave.inline(code, support_code=support_code, libraries=['gomp'],
arg_names=['A', 'B', 'N', 'D'],
type_converters=weave.converters.blitz,**weave_options)
else:
value = np.array_equal(A,B)
return value
if __name__ == '__main__':
import pylab as plt
X = np.linspace(1,10, 100)[:, None]