mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 00:36:56 +02:00
L2 int8 neon implementation (#18)
* initial work for l2 neon implementation * remove comment
This commit is contained in:
parent
e9bf355a70
commit
0c75fd292f
1 changed files with 45 additions and 0 deletions
45
sqlite-vec.c
45
sqlite-vec.c
|
|
@ -64,6 +64,7 @@ typedef u_int64_t uint64_t;
|
|||
|
||||
typedef int8_t i8;
|
||||
typedef uint8_t u8;
|
||||
typedef int16_t i16;
|
||||
typedef int32_t i32;
|
||||
typedef sqlite3_int64 i64;
|
||||
typedef uint32_t u32;
|
||||
|
|
@ -187,6 +188,45 @@ static f32 l2_sqr_float_neon(const void *pVect1v, const void *pVect2v,
|
|||
return sqrt(
|
||||
vaddvq_f32(vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3))));
|
||||
}
|
||||
|
||||
static f32 l2_sqr_int8_neon(const void *pVect1v, const void *pVect2v,
|
||||
const void *qty_ptr) {
|
||||
i8 *pVect1 = (i8 *)pVect1v;
|
||||
i8 *pVect2 = (i8 *)pVect2v;
|
||||
size_t qty = *((size_t *)qty_ptr);
|
||||
|
||||
const i8 *pEnd1 = pVect1 + qty;
|
||||
i32 sum_scalar = 0;
|
||||
|
||||
while (pVect1 < pEnd1 - 7) {
|
||||
// loading 8 at a time
|
||||
int8x8_t v1 = vld1_s8(pVect1);
|
||||
int8x8_t v2 = vld1_s8(pVect2);
|
||||
pVect1 += 8;
|
||||
pVect2 += 8;
|
||||
|
||||
// widen to protect against overflow
|
||||
int16x8_t v1_wide = vmovl_s8(v1);
|
||||
int16x8_t v2_wide = vmovl_s8(v2);
|
||||
|
||||
int16x8_t diff = vsubq_s16(v1_wide, v2_wide);
|
||||
int16x8_t squared_diff = vmulq_s16(diff, diff);
|
||||
int32x4_t sum = vpaddlq_s16(squared_diff);
|
||||
|
||||
sum_scalar += vgetq_lane_s32(sum, 0) + vgetq_lane_s32(sum, 1) +
|
||||
vgetq_lane_s32(sum, 2) + vgetq_lane_s32(sum, 3);
|
||||
}
|
||||
|
||||
// handle leftovers
|
||||
while (pVect1 < pEnd1) {
|
||||
i16 diff = (i16)*pVect1 - (i16)*pVect2;
|
||||
sum_scalar += diff * diff;
|
||||
pVect1++;
|
||||
pVect2++;
|
||||
}
|
||||
|
||||
return sqrtf(sum_scalar);
|
||||
}
|
||||
#endif
|
||||
|
||||
static f32 l2_sqr_float(const void *pVect1v, const void *pVect2v,
|
||||
|
|
@ -235,6 +275,11 @@ static f32 distance_l2_sqr_float(const void *a, const void *b, const void *d) {
|
|||
}
|
||||
|
||||
static f32 distance_l2_sqr_int8(const void *a, const void *b, const void *d) {
|
||||
#ifdef SQLITE_VEC_ENABLE_NEON
|
||||
if ((*(const size_t *)d) > 7) {
|
||||
return l2_sqr_int8_neon(a, b, d);
|
||||
}
|
||||
#endif
|
||||
return l2_sqr_int8(a, b, d);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue