fast_array_equals now handles 3d matrices

This commit is contained in:
Nicolo Fusi 2013-07-17 17:44:04 +01:00
parent fa523c3fce
commit a4170abceb

View file

@ -59,7 +59,7 @@ def kmm_init(X, m = 10):
return X[inducing]
def fast_array_equal(A, B):
code="""
code2="""
int i, j;
return_val = 1;
@ -74,6 +74,23 @@ def fast_array_equal(A, B):
}
"""
code3="""
int i, j, z;
return_val = 1;
#pragma omp parallel for private(i, j, z)
for(i=0;i<N;i++){
for(j=0;j<D;j++){
for(z=0;z<Q;z++){
if(A(i, j, z) != B(i, j, z)){
return_val = 0;
break;
}
}
}
}
"""
support_code = """
#include <omp.h>
#include <math.h>
@ -93,9 +110,14 @@ def fast_array_equal(A, B):
elif A.shape == B.shape:
if len(A.shape) == 2:
N, D = A.shape
value = weave.inline(code, support_code=support_code, libraries=['gomp'],
value = weave.inline(code2, support_code=support_code, libraries=['gomp'],
arg_names=['A', 'B', 'N', 'D'],
type_converters=weave.converters.blitz,**weave_options)
elif len(A.shape) == 3:
N, D, Q = A.shape
value = weave.inline(code3, support_code=support_code, libraries=['gomp'],
arg_names=['A', 'B', 'N', 'D', 'Q'],
type_converters=weave.converters.blitz,**weave_options)
else:
value = np.array_equal(A,B)