L2 int8 neon implementation (#18)

* initial work for l2 neon implementation

* remove comment
This commit is contained in:
Daniel Levi-Minzi 2024-06-08 14:52:24 -04:00 committed by GitHub
parent e9bf355a70
commit 0c75fd292f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -64,6 +64,7 @@ typedef u_int64_t uint64_t;
typedef int8_t i8;
typedef uint8_t u8;
typedef int16_t i16;
typedef int32_t i32;
typedef sqlite3_int64 i64;
typedef uint32_t u32;
@ -187,6 +188,45 @@ static f32 l2_sqr_float_neon(const void *pVect1v, const void *pVect2v,
return sqrt(
vaddvq_f32(vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3))));
}
static f32 l2_sqr_int8_neon(const void *pVect1v, const void *pVect2v,
const void *qty_ptr) {
i8 *pVect1 = (i8 *)pVect1v;
i8 *pVect2 = (i8 *)pVect2v;
size_t qty = *((size_t *)qty_ptr);
const i8 *pEnd1 = pVect1 + qty;
i32 sum_scalar = 0;
while (pVect1 < pEnd1 - 7) {
// loading 8 at a time
int8x8_t v1 = vld1_s8(pVect1);
int8x8_t v2 = vld1_s8(pVect2);
pVect1 += 8;
pVect2 += 8;
// widen to protect against overflow
int16x8_t v1_wide = vmovl_s8(v1);
int16x8_t v2_wide = vmovl_s8(v2);
int16x8_t diff = vsubq_s16(v1_wide, v2_wide);
int16x8_t squared_diff = vmulq_s16(diff, diff);
int32x4_t sum = vpaddlq_s16(squared_diff);
sum_scalar += vgetq_lane_s32(sum, 0) + vgetq_lane_s32(sum, 1) +
vgetq_lane_s32(sum, 2) + vgetq_lane_s32(sum, 3);
}
// handle leftovers
while (pVect1 < pEnd1) {
i16 diff = (i16)*pVect1 - (i16)*pVect2;
sum_scalar += diff * diff;
pVect1++;
pVect2++;
}
return sqrtf(sum_scalar);
}
#endif
static f32 l2_sqr_float(const void *pVect1v, const void *pVect2v,
@ -235,6 +275,11 @@ static f32 distance_l2_sqr_float(const void *a, const void *b, const void *d) {
}
static f32 distance_l2_sqr_int8(const void *a, const void *b, const void *d) {
#ifdef SQLITE_VEC_ENABLE_NEON
if ((*(const size_t *)d) > 7) {
return l2_sqr_int8_neon(a, b, d);
}
#endif
return l2_sqr_int8(a, b, d);
}