l1 distance (#39)

* initial work on l1

* l1 int8 neon implementation

* tweak l1 int8 and add test

* broken overflow still

* some progress on l1

* change to i32 instead of i64

* remove comment

* ignore poetry stuff

* unrolled l1 int8 and format

* remove comments
This commit is contained in:
Daniel Levi-Minzi 2024-07-23 12:04:15 -04:00 committed by GitHub
parent 6eb2397537
commit 25b85afc89
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 267 additions and 15 deletions

View file

@ -230,6 +230,100 @@ static f32 l2_sqr_int8_neon(const void *pVect1v, const void *pVect2v,
return sqrtf(sum_scalar);
}
static i32 l1_int8_neon(const void *pVect1v, const void *pVect2v,
const void *qty_ptr) {
i8 *pVect1 = (i8 *)pVect1v;
i8 *pVect2 = (i8 *)pVect2v;
size_t qty = *((size_t *)qty_ptr);
const int8_t *pEnd1 = pVect1 + qty;
int32x4_t acc1 = vdupq_n_s32(0);
int32x4_t acc2 = vdupq_n_s32(0);
int32x4_t acc3 = vdupq_n_s32(0);
int32x4_t acc4 = vdupq_n_s32(0);
while (pVect1 < pEnd1 - 63) {
int8x16_t v1 = vld1q_s8(pVect1);
int8x16_t v2 = vld1q_s8(pVect2);
int8x16_t diff1 = vabdq_s8(v1, v2);
acc1 = vaddq_s32(acc1, vpaddlq_u16(vpaddlq_u8(diff1)));
v1 = vld1q_s8(pVect1 + 16);
v2 = vld1q_s8(pVect2 + 16);
int8x16_t diff2 = vabdq_s8(v1, v2);
acc2 = vaddq_s32(acc2, vpaddlq_u16(vpaddlq_u8(diff2)));
v1 = vld1q_s8(pVect1 + 32);
v2 = vld1q_s8(pVect2 + 32);
int8x16_t diff3 = vabdq_s8(v1, v2);
acc3 = vaddq_s32(acc3, vpaddlq_u16(vpaddlq_u8(diff3)));
v1 = vld1q_s8(pVect1 + 48);
v2 = vld1q_s8(pVect2 + 48);
int8x16_t diff4 = vabdq_s8(v1, v2);
acc4 = vaddq_s32(acc4, vpaddlq_u16(vpaddlq_u8(diff4)));
pVect1 += 64;
pVect2 += 64;
}
while (pVect1 < pEnd1 - 15) {
int8x16_t v1 = vld1q_s8(pVect1);
int8x16_t v2 = vld1q_s8(pVect2);
int8x16_t diff = vabdq_s8(v1, v2);
acc1 = vaddq_s32(acc1, vpaddlq_u16(vpaddlq_u8(diff)));
pVect1 += 16;
pVect2 += 16;
}
int32x4_t acc = vaddq_s32(vaddq_s32(acc1, acc2), vaddq_s32(acc3, acc4));
int32_t sum = 0;
while (pVect1 < pEnd1) {
int32_t diff = abs((int32_t)*pVect1 - (int32_t)*pVect2);
sum += diff;
pVect1++;
pVect2++;
}
return vaddvq_s32(acc) + sum;
}
static double l1_f32_neon(const void *pVect1v, const void *pVect2v,
const void *qty_ptr) {
f32 *pVect1 = (f32 *)pVect1v;
f32 *pVect2 = (f32 *)pVect2v;
size_t qty = *((size_t *)qty_ptr);
const f32 *pEnd1 = pVect1 + qty;
float64x2_t acc = vdupq_n_f64(0);
while (pVect1 < pEnd1 - 3) {
float32x4_t v1 = vld1q_f32(pVect1);
float32x4_t v2 = vld1q_f32(pVect2);
pVect1 += 4;
pVect2 += 4;
// f32x4 -> f64x2 pad for overflow
float64x2_t low_diff = vabdq_f64(vcvt_f64_f32(vget_low_f32(v1)),
vcvt_f64_f32(vget_low_f32(v2)));
float64x2_t high_diff =
vabdq_f64(vcvt_high_f64_f32(v1), vcvt_high_f64_f32(v2));
acc = vaddq_f64(acc, vaddq_f64(low_diff, high_diff));
}
double sum = 0;
while (pVect1 < pEnd1) {
sum += fabs((double)*pVect1 - (double)*pVect2);
pVect1++;
pVect2++;
}
return vaddvq_f64(acc) + sum;
}
#endif
static f32 l2_sqr_float(const void *pVect1v, const void *pVect2v,
@ -286,6 +380,54 @@ static f32 distance_l2_sqr_int8(const void *a, const void *b, const void *d) {
return l2_sqr_int8(a, b, d);
}
static i32 l1_int8(const void *pA, const void *pB, const void *pD) {
i8 *a = (i8 *)pA;
i8 *b = (i8 *)pB;
size_t d = *((size_t *)pD);
i32 res = 0;
for (size_t i = 0; i < d; i++) {
res += abs(*a - *b);
a++;
b++;
}
return res;
}
static i32 distance_l1_int8(const void *a, const void *b, const void *d) {
#ifdef SQLITE_VEC_ENABLE_NEON
if ((*(const size_t *)d) > 15) {
return l1_int8_neon(a, b, d);
}
#endif
return l1_int8(a, b, d);
}
static double l1_f32(const void *pA, const void *pB, const void *pD) {
f32 *a = (f32 *)pA;
f32 *b = (f32 *)pB;
size_t d = *((size_t *)pD);
double res = 0;
for (size_t i = 0; i < d; i++) {
res += fabs((double)*a - (double)*b);
a++;
b++;
}
return res;
}
static double distance_l1_f32(const void *a, const void *b, const void *d) {
#ifdef SQLITE_VEC_ENABLE_NEON
if ((*(const size_t *)d) > 3) {
return l1_f32_neon(a, b, d);
}
#endif
return l1_f32(a, b, d);
}
static f32 distance_cosine_float(const void *pVect1v, const void *pVect2v,
const void *qty_ptr) {
f32 *pVect1 = (f32 *)pVect1v;
@ -1040,6 +1182,48 @@ finish:
bCleanup(b);
return;
}
static void vec_distance_l1(sqlite3_context *context, int argc,
sqlite3_value **argv) {
assert(argc == 2);
int rc;
void *a, *b;
size_t dimensions;
vector_cleanup aCleanup, bCleanup;
char *error;
enum VectorElementType elementType;
rc = ensure_vector_match(argv[0], argv[1], &a, &b, &elementType, &dimensions,
&aCleanup, &bCleanup, &error);
if (rc != SQLITE_OK) {
sqlite3_result_error(context, error, -1);
sqlite3_free(error);
return;
}
switch (elementType) {
case SQLITE_VEC_ELEMENT_TYPE_BIT: {
sqlite3_result_error(
context, "Cannot calculate L1 distance between two bitvectors.", -1);
goto finish;
}
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: {
double result = distance_l1_f32(a, b, &dimensions);
sqlite3_result_double(context, result);
goto finish;
}
case SQLITE_VEC_ELEMENT_TYPE_INT8: {
i64 result = distance_l1_int8(a, b, &dimensions);
sqlite3_result_int(context, result);
goto finish;
}
}
finish:
aCleanup(a);
bCleanup(b);
return;
}
static void vec_distance_hamming(sqlite3_context *context, int argc,
sqlite3_value **argv) {
assert(argc == 2);
@ -2608,9 +2792,9 @@ static int vec_npy_eachOpen(sqlite3_vtab *p, sqlite3_vtab_cursor **ppCursor) {
static int vec_npy_eachClose(sqlite3_vtab_cursor *cur) {
vec_npy_each_cursor *pCur = (vec_npy_each_cursor *)cur;
if (pCur->file) {
#ifndef SQLITE_VEC_OMIT_FS
#ifndef SQLITE_VEC_OMIT_FS
fclose(pCur->file);
#endif
#endif
pCur->file = NULL;
}
if (pCur->chunksBuffer) {
@ -2664,9 +2848,9 @@ static int vec_npy_eachFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum,
vec_npy_each_cursor *pCur = (vec_npy_each_cursor *)pVtabCursor;
if (pCur->file) {
#ifndef SQLITE_VEC_OMIT_FS
#ifndef SQLITE_VEC_OMIT_FS
fclose(pCur->file);
#endif
#endif
pCur->file = NULL;
}
if (pCur->chunksBuffer) {
@ -2679,7 +2863,7 @@ static int vec_npy_eachFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum,
struct VecNpyFile *f = NULL;
#ifndef SQLITE_VEC_OMIT_FS
#ifndef SQLITE_VEC_OMIT_FS
if ((f = sqlite3_value_pointer(argv[0], SQLITE_VEC_NPY_FILE_NAME))) {
FILE *file = fopen(f->path, "r");
if (!file) {
@ -2689,15 +2873,15 @@ static int vec_npy_eachFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum,
rc = parse_npy_file(pVtabCursor->pVtab, file, pCur);
if (rc != SQLITE_OK) {
#ifndef SQLITE_VEC_OMIT_FS
#ifndef SQLITE_VEC_OMIT_FS
fclose(file);
#endif
#endif
return rc;
}
} else
#endif
{
#endif
{
const unsigned char *input = sqlite3_value_blob(argv[0]);
int inputLength = sqlite3_value_bytes(argv[0]);
@ -2744,7 +2928,7 @@ static int vec_npy_eachNext(sqlite3_vtab_cursor *cur) {
return SQLITE_OK;
}
#ifndef SQLITE_VEC_OMIT_FS
#ifndef SQLITE_VEC_OMIT_FS
// else: input is a file
pCur->currentChunkIndex++;
if (pCur->currentChunkIndex >= pCur->currentChunkSize) {
@ -2757,7 +2941,7 @@ static int vec_npy_eachNext(sqlite3_vtab_cursor *cur) {
}
pCur->currentChunkIndex = 0;
}
#endif
#endif
return SQLITE_OK;
}
@ -6658,6 +6842,7 @@ __declspec(dllexport)
//{"vec_version", _static_text_func, 0, DEFAULT_FLAGS, (void *) SQLITE_VEC_VERSION },
//{"vec_debug", _static_text_func, 0, DEFAULT_FLAGS, (void *) SQLITE_VEC_DEBUG_STRING },
{"vec_distance_l2", vec_distance_l2, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE, },
{"vec_distance_l1", vec_distance_l1, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE, NULL },
{"vec_distance_hamming",vec_distance_hamming, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE, },
{"vec_distance_cosine", vec_distance_cosine, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE, },
{"vec_length", vec_length, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE, },