diff --git a/sqlite-vec.c b/sqlite-vec.c index d12e25d..5379f29 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -258,13 +258,16 @@ static f32 l2_sqr_int8_neon(const void *pVect1v, const void *pVect2v, pVect1 += 8; pVect2 += 8; - // widen to protect against overflow + // widen i8 to i16 for subtraction int16x8_t v1_wide = vmovl_s8(v1); int16x8_t v2_wide = vmovl_s8(v2); - int16x8_t diff = vsubq_s16(v1_wide, v2_wide); - int16x8_t squared_diff = vmulq_s16(diff, diff); - int32x4_t sum = vpaddlq_s16(squared_diff); + + // widening multiply: i16*i16 -> i32 to avoid i16 overflow + // (diff can be up to 255, so diff*diff can be up to 65025 > INT16_MAX) + int32x4_t sq_lo = vmull_s16(vget_low_s16(diff), vget_low_s16(diff)); + int32x4_t sq_hi = vmull_s16(vget_high_s16(diff), vget_high_s16(diff)); + int32x4_t sum = vaddq_s32(sq_lo, sq_hi); sum_scalar += vgetq_lane_s32(sum, 0) + vgetq_lane_s32(sum, 1) + vgetq_lane_s32(sum, 2) + vgetq_lane_s32(sum, 3); diff --git a/tests/test-loadable.py b/tests/test-loadable.py index 40c6a5e..1ac0cf3 100644 --- a/tests/test-loadable.py +++ b/tests/test-loadable.py @@ -381,11 +381,17 @@ def test_vec_distance_l2(): x = vec_distance_l2(a_sql_t, b_sql_t, a=transform, b=transform) y = npy_l2(np.array(a), np.array(b)) - assert isclose(x, y, abs_tol=1e-6) + assert isclose(x, y, rel_tol=1e-5, abs_tol=1e-6) check([1.2, 0.1], [0.4, -0.4]) check([-1.2, -0.1], [-0.4, 0.4]) check([1, 2, 3], [-9, -8, -7], dtype=np.int8) + # Extreme int8 values: diff=255, squared=65025 which overflows i16 + # This tests the NEON widening multiply fix (slight float rounding expected) + check([-128] * 8, [127] * 8, dtype=np.int8) + check([-128] * 16, [127] * 16, dtype=np.int8) + check([-128, 127, -128, 127, -128, 127, -128, 127], + [127, -128, 127, -128, 127, -128, 127, -128], dtype=np.int8) def test_vec_length():