mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 00:36:56 +02:00
static updates
This commit is contained in:
parent
e91ccf38ff
commit
a0bc9404ce
3 changed files with 323 additions and 231 deletions
249
sqlite-vec.c
249
sqlite-vec.c
|
|
@ -736,7 +736,7 @@ static int fvec_from_value(sqlite3_value *value, f32 **vector,
|
|||
if (source[offset] == ',') {
|
||||
offset++;
|
||||
continue;
|
||||
} // TODO multiple commas in a row without digits?
|
||||
}
|
||||
if (source[offset] == ']')
|
||||
goto done;
|
||||
break;
|
||||
|
|
@ -2006,7 +2006,7 @@ int parse_vector_column(const char *source, int source_length,
|
|||
struct VectorColumnDefinition *outColumn) {
|
||||
// parses a vector column definition like so:
|
||||
// "abc float[123]", "abc_123 bit[1234]", eetc.
|
||||
// TODO: more precise error messages here.
|
||||
// https://github.com/asg017/sqlite-vec/issues/46
|
||||
int rc;
|
||||
struct Vec0Scanner scanner;
|
||||
struct Vec0Token token;
|
||||
|
|
@ -2989,7 +2989,7 @@ static int vec_npy_eachColumnBuffer(vec_npy_each_cursor *pCur,
|
|||
context,
|
||||
&((unsigned char *)
|
||||
pCur->vector)[pCur->iRowid * pCur->nDimensions * sizeof(f32)],
|
||||
pCur->nDimensions * sizeof(f32), SQLITE_STATIC);
|
||||
pCur->nDimensions * sizeof(f32), SQLITE_TRANSIENT);
|
||||
|
||||
break;
|
||||
}
|
||||
|
|
@ -3361,7 +3361,6 @@ cleanup:
|
|||
int vec0_get_id_value_from_rowid(vec0_vtab *pVtab, i64 rowid,
|
||||
sqlite3_value **out) {
|
||||
// PERF: different strategy than get_chunk_position?
|
||||
// TODO: test / evidence-of
|
||||
return vec0_get_chunk_position((vec0_vtab *)pVtab, rowid, out, NULL, NULL);
|
||||
}
|
||||
|
||||
|
|
@ -3461,7 +3460,6 @@ int vec0_get_vector_data(vec0_vtab *pVtab, i64 rowid, int vector_column_idx,
|
|||
"vectors", chunk_id, 0, &vectorBlob);
|
||||
|
||||
if (rc != SQLITE_OK) {
|
||||
// TODO evidence-of
|
||||
vtab_set_error(&pVtab->base,
|
||||
"Could not fetch vector data for %lld, opening blob failed",
|
||||
rowid);
|
||||
|
|
@ -4490,7 +4488,11 @@ void bitmap_and_inplace(u8 *base, u8 *other, i32 n) {
|
|||
}
|
||||
|
||||
void bitmap_set(u8 *bitmap, i32 position, int value) {
|
||||
bitmap[position / CHAR_BIT] |= value << (position % CHAR_BIT);
|
||||
if (value) {
|
||||
bitmap[position / CHAR_BIT] |= 1 << (position % CHAR_BIT);
|
||||
} else {
|
||||
bitmap[position / CHAR_BIT] &= ~(1 << (position % CHAR_BIT));
|
||||
}
|
||||
}
|
||||
|
||||
int bitmap_get(u8 *bitmap, i32 position) {
|
||||
|
|
@ -4502,6 +4504,11 @@ void bitmap_clear(u8 *bitmap, i32 n) {
|
|||
memset(bitmap, 0, n / CHAR_BIT);
|
||||
}
|
||||
|
||||
void bitmap_fill(u8 *bitmap, i32 n) {
|
||||
assert((n % 8) == 0);
|
||||
memset(bitmap, 0xFF, n / CHAR_BIT);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Finds the minimum k items in distances, and writes the indicies to
|
||||
* out.
|
||||
|
|
@ -4885,7 +4892,7 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
|
|||
if (dimensions != vector_column->dimensions) {
|
||||
vtab_set_error(
|
||||
&p->base,
|
||||
"Dimension mismatch for inserted vector for the \"%.*s\" column. "
|
||||
"Dimension mismatch for query vector for the \"%.*s\" column. "
|
||||
"Expected %d dimensions but received %d.",
|
||||
vector_column->name_length, vector_column->name,
|
||||
vector_column->dimensions, dimensions);
|
||||
|
|
@ -5728,8 +5735,7 @@ int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv,
|
|||
// a write-able blob of the validity column for the given chunk. Used to mark
|
||||
// validity bit
|
||||
sqlite3_blob *blobChunksValidity = NULL;
|
||||
// buffer for the valididty column for the given chunk. TODO maybe not needed
|
||||
// here?
|
||||
// buffer for the valididty column for the given chunk. Maybe not needed here?
|
||||
const unsigned char *bufferChunksValidity = NULL;
|
||||
int numReadVectors = 0;
|
||||
|
||||
|
|
@ -5951,9 +5957,9 @@ int vec0Update_Delete(sqlite3_vtab *pVTab, sqlite_int64 rowid) {
|
|||
return rc;
|
||||
}
|
||||
|
||||
// 3. zero out rowid in chunks.rowids TODO
|
||||
// 3. zero out rowid in chunks.rowids https://github.com/asg017/sqlite-vec/issues/54
|
||||
|
||||
// 4. zero out any data in vector chunks tables TODO
|
||||
// 4. zero out any data in vector chunks tables https://github.com/asg017/sqlite-vec/issues/54
|
||||
|
||||
// 5. delete from _rowids table
|
||||
rc = vec0Update_Delete_DeleteRowids(p, rowid);
|
||||
|
|
@ -5975,9 +5981,7 @@ int vec0Update_UpdateVectorColumn(vec0_vtab *p, i64 chunk_id, i64 chunk_offset,
|
|||
enum VectorElementType elementType;
|
||||
void *vector;
|
||||
vector_cleanup cleanup = vector_cleanup_noop;
|
||||
// TODO: Can't update non f32, bc subtypes are stripped from UPDATEs.
|
||||
// Need to 1) create a less strict vector_from_value, or 2) wait
|
||||
// for this to resolve: https://sqlite.org/forum/forumpost/65317ce9c6
|
||||
// https://github.com/asg017/sqlite-vec/issues/53
|
||||
rc = vector_from_value(valueVector, &vector, &dimensions, &elementType,
|
||||
&cleanup, &pzError);
|
||||
if (rc != SQLITE_OK) {
|
||||
|
|
@ -6449,12 +6453,36 @@ typedef enum {
|
|||
VEC_SBE__QUERYPLAN_KNN = 2
|
||||
} vec_sbe_query_plan;
|
||||
|
||||
struct sbe_query_knn_data {
|
||||
i64 k;
|
||||
i64 k_used;
|
||||
// Array of rowids of size k. Must be freed with sqlite3_free().
|
||||
i32 *rowids;
|
||||
// Array of distances of size k. Must be freed with sqlite3_free().
|
||||
f32 *distances;
|
||||
i64 current_idx;
|
||||
};
|
||||
void sbe_query_knn_data_clear(struct sbe_query_knn_data *knn_data) {
|
||||
if (!knn_data)
|
||||
return;
|
||||
|
||||
if (knn_data->rowids) {
|
||||
sqlite3_free(knn_data->rowids);
|
||||
knn_data->rowids = NULL;
|
||||
}
|
||||
if (knn_data->distances) {
|
||||
sqlite3_free(knn_data->distances);
|
||||
knn_data->distances = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
typedef struct vec_static_blob_entries_cursor vec_static_blob_entries_cursor;
|
||||
struct vec_static_blob_entries_cursor {
|
||||
sqlite3_vtab_cursor base;
|
||||
sqlite3_int64 iRowid;
|
||||
vec_sbe_query_plan query_plan;
|
||||
struct vec0_query_knn_data *knn_data;
|
||||
struct sbe_query_knn_data *knn_data;
|
||||
};
|
||||
|
||||
static int vec_static_blob_entriesConnect(sqlite3 *db, void *pAux, int argc,
|
||||
|
|
@ -6519,6 +6547,7 @@ static int vec_static_blob_entriesOpen(sqlite3_vtab *p,
|
|||
|
||||
static int vec_static_blob_entriesClose(sqlite3_vtab_cursor *cur) {
|
||||
vec_static_blob_entries_cursor *pCur = (vec_static_blob_entries_cursor *)cur;
|
||||
sqlite3_free(pCur->knn_data);
|
||||
sqlite3_free(pCur);
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
|
@ -6528,7 +6557,7 @@ static int vec_static_blob_entriesBestIndex(sqlite3_vtab *pVTab,
|
|||
vec_static_blob_entries_vtab *p = (vec_static_blob_entries_vtab *)pVTab;
|
||||
int iMatchTerm = -1;
|
||||
int iLimitTerm = -1;
|
||||
// int iRowidTerm = -1; // TODO point query
|
||||
// int iRowidTerm = -1; // https://github.com/asg017/sqlite-vec/issues/47
|
||||
int iKTerm = -1;
|
||||
|
||||
for (int i = 0; i < pIdxInfo->nConstraint; i++) {
|
||||
|
|
@ -6540,7 +6569,7 @@ static int vec_static_blob_entriesBestIndex(sqlite3_vtab *pVTab,
|
|||
if (op == SQLITE_INDEX_CONSTRAINT_MATCH &&
|
||||
iColumn == VEC_STATIC_BLOB_ENTRIES_VECTOR) {
|
||||
if (iMatchTerm > -1) {
|
||||
// TODO only 1 match operator at a time
|
||||
// https://github.com/asg017/sqlite-vec/issues/51
|
||||
return SQLITE_ERROR;
|
||||
}
|
||||
iMatchTerm = i;
|
||||
|
|
@ -6555,7 +6584,7 @@ static int vec_static_blob_entriesBestIndex(sqlite3_vtab *pVTab,
|
|||
}
|
||||
if (iMatchTerm >= 0) {
|
||||
if (iLimitTerm < 0 && iKTerm < 0) {
|
||||
// TODO: error, match on vector1 should require a limit for KNN
|
||||
// https://github.com/asg017/sqlite-vec/issues/51
|
||||
return SQLITE_ERROR;
|
||||
}
|
||||
if (iLimitTerm >= 0 && iKTerm >= 0) {
|
||||
|
|
@ -6566,7 +6595,7 @@ static int vec_static_blob_entriesBestIndex(sqlite3_vtab *pVTab,
|
|||
return SQLITE_CONSTRAINT;
|
||||
}
|
||||
if (pIdxInfo->nOrderBy > 1) {
|
||||
// TODO error, orderByConsumed is all or nothing, only 1 order by allowed
|
||||
// https://github.com/asg017/sqlite-vec/issues/51
|
||||
vtab_set_error(pVTab, "more than 1 ORDER BY clause provided");
|
||||
return SQLITE_CONSTRAINT;
|
||||
}
|
||||
|
|
@ -6615,8 +6644,9 @@ static int vec_static_blob_entriesFilter(sqlite3_vtab_cursor *pVtabCursor,
|
|||
(vec_static_blob_entries_vtab *)pCur->base.pVtab;
|
||||
|
||||
if (idxNum == VEC_SBE__QUERYPLAN_KNN) {
|
||||
assert(argc == 2);
|
||||
pCur->query_plan = VEC_SBE__QUERYPLAN_KNN;
|
||||
struct vec0_query_knn_data *knn_data;
|
||||
struct sbe_query_knn_data *knn_data;
|
||||
knn_data = sqlite3_malloc(sizeof(*knn_data));
|
||||
if (!knn_data) {
|
||||
return SQLITE_NOMEM;
|
||||
|
|
@ -6642,7 +6672,7 @@ static int vec_static_blob_entriesFilter(sqlite3_vtab_cursor *pVtabCursor,
|
|||
|
||||
i64 k = min(sqlite3_value_int64(argv[1]), (i64)p->blob->nvectors);
|
||||
if (k < 0) {
|
||||
// TODO handle
|
||||
// HANDLE https://github.com/asg017/sqlite-vec/issues/55
|
||||
return SQLITE_ERROR;
|
||||
}
|
||||
if (k == 0) {
|
||||
|
|
@ -6651,24 +6681,38 @@ static int vec_static_blob_entriesFilter(sqlite3_vtab_cursor *pVtabCursor,
|
|||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
i64 *topk_rowids = sqlite3_malloc(k * sizeof(i64));
|
||||
size_t bsize = (p->blob->nvectors + 7) & ~7;
|
||||
|
||||
|
||||
i32 *topk_rowids = sqlite3_malloc(k * sizeof(i32));
|
||||
if (!topk_rowids) {
|
||||
// TODO handle
|
||||
// HANDLE https://github.com/asg017/sqlite-vec/issues/55
|
||||
return SQLITE_ERROR;
|
||||
}
|
||||
f32 *distances = sqlite3_malloc(p->blob->nvectors * sizeof(f32));
|
||||
f32 *distances = sqlite3_malloc(bsize * sizeof(f32));
|
||||
if (!distances) {
|
||||
// TODO handle
|
||||
// HANDLE https://github.com/asg017/sqlite-vec/issues/55
|
||||
return SQLITE_ERROR;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < p->blob->nvectors; i++) {
|
||||
// https://github.com/asg017/sqlite-vec/issues/52
|
||||
float *v = ((float *)p->blob->p) + (i * p->blob->dimensions);
|
||||
distances[i] =
|
||||
distance_l2_sqr_float(v, (float *)queryVector, &p->blob->dimensions);
|
||||
// TODO other metrics
|
||||
}
|
||||
// min_idx(distances, k, topk_rowids, k); // TODO
|
||||
u8 * candidates = bitmap_new(bsize);
|
||||
assert(candidates);
|
||||
|
||||
u8 * taken = bitmap_new(bsize);
|
||||
assert(taken);
|
||||
|
||||
bitmap_fill(candidates, bsize);
|
||||
for(size_t i = bsize; i >= p->blob->nvectors; i--) {
|
||||
bitmap_set(candidates, i, 0);
|
||||
}
|
||||
i32 k_used = 0;
|
||||
min_idx(distances, bsize, candidates, topk_rowids, k, taken, &k_used);
|
||||
knn_data->current_idx = 0;
|
||||
knn_data->distances = distances;
|
||||
knn_data->k = k;
|
||||
|
|
@ -6686,9 +6730,19 @@ static int vec_static_blob_entriesFilter(sqlite3_vtab_cursor *pVtabCursor,
|
|||
static int vec_static_blob_entriesRowid(sqlite3_vtab_cursor *cur,
|
||||
sqlite_int64 *pRowid) {
|
||||
vec_static_blob_entries_cursor *pCur = (vec_static_blob_entries_cursor *)cur;
|
||||
switch (pCur->query_plan) {
|
||||
case VEC_SBE__QUERYPLAN_FULLSCAN: {
|
||||
*pRowid = pCur->iRowid;
|
||||
return SQLITE_OK;
|
||||
}
|
||||
case VEC_SBE__QUERYPLAN_KNN: {
|
||||
i32 rowid = ((i32 *)pCur->knn_data->rowids)[pCur->knn_data->current_idx];
|
||||
*pRowid = (sqlite3_int64) rowid;
|
||||
return SQLITE_OK;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
static int vec_static_blob_entriesNext(sqlite3_vtab_cursor *cur) {
|
||||
vec_static_blob_entries_cursor *pCur = (vec_static_blob_entries_cursor *)cur;
|
||||
|
|
@ -6732,7 +6786,7 @@ static int vec_static_blob_entriesColumn(sqlite3_vtab_cursor *cur,
|
|||
context,
|
||||
((unsigned char *)p->blob->p) +
|
||||
(pCur->iRowid * p->blob->dimensions * sizeof(float)),
|
||||
p->blob->dimensions * sizeof(float), SQLITE_STATIC);
|
||||
p->blob->dimensions * sizeof(float), SQLITE_TRANSIENT);
|
||||
sqlite3_result_subtype(context, p->blob->element_type);
|
||||
break;
|
||||
}
|
||||
|
|
@ -6742,11 +6796,10 @@ static int vec_static_blob_entriesColumn(sqlite3_vtab_cursor *cur,
|
|||
switch (i) {
|
||||
case VEC_STATIC_BLOB_ENTRIES_VECTOR: {
|
||||
i32 rowid = ((i32 *)pCur->knn_data->rowids)[pCur->knn_data->current_idx];
|
||||
|
||||
sqlite3_result_blob(context,
|
||||
((unsigned char *)p->blob->p) +
|
||||
(rowid * p->blob->dimensions * sizeof(float)),
|
||||
p->blob->dimensions * sizeof(float), SQLITE_STATIC);
|
||||
p->blob->dimensions * sizeof(float), SQLITE_TRANSIENT);
|
||||
sqlite3_result_subtype(context, p->blob->element_type);
|
||||
break;
|
||||
}
|
||||
|
|
@ -6758,7 +6811,7 @@ static int vec_static_blob_entriesColumn(sqlite3_vtab_cursor *cur,
|
|||
|
||||
static sqlite3_module vec_static_blob_entriesModule = {
|
||||
/* iVersion */ 3,
|
||||
/* xCreate */ vec_static_blob_entriesCreate, // TODO rm?
|
||||
/* xCreate */ vec_static_blob_entriesCreate, // handle rm? https://github.com/asg017/sqlite-vec/issues/55
|
||||
/* xConnect */ vec_static_blob_entriesConnect,
|
||||
/* xBestIndex */ vec_static_blob_entriesBestIndex,
|
||||
/* xDisconnect */ vec_static_blob_entriesDisconnect,
|
||||
|
|
@ -6787,91 +6840,6 @@ static sqlite3_module vec_static_blob_entriesModule = {
|
|||
};
|
||||
#pragma endregion
|
||||
|
||||
int sqlite3_mmap_warm(sqlite3 *db, const char *zDb) {
|
||||
int rc = SQLITE_OK;
|
||||
char *zSql = 0;
|
||||
int pgsz = 0;
|
||||
unsigned int nTotal = 0;
|
||||
|
||||
if (0 == sqlite3_get_autocommit(db))
|
||||
return SQLITE_MISUSE;
|
||||
|
||||
/* Open a read-only transaction on the file in question */
|
||||
zSql = sqlite3_mprintf("BEGIN; SELECT * FROM %s%q%ssqlite_schema",
|
||||
(zDb ? "'" : ""), (zDb ? zDb : ""), (zDb ? "'." : ""));
|
||||
if (zSql == 0)
|
||||
return SQLITE_NOMEM;
|
||||
rc = sqlite3_exec(db, zSql, 0, 0, 0);
|
||||
sqlite3_free(zSql);
|
||||
|
||||
/* Find the SQLite page size of the file */
|
||||
if (rc == SQLITE_OK) {
|
||||
zSql = sqlite3_mprintf("PRAGMA %s%q%spage_size", (zDb ? "'" : ""),
|
||||
(zDb ? zDb : ""), (zDb ? "'." : ""));
|
||||
if (zSql == 0) {
|
||||
rc = SQLITE_NOMEM;
|
||||
} else {
|
||||
sqlite3_stmt *pPgsz = 0;
|
||||
rc = sqlite3_prepare_v2(db, zSql, -1, &pPgsz, 0);
|
||||
sqlite3_free(zSql);
|
||||
if (rc == SQLITE_OK) {
|
||||
if (sqlite3_step(pPgsz) == SQLITE_ROW) {
|
||||
pgsz = sqlite3_column_int(pPgsz, 0);
|
||||
}
|
||||
rc = sqlite3_finalize(pPgsz);
|
||||
}
|
||||
if (rc == SQLITE_OK && pgsz == 0) {
|
||||
rc = SQLITE_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Touch each mmap'd page of the file */
|
||||
if (rc == SQLITE_OK) {
|
||||
int rc2;
|
||||
sqlite3_file *pFd = 0;
|
||||
rc = sqlite3_file_control(db, zDb, SQLITE_FCNTL_FILE_POINTER, &pFd);
|
||||
if (rc == SQLITE_OK && pFd->pMethods && pFd->pMethods->iVersion >= 3) {
|
||||
i64 iPg = 1;
|
||||
sqlite3_io_methods const *p = pFd->pMethods;
|
||||
while (1) {
|
||||
unsigned char *pMap;
|
||||
rc = p->xFetch(pFd, pgsz * iPg, pgsz, (void **)&pMap);
|
||||
if (rc != SQLITE_OK || pMap == 0)
|
||||
break;
|
||||
|
||||
nTotal += (unsigned int)pMap[0];
|
||||
nTotal += (unsigned int)pMap[pgsz - 1];
|
||||
|
||||
rc = p->xUnfetch(pFd, pgsz * iPg, (void *)pMap);
|
||||
if (rc != SQLITE_OK)
|
||||
break;
|
||||
iPg++;
|
||||
}
|
||||
sqlite3_log(SQLITE_OK,
|
||||
"sqlite3_mmap_warm_cache: Warmed up %d pages of %s",
|
||||
iPg == 1 ? 0 : iPg, sqlite3_db_filename(db, zDb));
|
||||
}
|
||||
|
||||
rc2 = sqlite3_exec(db, "END", 0, 0, 0);
|
||||
if (rc == SQLITE_OK)
|
||||
rc = rc2;
|
||||
}
|
||||
|
||||
(void)nTotal;
|
||||
return rc;
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
__declspec(dllexport)
|
||||
#endif
|
||||
int sqlite3_vec_warm_mmap(sqlite3 *db, char **pzErrMsg,
|
||||
const sqlite3_api_routines *pApi) {
|
||||
UNUSED_PARAMETER(pzErrMsg);
|
||||
SQLITE_EXTENSION_INIT2(pApi);
|
||||
return sqlite3_mmap_warm(db, NULL);
|
||||
}
|
||||
|
||||
#ifdef SQLITE_VEC_ENABLE_AVX
|
||||
#define SQLITE_VEC_DEBUG_BUILD_AVX "avx"
|
||||
#else
|
||||
|
|
@ -6907,12 +6875,6 @@ __declspec(dllexport)
|
|||
const sqlite3_api_routines *pApi) {
|
||||
SQLITE_EXTENSION_INIT2(pApi);
|
||||
int rc = SQLITE_OK;
|
||||
vec_static_blob_data *static_blob_data;
|
||||
static_blob_data = sqlite3_malloc(sizeof(*static_blob_data));
|
||||
if (!static_blob_data) {
|
||||
return SQLITE_NOMEM;
|
||||
}
|
||||
memset(static_blob_data, 0, sizeof(*static_blob_data));
|
||||
|
||||
#define DEFAULT_FLAGS (SQLITE_UTF8 | SQLITE_INNOCUOUS | SQLITE_DETERMINISTIC)
|
||||
|
||||
|
|
@ -6953,7 +6915,6 @@ __declspec(dllexport)
|
|||
{"vec_int8", vec_int8, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, },
|
||||
{"vec_quantize_int8", vec_quantize_int8, 2, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, },
|
||||
{"vec_quantize_binary", vec_quantize_binary, 1, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, },
|
||||
{"vec_static_blob_from_raw", vec_static_blob_from_raw, 4, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE },
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
|
|
@ -6989,13 +6950,7 @@ __declspec(dllexport)
|
|||
return rc;
|
||||
}
|
||||
}
|
||||
rc = sqlite3_create_module_v2(db, "vec_static_blobs", &vec_static_blobsModule,
|
||||
static_blob_data, sqlite3_free);
|
||||
assert(rc == SQLITE_OK);
|
||||
rc = sqlite3_create_module_v2(db, "vec_static_blob_entries",
|
||||
&vec_static_blob_entriesModule,
|
||||
static_blob_data, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
|
|
@ -7014,23 +6969,33 @@ __declspec(dllexport)
|
|||
}
|
||||
#endif
|
||||
|
||||
#ifdef SQLITE_VEC_ENABLE_TRACE_ENTRYPOINT
|
||||
|
||||
int trace(unsigned int x, void *p1, void *p2, void *p3) {
|
||||
if (x == SQLITE_TRACE_STMT) {
|
||||
sqlite3_stmt *stmt = (sqlite3_stmt *)p2;
|
||||
char *zSql = sqlite3_expanded_sql(stmt);
|
||||
printf("%s\n", zSql);
|
||||
}
|
||||
}
|
||||
#ifdef _WIN32
|
||||
__declspec(dllexport)
|
||||
#endif
|
||||
int trace_debug(sqlite3 *db, char **pzErrMsg,
|
||||
int sqlite3_vec_static_blobs_init(sqlite3 *db, char **pzErrMsg,
|
||||
const sqlite3_api_routines *pApi) {
|
||||
UNUSED_PARAMETER(pzErrMsg);
|
||||
SQLITE_EXTENSION_INIT2(pApi);
|
||||
sqlite3_trace_v2(db, SQLITE_TRACE_STMT, trace, NULL);
|
||||
return SQLITE_OK;
|
||||
|
||||
int rc = SQLITE_OK;
|
||||
vec_static_blob_data *static_blob_data;
|
||||
static_blob_data = sqlite3_malloc(sizeof(*static_blob_data));
|
||||
if (!static_blob_data) {
|
||||
return SQLITE_NOMEM;
|
||||
}
|
||||
memset(static_blob_data, 0, sizeof(*static_blob_data));
|
||||
|
||||
rc = sqlite3_create_function_v2(db, "vec_static_blob_from_raw", 4, DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE,
|
||||
NULL, vec_static_blob_from_raw, NULL, NULL, NULL);
|
||||
if(rc != SQLITE_OK) return rc;
|
||||
|
||||
rc = sqlite3_create_module_v2(db, "vec_static_blobs", &vec_static_blobsModule,
|
||||
static_blob_data, sqlite3_free);
|
||||
if(rc != SQLITE_OK) return rc;
|
||||
rc = sqlite3_create_module_v2(db, "vec_static_blob_entries",
|
||||
&vec_static_blob_entriesModule,
|
||||
static_blob_data, NULL);
|
||||
if(rc != SQLITE_OK) return rc;
|
||||
return rc;
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -80,8 +80,6 @@ def connect(ext, path=":memory:", extra_entrypoint=None):
|
|||
|
||||
db = connect(EXT_PATH)
|
||||
|
||||
# db.load_extension(EXT_PATH, entrypoint="trace_debug")
|
||||
|
||||
|
||||
def explain_query_plan(sql):
|
||||
return db.execute("explain query plan " + sql).fetchone()["detail"]
|
||||
|
|
@ -113,7 +111,6 @@ FUNCTIONS = [
|
|||
"vec_quantize_binary",
|
||||
"vec_quantize_int8",
|
||||
"vec_slice",
|
||||
"vec_static_blob_from_raw",
|
||||
"vec_sub",
|
||||
"vec_to_json",
|
||||
"vec_type",
|
||||
|
|
@ -123,24 +120,150 @@ MODULES = [
|
|||
"vec0",
|
||||
"vec_each",
|
||||
"vec_npy_each",
|
||||
"vec_static_blob_entries",
|
||||
"vec_static_blobs",
|
||||
#"vec_static_blob_entries",
|
||||
#"vec_static_blobs",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="TODO")
|
||||
def test_vec_static_blob_from_raw():
|
||||
pass
|
||||
|
||||
def register_numpy(db, name: str, array):
|
||||
ptr = array.__array_interface__["data"][0]
|
||||
nvectors, dimensions = array.__array_interface__["shape"]
|
||||
element_type = array.__array_interface__["typestr"]
|
||||
|
||||
assert element_type == "<f4"
|
||||
|
||||
name_escaped = db.execute("select printf('%w', ?)", [name]).fetchone()[0]
|
||||
|
||||
db.execute(
|
||||
"""
|
||||
insert into temp.vec_static_blobs(name, data)
|
||||
select ?, vec_static_blob_from_raw(?, ?, ?, ?)
|
||||
""",
|
||||
[name, ptr, element_type, dimensions, nvectors],
|
||||
)
|
||||
|
||||
db.execute(
|
||||
f'create virtual table "{name_escaped}" using vec_static_blob_entries({name_escaped})'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="TODO")
|
||||
def test_vec_static_blobs():
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="TODO")
|
||||
def test_vec_static_blob_entries():
|
||||
pass
|
||||
db = connect(EXT_PATH, extra_entrypoint="sqlite3_vec_static_blobs_init")
|
||||
|
||||
x = np.array([[0.1, 0.2, 0.3, 0.4], [0.9, 0.8, 0.7, 0.6]], dtype=np.float32)
|
||||
y = np.array([[0.2, 0.3], [0.9, 0.8], [0.6, 0.5]], dtype=np.float32)
|
||||
z = np.array(
|
||||
[
|
||||
[0.1, 0.1, 0.1, 0.1],
|
||||
[0.2, 0.2, 0.2, 0.2],
|
||||
[0.3, 0.3, 0.3, 0.3],
|
||||
[0.4, 0.4, 0.4, 0.4],
|
||||
[0.5, 0.5, 0.5, 0.5],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
register_numpy(db, "x", x)
|
||||
register_numpy(db, "y", y)
|
||||
register_numpy(db, "z", z)
|
||||
assert execute_all(
|
||||
db, "select *, dimensions, count from temp.vec_static_blobs;"
|
||||
) == [
|
||||
{
|
||||
"count": 2,
|
||||
"data": None,
|
||||
"dimensions": 4,
|
||||
"name": "x",
|
||||
},
|
||||
{
|
||||
"count": 3,
|
||||
"data": None,
|
||||
"dimensions": 2,
|
||||
"name": "y",
|
||||
},
|
||||
{
|
||||
"count": 5,
|
||||
"data": None,
|
||||
"dimensions": 4,
|
||||
"name": "z",
|
||||
},
|
||||
]
|
||||
|
||||
assert execute_all(db, "select vec_to_json(vector) from x;") == [
|
||||
{
|
||||
"vec_to_json(vector)": "[0.100000,0.200000,0.300000,0.400000]",
|
||||
},
|
||||
{
|
||||
"vec_to_json(vector)": "[0.900000,0.800000,0.700000,0.600000]",
|
||||
},
|
||||
]
|
||||
assert execute_all(db, "select (vector) from y limit 2;") == [
|
||||
{
|
||||
"vector": b"\xcd\xccL>\x9a\x99\x99>",
|
||||
},
|
||||
{
|
||||
"vector": b"fff?\xcd\xccL?",
|
||||
},
|
||||
]
|
||||
assert execute_all(db, "select rowid, (vector) from z") == [
|
||||
{
|
||||
"rowid": 0,
|
||||
"vector": b"\xcd\xcc\xcc=\xcd\xcc\xcc=\xcd\xcc\xcc=\xcd\xcc\xcc=",
|
||||
},
|
||||
{
|
||||
"rowid": 1,
|
||||
"vector": b"\xcd\xccL>\xcd\xccL>\xcd\xccL>\xcd\xccL>",
|
||||
},
|
||||
{
|
||||
"rowid": 2,
|
||||
"vector": b"\x9a\x99\x99>\x9a\x99\x99>\x9a\x99\x99>\x9a\x99\x99>",
|
||||
},
|
||||
{
|
||||
"rowid": 3,
|
||||
"vector": b"\xcd\xcc\xcc>\xcd\xcc\xcc>\xcd\xcc\xcc>\xcd\xcc\xcc>",
|
||||
},
|
||||
{
|
||||
"rowid": 4,
|
||||
"vector": b"\x00\x00\x00?\x00\x00\x00?\x00\x00\x00?\x00\x00\x00?",
|
||||
},
|
||||
]
|
||||
assert execute_all(
|
||||
db,
|
||||
"select rowid, vec_to_json(vector) as v from z where vector match ? and k = 3 order by distance;",
|
||||
[np.array([0.3, 0.3, 0.3, 0.3], dtype=np.float32)],
|
||||
) == [
|
||||
{
|
||||
"rowid": 2,
|
||||
"v": "[0.300000,0.300000,0.300000,0.300000]",
|
||||
},
|
||||
{
|
||||
"rowid": 3,
|
||||
"v": "[0.400000,0.400000,0.400000,0.400000]",
|
||||
},
|
||||
{
|
||||
"rowid": 1,
|
||||
"v": "[0.200000,0.200000,0.200000,0.200000]",
|
||||
},
|
||||
]
|
||||
assert execute_all(
|
||||
db,
|
||||
"select rowid, vec_to_json(vector) as v from z where vector match ? and k = 3 order by distance;",
|
||||
[np.array([0.6, 0.6, 0.6, 0.6], dtype=np.float32)],
|
||||
) == [
|
||||
{
|
||||
"rowid": 4,
|
||||
"v": "[0.500000,0.500000,0.500000,0.500000]",
|
||||
},
|
||||
{
|
||||
"rowid": 3,
|
||||
"v": "[0.400000,0.400000,0.400000,0.400000]",
|
||||
},
|
||||
{
|
||||
"rowid": 2,
|
||||
"v": "[0.300000,0.300000,0.300000,0.300000]",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_funcs():
|
||||
|
|
@ -1872,7 +1995,7 @@ def test_vec0_knn():
|
|||
db.execute("select * from v where aaa match vec_bit(X'AA') and k = 10")
|
||||
|
||||
with _raises(
|
||||
'Dimension mismatch for inserted vector for the "aaa" column. Expected 8 dimensions but received 1.'
|
||||
'Dimension mismatch for query vector for the "aaa" column. Expected 8 dimensions but received 1.'
|
||||
):
|
||||
db.execute("select * from v where aaa match vec_f32('[.1]') and k = 10")
|
||||
|
||||
|
|
@ -2120,36 +2243,42 @@ def test_smoke():
|
|||
|
||||
db.execute("insert into vec_xyz(rowid, a) select 2, X'0000000000000040'")
|
||||
chunk = db.execute("select * from vec_xyz_chunks").fetchone()
|
||||
assert chunk[
|
||||
"rowids"
|
||||
] == b"\x01\x00\x00\x00\x00\x00\x00\x00" + b"\x02\x00\x00\x00\x00\x00\x00\x00" + bytearray(
|
||||
int(1024 * 8) - 8 * 2
|
||||
assert (
|
||||
chunk["rowids"]
|
||||
== b"\x01\x00\x00\x00\x00\x00\x00\x00"
|
||||
+ b"\x02\x00\x00\x00\x00\x00\x00\x00"
|
||||
+ bytearray(int(1024 * 8) - 8 * 2)
|
||||
)
|
||||
assert chunk["chunk_id"] == 1
|
||||
assert chunk["validity"] == b"\x03" + bytearray(int(1024 / 8) - 1)
|
||||
vchunk = db.execute("select * from vec_xyz_vector_chunks00").fetchone()
|
||||
assert vchunk["rowid"] == 1
|
||||
assert vchunk[
|
||||
"vectors"
|
||||
] == b"\x00\x00\x00\x00\x00\x00\x80\x3f" + b"\x00\x00\x00\x00\x00\x00\x00\x40" + bytearray(
|
||||
int(1024 * 4 * 2) - (2 * 4 * 2)
|
||||
assert (
|
||||
vchunk["vectors"]
|
||||
== b"\x00\x00\x00\x00\x00\x00\x80\x3f"
|
||||
+ b"\x00\x00\x00\x00\x00\x00\x00\x40"
|
||||
+ bytearray(int(1024 * 4 * 2) - (2 * 4 * 2))
|
||||
)
|
||||
|
||||
db.execute("insert into vec_xyz(rowid, a) select 3, X'00000000000080bf'")
|
||||
chunk = db.execute("select * from vec_xyz_chunks").fetchone()
|
||||
assert chunk["chunk_id"] == 1
|
||||
assert chunk["validity"] == b"\x07" + bytearray(int(1024 / 8) - 1)
|
||||
assert chunk[
|
||||
"rowids"
|
||||
] == b"\x01\x00\x00\x00\x00\x00\x00\x00" + b"\x02\x00\x00\x00\x00\x00\x00\x00" + b"\x03\x00\x00\x00\x00\x00\x00\x00" + bytearray(
|
||||
int(1024 * 8) - 8 * 3
|
||||
assert (
|
||||
chunk["rowids"]
|
||||
== b"\x01\x00\x00\x00\x00\x00\x00\x00"
|
||||
+ b"\x02\x00\x00\x00\x00\x00\x00\x00"
|
||||
+ b"\x03\x00\x00\x00\x00\x00\x00\x00"
|
||||
+ bytearray(int(1024 * 8) - 8 * 3)
|
||||
)
|
||||
vchunk = db.execute("select * from vec_xyz_vector_chunks00").fetchone()
|
||||
assert vchunk["rowid"] == 1
|
||||
assert vchunk[
|
||||
"vectors"
|
||||
] == b"\x00\x00\x00\x00\x00\x00\x80\x3f" + b"\x00\x00\x00\x00\x00\x00\x00\x40" + b"\x00\x00\x00\x00\x00\x00\x80\xbf" + bytearray(
|
||||
int(1024 * 4 * 2) - (2 * 4 * 3)
|
||||
assert (
|
||||
vchunk["vectors"]
|
||||
== b"\x00\x00\x00\x00\x00\x00\x80\x3f"
|
||||
+ b"\x00\x00\x00\x00\x00\x00\x00\x40"
|
||||
+ b"\x00\x00\x00\x00\x00\x00\x80\xbf"
|
||||
+ bytearray(int(1024 * 4 * 2) - (2 * 4 * 3))
|
||||
)
|
||||
|
||||
# db.execute("select * from vec_xyz")
|
||||
|
|
@ -2192,8 +2321,7 @@ def test_vec0_stress_small_chunks():
|
|||
{"rowid": 994, "a": _f32([99.4] * 8)},
|
||||
{"rowid": 993, "a": _f32([99.3] * 8)},
|
||||
]
|
||||
assert (
|
||||
execute_all(
|
||||
assert execute_all(
|
||||
db,
|
||||
"""
|
||||
select rowid, a, distance
|
||||
|
|
@ -2203,8 +2331,7 @@ def test_vec0_stress_small_chunks():
|
|||
order by distance
|
||||
""",
|
||||
[_f32([50.0] * 8)],
|
||||
)
|
||||
== [
|
||||
) == [
|
||||
{
|
||||
"a": _f32([500 * 0.1] * 8),
|
||||
"distance": 0.0,
|
||||
|
|
@ -2251,7 +2378,6 @@ def test_vec0_stress_small_chunks():
|
|||
"rowid": 504,
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_vec0_distance_metric():
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ db = sqlite3.connect(":memory:")
|
|||
|
||||
db.enable_load_extension(True)
|
||||
db.load_extension("./dist/vec0")
|
||||
db.execute("select load_extension('./dist/vec0', 'sqlite3_vec_raw_init')")
|
||||
db.enable_load_extension(False)
|
||||
|
||||
x = np.array([[0.1, 0.2, 0.3, 0.4], [0.9, 0.8, 0.7, 0.6]], dtype=np.float32)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue