Fix int16 overflow in l2_sqr_int8_neon SIMD distance

vmulq_s16(diff, diff) produced int16 results, but diff can be up to
255 for int8 vectors (-128 vs 127), and 255^2 = 65025 overflows
int16 (max 32767). This caused NaN/wrong results for int8 vectors
with large differences.

Fix: use vmull_s16 (widening multiply) to produce int32 results
directly, avoiding the intermediate int16 overflow.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Alex Garcia 2026-03-31 14:55:37 -07:00
parent 4bee88384b
commit 7de925be70
2 changed files with 14 additions and 5 deletions

View file

@ -258,13 +258,16 @@ static f32 l2_sqr_int8_neon(const void *pVect1v, const void *pVect2v,
pVect1 += 8;
pVect2 += 8;
// widen to protect against overflow
// widen i8 to i16 for subtraction
int16x8_t v1_wide = vmovl_s8(v1);
int16x8_t v2_wide = vmovl_s8(v2);
int16x8_t diff = vsubq_s16(v1_wide, v2_wide);
int16x8_t squared_diff = vmulq_s16(diff, diff);
int32x4_t sum = vpaddlq_s16(squared_diff);
// widening multiply: i16*i16 -> i32 to avoid i16 overflow
// (diff can be up to 255, so diff*diff can be up to 65025 > INT16_MAX)
int32x4_t sq_lo = vmull_s16(vget_low_s16(diff), vget_low_s16(diff));
int32x4_t sq_hi = vmull_s16(vget_high_s16(diff), vget_high_s16(diff));
int32x4_t sum = vaddq_s32(sq_lo, sq_hi);
sum_scalar += vgetq_lane_s32(sum, 0) + vgetq_lane_s32(sum, 1) +
vgetq_lane_s32(sum, 2) + vgetq_lane_s32(sum, 3);

View file

@ -381,11 +381,17 @@ def test_vec_distance_l2():
x = vec_distance_l2(a_sql_t, b_sql_t, a=transform, b=transform)
y = npy_l2(np.array(a), np.array(b))
assert isclose(x, y, abs_tol=1e-6)
assert isclose(x, y, rel_tol=1e-5, abs_tol=1e-6)
check([1.2, 0.1], [0.4, -0.4])
check([-1.2, -0.1], [-0.4, 0.4])
check([1, 2, 3], [-9, -8, -7], dtype=np.int8)
# Extreme int8 values: diff=255, squared=65025 which overflows i16
# This tests the NEON widening multiply fix (slight float rounding expected)
check([-128] * 8, [127] * 8, dtype=np.int8)
check([-128] * 16, [127] * 16, dtype=np.int8)
check([-128, 127, -128, 127, -128, 127, -128, 127],
[127, -128, 127, -128, 127, -128, 127, -128], dtype=np.int8)
def test_vec_length():