diff --git a/sqlite-vec.c b/sqlite-vec.c index b9b590f..e318907 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -186,8 +186,16 @@ static f32 l2_sqr_float_neon(const void *pVect1v, const void *pVect2v, sum3 = vfmaq_f32(sum3, diff, diff); } - return sqrt( - vaddvq_f32(vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3)))); + f32 sum_scalar = vaddvq_f32(vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3))); + const f32 *pEnd2 = pVect1 + (qty - (qty16 << 4)); + while (pVect1 < pEnd2) { + f32 diff = *pVect1 - *pVect2; + sum_scalar += diff * diff; + pVect1++; + pVect2++; + } + + return sqrt(sum_scalar); } static f32 l2_sqr_int8_neon(const void *pVect1v, const void *pVect2v, @@ -263,7 +271,7 @@ static f32 l2_sqr_int8(const void *pA, const void *pB, const void *pD) { static f32 distance_l2_sqr_float(const void *a, const void *b, const void *d) { #ifdef SQLITE_VEC_ENABLE_NEON - if (((*(const size_t *)d) % 16 == 0)) { + if ((*(const size_t *)d) > 16) { return l2_sqr_float_neon(a, b, d); } #endif