make format

This commit is contained in:
Alex Garcia 2024-06-13 16:32:57 -07:00
parent dcb8bf5e53
commit df48ac2416

View file

@ -1,6 +1,7 @@
#include "sqlite-vec.h" #include "sqlite-vec.h"
#include <assert.h> #include <assert.h>
#include <errno.h> #include <errno.h>
#include <float.h>
#include <inttypes.h> #include <inttypes.h>
#include <limits.h> #include <limits.h>
#include <math.h> #include <math.h>
@ -9,7 +10,6 @@
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include <float.h>
#include "sqlite3ext.h" #include "sqlite3ext.h"
SQLITE_EXTENSION_INIT1 SQLITE_EXTENSION_INIT1
@ -186,7 +186,8 @@ static f32 l2_sqr_float_neon(const void *pVect1v, const void *pVect2v,
sum3 = vfmaq_f32(sum3, diff, diff); sum3 = vfmaq_f32(sum3, diff, diff);
} }
f32 sum_scalar = vaddvq_f32(vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3))); f32 sum_scalar =
vaddvq_f32(vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3)));
const f32 *pEnd2 = pVect1 + (qty - (qty16 << 4)); const f32 *pEnd2 = pVect1 + (qty - (qty16 << 4));
while (pVect1 < pEnd2) { while (pVect1 < pEnd2) {
f32 diff = *pVect1 - *pVect2; f32 diff = *pVect1 - *pVect2;
@ -419,7 +420,8 @@ struct Array {
* @param array * @param array
* @param element_size * @param element_size
* @param init_capacity * @param init_capacity
* @return SQLITE_OK on success, error code on failure. Only error is SQLITE_NOMEM * @return SQLITE_OK on success, error code on failure. Only error is
* SQLITE_NOMEM
*/ */
int array_init(struct Array *array, size_t element_size, size_t init_capacity) { int array_init(struct Array *array, size_t element_size, size_t init_capacity) {
void *z = sqlite3_malloc(element_size * init_capacity); void *z = sqlite3_malloc(element_size * init_capacity);
@ -658,10 +660,8 @@ static int int8_vec_from_value(sqlite3_value *value, i8 **vector,
errno = 0; errno = 0;
long result = strtol(ptr, &endptr, 10); long result = strtol(ptr, &endptr, 10);
if ((errno != 0 && result == 0) if ((errno != 0 && result == 0) ||
|| (errno == ERANGE && (errno == ERANGE && (result == LONG_MAX || result == LONG_MIN))) {
(result == LONG_MAX || result == LONG_MIN))
) {
sqlite3_free(x.z); sqlite3_free(x.z);
*pzErr = sqlite3_mprintf("JSON parsing error"); *pzErr = sqlite3_mprintf("JSON parsing error");
return SQLITE_ERROR; return SQLITE_ERROR;
@ -678,7 +678,8 @@ static int int8_vec_from_value(sqlite3_value *value, i8 **vector,
if (result < INT8_MIN || result > INT8_MAX) { if (result < INT8_MIN || result > INT8_MAX) {
sqlite3_free(x.z); sqlite3_free(x.z);
*pzErr = sqlite3_mprintf("JSON parsing error: value out of range for int8"); *pzErr =
sqlite3_mprintf("JSON parsing error: value out of range for int8");
return SQLITE_ERROR; return SQLITE_ERROR;
} }
@ -1126,7 +1127,8 @@ static void vec_quantize_binary(sqlite3_context *context, int argc,
sqlite3_result_blob(context, out, dimensions / CHAR_BIT, sqlite3_free); sqlite3_result_blob(context, out, dimensions / CHAR_BIT, sqlite3_free);
sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_BIT); sqlite3_result_subtype(context, SQLITE_VEC_ELEMENT_TYPE_BIT);
} else { } else {
sqlite3_result_error(context, "Can only binary quantize float or int8 vectors", -1); sqlite3_result_error(context,
"Can only binary quantize float or int8 vectors", -1);
return; return;
} }
} }
@ -1373,8 +1375,7 @@ static void vec_to_json(sqlite3_context *context, int argc,
f32 value = ((f32 *)vector)[i]; f32 value = ((f32 *)vector)[i];
if (isnan(value)) { if (isnan(value)) {
sqlite3_str_appendall(str, "null"); sqlite3_str_appendall(str, "null");
} } else {
else {
sqlite3_str_appendf(str, "%f", value); sqlite3_str_appendf(str, "%f", value);
} }
@ -4002,8 +4003,7 @@ static int vec0Eof(sqlite3_vtab_cursor *cur) {
case SQLITE_VEC0_QUERYPLAN_KNN: { case SQLITE_VEC0_QUERYPLAN_KNN: {
todo_assert(pCur->knn_data); todo_assert(pCur->knn_data);
return (pCur->knn_data->current_idx >= pCur->knn_data->k) || return (pCur->knn_data->current_idx >= pCur->knn_data->k) ||
(pCur->knn_data->distances[pCur->knn_data->current_idx] == (pCur->knn_data->distances[pCur->knn_data->current_idx] == FLT_MAX);
FLT_MAX);
} }
case SQLITE_VEC0_QUERYPLAN_POINT: { case SQLITE_VEC0_QUERYPLAN_POINT: {
todo_assert(pCur->point_data); todo_assert(pCur->point_data);
@ -4655,7 +4655,8 @@ struct static_blob_definition {
size_t nvectors; size_t nvectors;
enum VectorElementType element_type; enum VectorElementType element_type;
}; };
static void vec_static_blob_from_raw(sqlite3_context *context, int argc, sqlite3_value **argv) { static void vec_static_blob_from_raw(sqlite3_context *context, int argc,
sqlite3_value **argv) {
struct static_blob_definition *p; struct static_blob_definition *p;
p = sqlite3_malloc(sizeof(*p)); p = sqlite3_malloc(sizeof(*p));
todo_assert(p); todo_assert(p);
@ -4663,7 +4664,8 @@ static void vec_static_blob_from_raw(sqlite3_context *context, int argc, sqlite3
p->element_type = SQLITE_VEC_ELEMENT_TYPE_FLOAT32; p->element_type = SQLITE_VEC_ELEMENT_TYPE_FLOAT32;
p->dimensions = sqlite3_value_int64(argv[2]); p->dimensions = sqlite3_value_int64(argv[2]);
p->nvectors = sqlite3_value_int64(argv[3]); p->nvectors = sqlite3_value_int64(argv[3]);
sqlite3_result_pointer(context, p, POINTER_NAME_STATIC_BLOB_DEF, sqlite3_free); sqlite3_result_pointer(context, p, POINTER_NAME_STATIC_BLOB_DEF,
sqlite3_free);
} }
#pragma region vec_static_blobs() table function #pragma region vec_static_blobs() table function
@ -4696,8 +4698,8 @@ struct vec_static_blobs_cursor {
}; };
static int vec_static_blobsConnect(sqlite3 *db, void *pAux, int argc, static int vec_static_blobsConnect(sqlite3 *db, void *pAux, int argc,
const char *const *argv, sqlite3_vtab **ppVtab, const char *const *argv,
char **pzErr) { sqlite3_vtab **ppVtab, char **pzErr) {
vec_static_blobs_vtab *pNew; vec_static_blobs_vtab *pNew;
#define VEC_STATIC_BLOBS_NAME 0 #define VEC_STATIC_BLOBS_NAME 0
#define VEC_STATIC_BLOBS_DATA 1 #define VEC_STATIC_BLOBS_DATA 1
@ -4722,8 +4724,8 @@ static int vec_static_blobsDisconnect(sqlite3_vtab *pVtab) {
return SQLITE_OK; return SQLITE_OK;
} }
static int vec_static_blobsUpdate(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv, static int vec_static_blobsUpdate(sqlite3_vtab *pVTab, int argc,
sqlite_int64 *pRowid) { sqlite3_value **argv, sqlite_int64 *pRowid) {
vec_static_blobs_vtab *p = (vec_static_blobs_vtab *)pVTab; vec_static_blobs_vtab *p = (vec_static_blobs_vtab *)pVTab;
// DELETE operation // DELETE operation
if (argc == 1 && sqlite3_value_type(argv[0]) != SQLITE_NULL) { if (argc == 1 && sqlite3_value_type(argv[0]) != SQLITE_NULL) {
@ -4740,8 +4742,10 @@ static int vec_static_blobsUpdate(sqlite3_vtab *pVTab, int argc, sqlite3_value *
break; break;
} }
} }
if(idx < 0) abort(); if (idx < 0)
struct static_blob_definition * def = sqlite3_value_pointer(argv[2 + VEC_STATIC_BLOBS_DATA], POINTER_NAME_STATIC_BLOB_DEF); abort();
struct static_blob_definition *def = sqlite3_value_pointer(
argv[2 + VEC_STATIC_BLOBS_DATA], POINTER_NAME_STATIC_BLOB_DEF);
p->data->static_blobs[idx].p = def->p; p->data->static_blobs[idx].p = def->p;
p->data->static_blobs[idx].dimensions = def->dimensions; p->data->static_blobs[idx].dimensions = def->dimensions;
p->data->static_blobs[idx].nvectors = def->nvectors; p->data->static_blobs[idx].nvectors = def->nvectors;
@ -4756,7 +4760,8 @@ static int vec_static_blobsUpdate(sqlite3_vtab *pVTab, int argc, sqlite3_value *
return SQLITE_ERROR; return SQLITE_ERROR;
} }
static int vec_static_blobsOpen(sqlite3_vtab *p, sqlite3_vtab_cursor **ppCursor) { static int vec_static_blobsOpen(sqlite3_vtab *p,
sqlite3_vtab_cursor **ppCursor) {
vec_static_blobs_cursor *pCur; vec_static_blobs_cursor *pCur;
pCur = sqlite3_malloc(sizeof(*pCur)); pCur = sqlite3_malloc(sizeof(*pCur));
if (pCur == 0) if (pCur == 0)
@ -4772,7 +4777,8 @@ static int vec_static_blobsClose(sqlite3_vtab_cursor *cur) {
return SQLITE_OK; return SQLITE_OK;
} }
static int vec_static_blobsBestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) { static int vec_static_blobsBestIndex(sqlite3_vtab *pVTab,
sqlite3_index_info *pIdxInfo) {
pIdxInfo->idxNum = 1; pIdxInfo->idxNum = 1;
pIdxInfo->estimatedCost = (double)10; pIdxInfo->estimatedCost = (double)10;
pIdxInfo->estimatedRows = 10; pIdxInfo->estimatedRows = 10;
@ -4781,14 +4787,16 @@ static int vec_static_blobsBestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pI
static int vec_static_blobsNext(sqlite3_vtab_cursor *cur); static int vec_static_blobsNext(sqlite3_vtab_cursor *cur);
static int vec_static_blobsFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum, static int vec_static_blobsFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum,
const char *idxStr, int argc, sqlite3_value **argv) { const char *idxStr, int argc,
sqlite3_value **argv) {
vec_static_blobs_cursor *pCur = (vec_static_blobs_cursor *)pVtabCursor; vec_static_blobs_cursor *pCur = (vec_static_blobs_cursor *)pVtabCursor;
pCur->iRowid = -1; pCur->iRowid = -1;
vec_static_blobsNext(pVtabCursor); vec_static_blobsNext(pVtabCursor);
return SQLITE_OK; return SQLITE_OK;
} }
static int vec_static_blobsRowid(sqlite3_vtab_cursor *cur, sqlite_int64 *pRowid) { static int vec_static_blobsRowid(sqlite3_vtab_cursor *cur,
sqlite_int64 *pRowid) {
vec_static_blobs_cursor *pCur = (vec_static_blobs_cursor *)cur; vec_static_blobs_cursor *pCur = (vec_static_blobs_cursor *)cur;
*pRowid = pCur->iRowid; *pRowid = pCur->iRowid;
return SQLITE_OK; return SQLITE_OK;
@ -4812,19 +4820,21 @@ static int vec_static_blobsEof(sqlite3_vtab_cursor *cur) {
return pCur->iRowid >= MAX_STATIC_BLOBS; return pCur->iRowid >= MAX_STATIC_BLOBS;
} }
static int vec_static_blobsColumn(sqlite3_vtab_cursor *cur, sqlite3_context *context, static int vec_static_blobsColumn(sqlite3_vtab_cursor *cur,
int i) { sqlite3_context *context, int i) {
vec_static_blobs_cursor *pCur = (vec_static_blobs_cursor *)cur; vec_static_blobs_cursor *pCur = (vec_static_blobs_cursor *)cur;
vec_static_blobs_vtab *p = (vec_static_blobs_vtab *)cur->pVtab; vec_static_blobs_vtab *p = (vec_static_blobs_vtab *)cur->pVtab;
switch (i) { switch (i) {
case VEC_STATIC_BLOBS_NAME: case VEC_STATIC_BLOBS_NAME:
sqlite3_result_text(context, p->data->static_blobs[pCur->iRowid].name, -1, SQLITE_TRANSIENT); sqlite3_result_text(context, p->data->static_blobs[pCur->iRowid].name, -1,
SQLITE_TRANSIENT);
break; break;
case VEC_STATIC_BLOBS_DATA: case VEC_STATIC_BLOBS_DATA:
sqlite3_result_null(context); sqlite3_result_null(context);
break; break;
case VEC_STATIC_BLOBS_DIMENSIONS: case VEC_STATIC_BLOBS_DIMENSIONS:
sqlite3_result_int64(context, p->data->static_blobs[pCur->iRowid].dimensions); sqlite3_result_int64(context,
p->data->static_blobs[pCur->iRowid].dimensions);
break; break;
case VEC_STATIC_BLOBS_COUNT: case VEC_STATIC_BLOBS_COUNT:
sqlite3_result_int64(context, p->data->static_blobs[pCur->iRowid].nvectors); sqlite3_result_int64(context, p->data->static_blobs[pCur->iRowid].nvectors);
@ -4833,7 +4843,6 @@ static int vec_static_blobsColumn(sqlite3_vtab_cursor *cur, sqlite3_context *con
return SQLITE_OK; return SQLITE_OK;
} }
static sqlite3_module vec_static_blobsModule = { static sqlite3_module vec_static_blobsModule = {
/* iVersion */ 3, /* iVersion */ 3,
/* xCreate */ 0, /* xCreate */ 0,
@ -4861,7 +4870,6 @@ static sqlite3_module vec_static_blobsModule = {
/* xShadowName */ 0}; /* xShadowName */ 0};
#pragma endregion #pragma endregion
#pragma region vec_static_blob_entries() table function #pragma region vec_static_blob_entries() table function
typedef struct vec_static_blob_entries_vtab vec_static_blob_entries_vtab; typedef struct vec_static_blob_entries_vtab vec_static_blob_entries_vtab;
@ -4882,20 +4890,22 @@ struct vec_static_blob_entries_cursor {
struct vec0_query_knn_data *knn_data; struct vec0_query_knn_data *knn_data;
}; };
static int vec_static_blob_entriesConnect(sqlite3 *db, void *pAux, int argc, static int vec_static_blob_entriesConnect(sqlite3 *db, void *pAux, int argc,
const char *const *argv, sqlite3_vtab **ppVtab, const char *const *argv,
char **pzErr) { sqlite3_vtab **ppVtab, char **pzErr) {
vec_static_blob_data *blob_data = pAux; vec_static_blob_data *blob_data = pAux;
int idx = -1; int idx = -1;
for (int i = 0; i < MAX_STATIC_BLOBS; i++) { for (int i = 0; i < MAX_STATIC_BLOBS; i++) {
if(!blob_data->static_blobs[i].name) continue; if (!blob_data->static_blobs[i].name)
if(strncmp(blob_data->static_blobs[i].name, argv[3], strlen(blob_data->static_blobs[i].name))==0) { continue;
if (strncmp(blob_data->static_blobs[i].name, argv[3],
strlen(blob_data->static_blobs[i].name)) == 0) {
idx = i; idx = i;
break; break;
} }
} }
if(idx < 0) abort(); if (idx < 0)
abort();
vec_static_blob_entries_vtab *pNew; vec_static_blob_entries_vtab *pNew;
#define VEC_STATIC_BLOB_ENTRIES_VECTOR 0 #define VEC_STATIC_BLOB_ENTRIES_VECTOR 0
#define VEC_STATIC_BLOB_ENTRIES_DISTANCE 1 #define VEC_STATIC_BLOB_ENTRIES_DISTANCE 1
@ -4914,8 +4924,8 @@ static int vec_static_blob_entriesConnect(sqlite3 *db, void *pAux, int argc,
} }
static int vec_static_blob_entriesCreate(sqlite3 *db, void *pAux, int argc, static int vec_static_blob_entriesCreate(sqlite3 *db, void *pAux, int argc,
const char *const *argv, sqlite3_vtab **ppVtab, const char *const *argv,
char **pzErr) { sqlite3_vtab **ppVtab, char **pzErr) {
vec_static_blob_entriesConnect(db, pAux, argc, argv, ppVtab, pzErr); vec_static_blob_entriesConnect(db, pAux, argc, argv, ppVtab, pzErr);
} }
@ -4925,7 +4935,8 @@ static int vec_static_blob_entriesDisconnect(sqlite3_vtab *pVtab) {
return SQLITE_OK; return SQLITE_OK;
} }
static int vec_static_blob_entriesOpen(sqlite3_vtab *p, sqlite3_vtab_cursor **ppCursor) { static int vec_static_blob_entriesOpen(sqlite3_vtab *p,
sqlite3_vtab_cursor **ppCursor) {
vec_static_blob_entries_cursor *pCur; vec_static_blob_entries_cursor *pCur;
pCur = sqlite3_malloc(sizeof(*pCur)); pCur = sqlite3_malloc(sizeof(*pCur));
if (pCur == 0) if (pCur == 0)
@ -4941,7 +4952,8 @@ static int vec_static_blob_entriesClose(sqlite3_vtab_cursor *cur) {
return SQLITE_OK; return SQLITE_OK;
} }
static int vec_static_blob_entriesBestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) { static int vec_static_blob_entriesBestIndex(sqlite3_vtab *pVTab,
sqlite3_index_info *pIdxInfo) {
vec_static_blob_entries_vtab *p = (vec_static_blob_entries_vtab *)pVTab; vec_static_blob_entries_vtab *p = (vec_static_blob_entries_vtab *)pVTab;
int iMatchTerm = -1; int iMatchTerm = -1;
int iLimitTerm = -1; int iLimitTerm = -1;
@ -4954,7 +4966,8 @@ static int vec_static_blob_entriesBestIndex(sqlite3_vtab *pVTab, sqlite3_index_i
int iColumn = pIdxInfo->aConstraint[i].iColumn; int iColumn = pIdxInfo->aConstraint[i].iColumn;
int op = pIdxInfo->aConstraint[i].op; int op = pIdxInfo->aConstraint[i].op;
if (op == SQLITE_INDEX_CONSTRAINT_MATCH && iColumn == VEC_STATIC_BLOB_ENTRIES_VECTOR) { if (op == SQLITE_INDEX_CONSTRAINT_MATCH &&
iColumn == VEC_STATIC_BLOB_ENTRIES_VECTOR) {
if (iMatchTerm > -1) { if (iMatchTerm > -1) {
// TODO only 1 match operator at a time // TODO only 1 match operator at a time
return SQLITE_ERROR; return SQLITE_ERROR;
@ -4964,7 +4977,8 @@ static int vec_static_blob_entriesBestIndex(sqlite3_vtab *pVTab, sqlite3_index_i
if (op == SQLITE_INDEX_CONSTRAINT_LIMIT) { if (op == SQLITE_INDEX_CONSTRAINT_LIMIT) {
iLimitTerm = i; iLimitTerm = i;
} }
if (op == SQLITE_INDEX_CONSTRAINT_EQ && iColumn == VEC_STATIC_BLOB_ENTRIES_K) { if (op == SQLITE_INDEX_CONSTRAINT_EQ &&
iColumn == VEC_STATIC_BLOB_ENTRIES_K) {
iKTerm = i; iKTerm = i;
} }
} }
@ -5010,8 +5024,7 @@ static int vec_static_blob_entriesBestIndex(sqlite3_vtab *pVTab, sqlite3_index_i
pIdxInfo->aConstraintUsage[iKTerm].omit = 1; pIdxInfo->aConstraintUsage[iKTerm].omit = 1;
} }
} } else {
else {
pIdxInfo->idxNum = VEC_SBE__QUERYPLAN_FULLSCAN; pIdxInfo->idxNum = VEC_SBE__QUERYPLAN_FULLSCAN;
pIdxInfo->estimatedCost = (double)p->blob->nvectors; pIdxInfo->estimatedCost = (double)p->blob->nvectors;
pIdxInfo->estimatedRows = p->blob->nvectors; pIdxInfo->estimatedRows = p->blob->nvectors;
@ -5019,10 +5032,13 @@ static int vec_static_blob_entriesBestIndex(sqlite3_vtab *pVTab, sqlite3_index_i
return SQLITE_OK; return SQLITE_OK;
} }
static int vec_static_blob_entriesFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum, static int vec_static_blob_entriesFilter(sqlite3_vtab_cursor *pVtabCursor,
const char *idxStr, int argc, sqlite3_value **argv) { int idxNum, const char *idxStr,
vec_static_blob_entries_cursor *pCur = (vec_static_blob_entries_cursor *)pVtabCursor; int argc, sqlite3_value **argv) {
vec_static_blob_entries_vtab *p = (vec_static_blob_entries_vtab *)pCur->base.pVtab; vec_static_blob_entries_cursor *pCur =
(vec_static_blob_entries_cursor *)pVtabCursor;
vec_static_blob_entries_vtab *p =
(vec_static_blob_entries_vtab *)pCur->base.pVtab;
if (idxNum == VEC_SBE__QUERYPLAN_KNN) { if (idxNum == VEC_SBE__QUERYPLAN_KNN) {
pCur->query_plan = VEC_SBE__QUERYPLAN_KNN; pCur->query_plan = VEC_SBE__QUERYPLAN_KNN;
@ -5038,7 +5054,8 @@ static int vec_static_blob_entriesFilter(sqlite3_vtab_cursor *pVtabCursor, int i
enum VectorElementType elementType; enum VectorElementType elementType;
vector_cleanup cleanup; vector_cleanup cleanup;
char *err; char *err;
int rc = vector_from_value(argv[0], &queryVector, &dimensions, &elementType, &cleanup, &err); int rc = vector_from_value(argv[0], &queryVector, &dimensions, &elementType,
&cleanup, &err);
todo_assert(elementType == p->blob->element_type); todo_assert(elementType == p->blob->element_type);
todo_assert(dimensions == p->blob->dimensions); todo_assert(dimensions == p->blob->dimensions);
@ -5059,7 +5076,8 @@ static int vec_static_blob_entriesFilter(sqlite3_vtab_cursor *pVtabCursor, int i
for (size_t i = 0; i < p->blob->nvectors; i++) { for (size_t i = 0; i < p->blob->nvectors; i++) {
float *v = ((float *)p->blob->p) + (i * p->blob->dimensions); float *v = ((float *)p->blob->p) + (i * p->blob->dimensions);
distances[i] = distance_l2_sqr_float(v, (float *) queryVector, &p->blob->dimensions); distances[i] =
distance_l2_sqr_float(v, (float *)queryVector, &p->blob->dimensions);
} }
min_idx(distances, k, topk_rowids, k); min_idx(distances, k, topk_rowids, k);
knn_data->current_idx = 0; knn_data->current_idx = 0;
@ -5068,8 +5086,7 @@ static int vec_static_blob_entriesFilter(sqlite3_vtab_cursor *pVtabCursor, int i
knn_data->rowids = topk_rowids; knn_data->rowids = topk_rowids;
pCur->knn_data = knn_data; pCur->knn_data = knn_data;
} } else {
else {
pCur->query_plan = VEC_SBE__QUERYPLAN_FULLSCAN; pCur->query_plan = VEC_SBE__QUERYPLAN_FULLSCAN;
pCur->iRowid = 0; pCur->iRowid = 0;
} }
@ -5077,7 +5094,8 @@ static int vec_static_blob_entriesFilter(sqlite3_vtab_cursor *pVtabCursor, int i
return SQLITE_OK; return SQLITE_OK;
} }
static int vec_static_blob_entriesRowid(sqlite3_vtab_cursor *cur, sqlite_int64 *pRowid) { static int vec_static_blob_entriesRowid(sqlite3_vtab_cursor *cur,
sqlite_int64 *pRowid) {
vec_static_blob_entries_cursor *pCur = (vec_static_blob_entries_cursor *)cur; vec_static_blob_entries_cursor *pCur = (vec_static_blob_entries_cursor *)cur;
*pRowid = pCur->iRowid; *pRowid = pCur->iRowid;
return SQLITE_OK; return SQLITE_OK;
@ -5095,12 +5113,12 @@ static int vec_static_blob_entriesNext(sqlite3_vtab_cursor *cur) {
return SQLITE_OK; return SQLITE_OK;
} }
} }
} }
static int vec_static_blob_entriesEof(sqlite3_vtab_cursor *cur) { static int vec_static_blob_entriesEof(sqlite3_vtab_cursor *cur) {
vec_static_blob_entries_cursor *pCur = (vec_static_blob_entries_cursor *)cur; vec_static_blob_entries_cursor *pCur = (vec_static_blob_entries_cursor *)cur;
vec_static_blob_entries_vtab * p = (vec_static_blob_entries_vtab *) pCur->base.pVtab; vec_static_blob_entries_vtab *p =
(vec_static_blob_entries_vtab *)pCur->base.pVtab;
switch (pCur->query_plan) { switch (pCur->query_plan) {
case VEC_SBE__QUERYPLAN_FULLSCAN: { case VEC_SBE__QUERYPLAN_FULLSCAN: {
return (size_t)pCur->iRowid >= p->blob->nvectors; return (size_t)pCur->iRowid >= p->blob->nvectors;
@ -5109,11 +5127,10 @@ static int vec_static_blob_entriesEof(sqlite3_vtab_cursor *cur) {
return pCur->knn_data->current_idx >= pCur->knn_data->k; return pCur->knn_data->current_idx >= pCur->knn_data->k;
} }
} }
} }
static int vec_static_blob_entriesColumn(sqlite3_vtab_cursor *cur, sqlite3_context *context, static int vec_static_blob_entriesColumn(sqlite3_vtab_cursor *cur,
int i) { sqlite3_context *context, int i) {
vec_static_blob_entries_cursor *pCur = (vec_static_blob_entries_cursor *)cur; vec_static_blob_entries_cursor *pCur = (vec_static_blob_entries_cursor *)cur;
vec_static_blob_entries_vtab *p = (vec_static_blob_entries_vtab *)cur->pVtab; vec_static_blob_entries_vtab *p = (vec_static_blob_entries_vtab *)cur->pVtab;
@ -5125,9 +5142,7 @@ static int vec_static_blob_entriesColumn(sqlite3_vtab_cursor *cur, sqlite3_conte
sqlite3_result_blob( sqlite3_result_blob(
context, context,
p->blob->p + (pCur->iRowid * p->blob->dimensions * sizeof(float)), p->blob->p + (pCur->iRowid * p->blob->dimensions * sizeof(float)),
p->blob->dimensions * sizeof(float), p->blob->dimensions * sizeof(float), SQLITE_STATIC);
SQLITE_STATIC
);
sqlite3_result_subtype(context, p->blob->element_type); sqlite3_result_subtype(context, p->blob->element_type);
break; break;
} }
@ -5139,11 +5154,8 @@ static int vec_static_blob_entriesColumn(sqlite3_vtab_cursor *cur, sqlite3_conte
i32 rowid = ((i32 *)pCur->knn_data->rowids)[pCur->knn_data->current_idx]; i32 rowid = ((i32 *)pCur->knn_data->rowids)[pCur->knn_data->current_idx];
sqlite3_result_blob( sqlite3_result_blob(
context, context, p->blob->p + (rowid * p->blob->dimensions * sizeof(float)),
p->blob->p + (rowid* p->blob->dimensions * sizeof(float)), p->blob->dimensions * sizeof(float), SQLITE_STATIC);
p->blob->dimensions * sizeof(float),
SQLITE_STATIC
);
sqlite3_result_subtype(context, p->blob->element_type); sqlite3_result_subtype(context, p->blob->element_type);
break; break;
} }
@ -5153,7 +5165,6 @@ static int vec_static_blob_entriesColumn(sqlite3_vtab_cursor *cur, sqlite3_conte
} }
} }
static sqlite3_module vec_static_blob_entriesModule = { static sqlite3_module vec_static_blob_entriesModule = {
/* iVersion */ 3, /* iVersion */ 3,
/* xCreate */ vec_static_blob_entriesCreate, /* xCreate */ vec_static_blob_entriesCreate,
@ -5195,7 +5206,9 @@ void dethrone2(int k, f32 *base_distances, i64 *base_rowids, size_t chunk_size,
size_t ptrA = 0; size_t ptrA = 0;
size_t ptrB = 0; size_t ptrB = 0;
for (int i = 0; i < k; i++) { for (int i = 0; i < k; i++) {
if (ptrA < chunk_size && (ptrB >= k || chunk_distances[chunk_top_idx[ptrA]] < base_distances[ptrB])) { if (ptrA < chunk_size &&
(ptrB >= k ||
chunk_distances[chunk_top_idx[ptrA]] < base_distances[ptrB])) {
(*out_rowids)[i] = chunk_rowids[chunk_top_idx[ptrA]]; (*out_rowids)[i] = chunk_rowids[chunk_top_idx[ptrA]];
(*out_distances)[i] = chunk_distances[chunk_top_idx[ptrA]]; (*out_distances)[i] = chunk_distances[chunk_top_idx[ptrA]];
ptrA++; ptrA++;
@ -5207,7 +5220,6 @@ void dethrone2(int k, f32 *base_distances, i64 *base_rowids, size_t chunk_size,
} }
} }
typedef struct vec_expo_vtab vec_expo_vtab; typedef struct vec_expo_vtab vec_expo_vtab;
struct vec_expo_vtab { struct vec_expo_vtab {
sqlite3_vtab base; sqlite3_vtab base;
@ -5224,7 +5236,6 @@ struct vec_expo_cursor {
struct vec0_query_knn_data *knn_data; struct vec0_query_knn_data *knn_data;
}; };
static int vec_expoConnect(sqlite3 *db, void *pAux, int argc, static int vec_expoConnect(sqlite3 *db, void *pAux, int argc,
const char *const *argv, sqlite3_vtab **ppVtab, const char *const *argv, sqlite3_vtab **ppVtab,
char **pzErr) { char **pzErr) {
@ -5276,7 +5287,8 @@ static int vec_expoClose(sqlite3_vtab_cursor *cur) {
return SQLITE_OK; return SQLITE_OK;
} }
static int vec_expoBestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) { static int vec_expoBestIndex(sqlite3_vtab *pVTab,
sqlite3_index_info *pIdxInfo) {
vec_expo_vtab *p = (vec_expo_vtab *)pVTab; vec_expo_vtab *p = (vec_expo_vtab *)pVTab;
int iMatchTerm = -1; int iMatchTerm = -1;
int iLimitTerm = -1; int iLimitTerm = -1;
@ -5345,8 +5357,7 @@ static int vec_expoBestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo)
pIdxInfo->aConstraintUsage[iKTerm].omit = 1; pIdxInfo->aConstraintUsage[iKTerm].omit = 1;
} }
} } else {
else {
pIdxInfo->idxNum = VEC_SBE__QUERYPLAN_FULLSCAN; pIdxInfo->idxNum = VEC_SBE__QUERYPLAN_FULLSCAN;
pIdxInfo->estimatedCost = 10000.0; pIdxInfo->estimatedCost = 10000.0;
pIdxInfo->estimatedRows = 10000; pIdxInfo->estimatedRows = 10000;
@ -5398,7 +5409,8 @@ static int vec_expoFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum,
assert(rc == SQLITE_OK); assert(rc == SQLITE_OK);
sqlite3_blob *baseVectorsBlob; sqlite3_blob *baseVectorsBlob;
sqlite3_blob_open(p->db, "main", p->table, p->column, 1, 0, &baseVectorsBlob); sqlite3_blob_open(p->db, "main", p->table, p->column, 1, 0,
&baseVectorsBlob);
int chunk_size = 200; int chunk_size = 200;
float *chunk = sqlite3_malloc(dimensions * chunk_size * sizeof(float)); float *chunk = sqlite3_malloc(dimensions * chunk_size * sizeof(float));
@ -5412,8 +5424,6 @@ static int vec_expoFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum,
i64 *chunk_rowids = sqlite3_malloc(chunk_size * sizeof(i64)); i64 *chunk_rowids = sqlite3_malloc(chunk_size * sizeof(i64));
todo_assert(chunk_rowids); todo_assert(chunk_rowids);
while (true) { while (true) {
int nused = 0; int nused = 0;
for (int i = 0; i < chunk_size; i++) { for (int i = 0; i < chunk_size; i++) {
@ -5428,13 +5438,16 @@ static int vec_expoFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum,
chunk_rowids[i] = rowid; chunk_rowids[i] = rowid;
rc = sqlite3_blob_reopen(baseVectorsBlob, rowid); rc = sqlite3_blob_reopen(baseVectorsBlob, rowid);
assert(rc == SQLITE_OK); assert(rc == SQLITE_OK);
assert(sqlite3_blob_bytes(baseVectorsBlob) == dimensions * sizeof(float)); assert(sqlite3_blob_bytes(baseVectorsBlob) ==
sqlite3_blob_read(baseVectorsBlob, &chunk[i * dimensions], dimensions * sizeof(float), 0); dimensions * sizeof(float));
sqlite3_blob_read(baseVectorsBlob, &chunk[i * dimensions],
dimensions * sizeof(float), 0);
} }
for (int i = 0; i < nused; i++) { for (int i = 0; i < nused; i++) {
const f32 *base_i = (chunk) + (i * dimensions); const f32 *base_i = (chunk) + (i * dimensions);
chunk_distances[i] = distance_l2_sqr_float(base_i, (f32 *)queryVector, &dimensions); chunk_distances[i] =
distance_l2_sqr_float(base_i, (f32 *)queryVector, &dimensions);
} }
i32 *chunk_top_idxs = sqlite3_malloc(nused * sizeof(i32)); i32 *chunk_top_idxs = sqlite3_malloc(nused * sizeof(i32));
@ -5443,8 +5456,8 @@ static int vec_expoFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum,
i64 *out_rowids; i64 *out_rowids;
f32 *out_distances; f32 *out_distances;
dethrone2(k, topk_distances, topk_rowids, /*chunk_size*/ nused, chunk_top_idxs, dethrone2(k, topk_distances, topk_rowids, /*chunk_size*/ nused,
chunk_distances, chunk_rowids, chunk_top_idxs, chunk_distances, chunk_rowids,
&out_rowids, &out_distances); &out_rowids, &out_distances);
for (int i = 0; i < k; i++) { for (int i = 0; i < k; i++) {
@ -5455,21 +5468,20 @@ static int vec_expoFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum,
sqlite3_free(out_distances); sqlite3_free(out_distances);
sqlite3_free(chunk_top_idxs); sqlite3_free(chunk_top_idxs);
if(nused < chunk_size) break; if (nused < chunk_size)
break;
} }
sqlite3_blob_close(baseVectorsBlob); sqlite3_blob_close(baseVectorsBlob);
sqlite3_finalize(stmtRowids); sqlite3_finalize(stmtRowids);
cleanup(queryVector); cleanup(queryVector);
knn_data->current_idx = 0; knn_data->current_idx = 0;
knn_data->k = k; knn_data->k = k;
knn_data->rowids = topk_rowids; knn_data->rowids = topk_rowids;
knn_data->distances = topk_distances; knn_data->distances = topk_distances;
pCur->knn_data = knn_data; pCur->knn_data = knn_data;
} } else {
else {
pCur->query_plan = VEC_SBE__QUERYPLAN_FULLSCAN; pCur->query_plan = VEC_SBE__QUERYPLAN_FULLSCAN;
pCur->iRowid = 0; pCur->iRowid = 0;
} }
@ -5505,7 +5517,6 @@ static int vec_expoNext(sqlite3_vtab_cursor *cur) {
return SQLITE_OK; return SQLITE_OK;
} }
} }
} }
static int vec_expoEof(sqlite3_vtab_cursor *cur) { static int vec_expoEof(sqlite3_vtab_cursor *cur) {
@ -5519,7 +5530,6 @@ static int vec_expoEof(sqlite3_vtab_cursor *cur) {
return pCur->knn_data->current_idx >= pCur->knn_data->k; return pCur->knn_data->current_idx >= pCur->knn_data->k;
} }
} }
} }
static int vec_expoColumn(sqlite3_vtab_cursor *cur, sqlite3_context *context, static int vec_expoColumn(sqlite3_vtab_cursor *cur, sqlite3_context *context,
@ -5537,7 +5547,8 @@ static int vec_expoColumn(sqlite3_vtab_cursor *cur, sqlite3_context *context,
break; break;
} }
case VEC_EXPO_DISTANCE: { case VEC_EXPO_DISTANCE: {
sqlite3_result_double(context, pCur->knn_data->distances[pCur->knn_data->current_idx]); sqlite3_result_double(
context, pCur->knn_data->distances[pCur->knn_data->current_idx]);
break; break;
} }
} }
@ -5546,7 +5557,6 @@ static int vec_expoColumn(sqlite3_vtab_cursor *cur, sqlite3_context *context,
} }
} }
static sqlite3_module vec_expoModule = { static sqlite3_module vec_expoModule = {
/* iVersion */ 3, /* iVersion */ 3,
/* xCreate */ vec_expoCreate, /* xCreate */ vec_expoCreate,
@ -5768,9 +5778,12 @@ __declspec(dllexport)
} }
} }
#ifdef SQLITE_VEC_ENABLE_EXPERIMENTAL #ifdef SQLITE_VEC_ENABLE_EXPERIMENTAL
rc = sqlite3_create_module_v2(db, "vec_static_blobs", &vec_static_blobsModule, static_blob_data, sqlite3_free); rc = sqlite3_create_module_v2(db, "vec_static_blobs", &vec_static_blobsModule,
static_blob_data, sqlite3_free);
assert(rc == SQLITE_OK); assert(rc == SQLITE_OK);
rc = sqlite3_create_module_v2(db, "vec_static_blob_entries", &vec_static_blob_entriesModule, static_blob_data, NULL); rc = sqlite3_create_module_v2(db, "vec_static_blob_entries",
&vec_static_blob_entriesModule,
static_blob_data, NULL);
assert(rc == SQLITE_OK); assert(rc == SQLITE_OK);
rc = sqlite3_create_module_v2(db, "vec_expo", &vec_expoModule, NULL, NULL); rc = sqlite3_create_module_v2(db, "vec_expo", &vec_expoModule, NULL, NULL);
assert(rc == SQLITE_OK); assert(rc == SQLITE_OK);