Fix int16 overflow in l2_sqr_int8_neon SIMD distance

vmulq_s16(diff, diff) produced int16 results, but diff can be up to
255 for int8 vectors (-128 vs 127), and 255^2 = 65025 overflows
int16 (max 32767). This caused NaN/wrong results for int8 vectors
with large differences.

Fix: use vmull_s16 (widening multiply) to produce int32 results
directly, avoiding the intermediate int16 overflow.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Alex Garcia 2026-03-31 14:55:37 -07:00
parent 4bee88384b
commit 7de925be70
2 changed files with 14 additions and 5 deletions

View file

@ -258,13 +258,16 @@ static f32 l2_sqr_int8_neon(const void *pVect1v, const void *pVect2v,
pVect1 += 8;
pVect2 += 8;
// widen to protect against overflow
// widen i8 to i16 for subtraction
int16x8_t v1_wide = vmovl_s8(v1);
int16x8_t v2_wide = vmovl_s8(v2);
int16x8_t diff = vsubq_s16(v1_wide, v2_wide);
int16x8_t squared_diff = vmulq_s16(diff, diff);
int32x4_t sum = vpaddlq_s16(squared_diff);
// widening multiply: i16*i16 -> i32 to avoid i16 overflow
// (diff can be up to 255, so diff*diff can be up to 65025 > INT16_MAX)
int32x4_t sq_lo = vmull_s16(vget_low_s16(diff), vget_low_s16(diff));
int32x4_t sq_hi = vmull_s16(vget_high_s16(diff), vget_high_s16(diff));
int32x4_t sum = vaddq_s32(sq_lo, sq_hi);
sum_scalar += vgetq_lane_s32(sum, 0) + vgetq_lane_s32(sum, 1) +
vgetq_lane_s32(sum, 2) + vgetq_lane_s32(sum, 3);