diff --git a/sqlite-vec.c b/sqlite-vec.c index 4d91c9e..ecdd875 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -64,6 +64,7 @@ typedef u_int64_t uint64_t; typedef int8_t i8; typedef uint8_t u8; +typedef int16_t i16; typedef int32_t i32; typedef sqlite3_int64 i64; typedef uint32_t u32; @@ -187,6 +188,45 @@ static f32 l2_sqr_float_neon(const void *pVect1v, const void *pVect2v, return sqrt( vaddvq_f32(vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3)))); } + +static f32 l2_sqr_int8_neon(const void *pVect1v, const void *pVect2v, + const void *qty_ptr) { + i8 *pVect1 = (i8 *)pVect1v; + i8 *pVect2 = (i8 *)pVect2v; + size_t qty = *((size_t *)qty_ptr); + + const i8 *pEnd1 = pVect1 + qty; + i32 sum_scalar = 0; + + while (pVect1 < pEnd1 - 7) { + // loading 8 at a time + int8x8_t v1 = vld1_s8(pVect1); + int8x8_t v2 = vld1_s8(pVect2); + pVect1 += 8; + pVect2 += 8; + + // widen to protect against overflow + int16x8_t v1_wide = vmovl_s8(v1); + int16x8_t v2_wide = vmovl_s8(v2); + + int16x8_t diff = vsubq_s16(v1_wide, v2_wide); + int16x8_t squared_diff = vmulq_s16(diff, diff); + int32x4_t sum = vpaddlq_s16(squared_diff); + + sum_scalar += vgetq_lane_s32(sum, 0) + vgetq_lane_s32(sum, 1) + + vgetq_lane_s32(sum, 2) + vgetq_lane_s32(sum, 3); + } + + // handle leftovers + while (pVect1 < pEnd1) { + i16 diff = (i16)*pVect1 - (i16)*pVect2; + sum_scalar += diff * diff; + pVect1++; + pVect2++; + } + + return sqrtf(sum_scalar); +} #endif static f32 l2_sqr_float(const void *pVect1v, const void *pVect2v, @@ -235,6 +275,11 @@ static f32 distance_l2_sqr_float(const void *a, const void *b, const void *d) { } static f32 distance_l2_sqr_int8(const void *a, const void *b, const void *d) { + #ifdef SQLITE_VEC_ENABLE_NEON + if ((*(const size_t *)d) > 7) { + return l2_sqr_int8_neon(a, b, d); + } + #endif return l2_sqr_int8(a, b, d); }