From ba0db0b6d6d644bf1d5eedb74a378b4683d91353 Mon Sep 17 00:00:00 2001 From: Alex Garcia Date: Sun, 29 Mar 2026 19:45:54 -0700 Subject: [PATCH] Add rescore index for ANN queries Add rescore index type: stores full-precision float vectors in a rowid-keyed shadow table, quantizes to int8 for fast initial scan, then rescores top candidates with original vectors. Includes config parser, shadow table management, insert/delete support, KNN integration, compile flag (SQLITE_VEC_ENABLE_RESCORE), fuzz targets, and tests. --- Makefile | 2 +- benchmarks-ann/Makefile | 13 +- benchmarks-ann/bench.py | 33 ++ sqlite-vec-rescore.c | 662 ++++++++++++++++++++++++++++ sqlite-vec.c | 435 +++++++++++++++++- tests/fuzz/.gitignore | 5 + tests/fuzz/Makefile | 26 +- tests/fuzz/rescore-create.c | 36 ++ tests/fuzz/rescore-create.dict | 20 + tests/fuzz/rescore-interleave.c | 151 +++++++ tests/fuzz/rescore-knn-deep.c | 178 ++++++++ tests/fuzz/rescore-operations.c | 96 ++++ tests/fuzz/rescore-quantize-edge.c | 177 ++++++++ tests/fuzz/rescore-quantize.c | 54 +++ tests/fuzz/rescore-shadow-corrupt.c | 230 ++++++++++ tests/sqlite-vec-internal.h | 25 ++ tests/test-rescore-mutations.py | 470 ++++++++++++++++++++ tests/test-rescore.py | 568 ++++++++++++++++++++++++ tests/test-unit.c | 205 +++++++++ 19 files changed, 3378 insertions(+), 8 deletions(-) create mode 100644 sqlite-vec-rescore.c create mode 100644 tests/fuzz/rescore-create.c create mode 100644 tests/fuzz/rescore-create.dict create mode 100644 tests/fuzz/rescore-interleave.c create mode 100644 tests/fuzz/rescore-knn-deep.c create mode 100644 tests/fuzz/rescore-operations.c create mode 100644 tests/fuzz/rescore-quantize-edge.c create mode 100644 tests/fuzz/rescore-quantize.c create mode 100644 tests/fuzz/rescore-shadow-corrupt.c create mode 100644 tests/test-rescore-mutations.py create mode 100644 tests/test-rescore.py diff --git a/Makefile b/Makefile index 051590e..b604171 100644 --- a/Makefile +++ b/Makefile @@ -202,7 +202,7 @@ test-loadable-watch: watchexec --exts c,py,Makefile --clear -- make test-loadable test-unit: - $(CC) -DSQLITE_CORE -DSQLITE_VEC_TEST tests/test-unit.c sqlite-vec.c vendor/sqlite3.c -I./ -Ivendor -o $(prefix)/test-unit && $(prefix)/test-unit + $(CC) -DSQLITE_CORE -DSQLITE_VEC_TEST -DSQLITE_VEC_ENABLE_RESCORE tests/test-unit.c sqlite-vec.c vendor/sqlite3.c -I./ -Ivendor -o $(prefix)/test-unit && $(prefix)/test-unit fuzz-build: $(MAKE) -C tests/fuzz all diff --git a/benchmarks-ann/Makefile b/benchmarks-ann/Makefile index 59e2dcd..762abea 100644 --- a/benchmarks-ann/Makefile +++ b/benchmarks-ann/Makefile @@ -21,9 +21,14 @@ BASELINES = \ # ANNOY_CONFIGS = \ # "annoy-t50:type=annoy,n_trees=50" -ALL_CONFIGS = $(BASELINES) +RESCORE_CONFIGS = \ + "rescore-bit-os8:type=rescore,quantizer=bit,oversample=8" \ + "rescore-bit-os16:type=rescore,quantizer=bit,oversample=16" \ + "rescore-int8-os8:type=rescore,quantizer=int8,oversample=8" -.PHONY: seed ground-truth bench-smoke bench-10k bench-50k bench-100k bench-all \ +ALL_CONFIGS = $(BASELINES) $(RESCORE_CONFIGS) + +.PHONY: seed ground-truth bench-smoke bench-rescore bench-10k bench-50k bench-100k bench-all \ report clean # --- Data preparation --- @@ -40,6 +45,10 @@ bench-smoke: seed $(BENCH) --subset-size 5000 -k 10 -n 20 -o runs/smoke \ $(BASELINES) +bench-rescore: seed + $(BENCH) --subset-size 10000 -k 10 -o runs/rescore \ + $(RESCORE_CONFIGS) + # --- Standard sizes --- bench-10k: seed $(BENCH) --subset-size 10000 -k 10 -o runs/10k $(ALL_CONFIGS) diff --git a/benchmarks-ann/bench.py b/benchmarks-ann/bench.py index 93f8f82..c1179d6 100644 --- a/benchmarks-ann/bench.py +++ b/benchmarks-ann/bench.py @@ -140,6 +140,39 @@ INDEX_REGISTRY["baseline"] = { } +# ============================================================================ +# Rescore implementation +# ============================================================================ + + +def _rescore_create_table_sql(params): + quantizer = params.get("quantizer", "bit") + oversample = params.get("oversample", 8) + return ( + f"CREATE VIRTUAL TABLE vec_items USING vec0(" + f" chunk_size=256," + f" id integer primary key," + f" embedding float[768] distance_metric=cosine" + f" indexed by rescore(quantizer={quantizer}, oversample={oversample}))" + ) + + +def _rescore_describe(params): + q = params.get("quantizer", "bit") + os = params.get("oversample", 8) + return f"rescore {q} (os={os})" + + +INDEX_REGISTRY["rescore"] = { + "defaults": {"quantizer": "bit", "oversample": 8}, + "create_table_sql": _rescore_create_table_sql, + "insert_sql": None, + "post_insert_hook": None, + "run_query": None, # default MATCH query works — rescore is automatic + "describe": _rescore_describe, +} + + # ============================================================================ # Config parsing # ============================================================================ diff --git a/sqlite-vec-rescore.c b/sqlite-vec-rescore.c new file mode 100644 index 0000000..a45f52f --- /dev/null +++ b/sqlite-vec-rescore.c @@ -0,0 +1,662 @@ +/** + * sqlite-vec-rescore.c — Rescore index logic for sqlite-vec. + * + * This file is #included into sqlite-vec.c after the vec0_vtab definition. + * All functions receive a vec0_vtab *p and access p->vector_columns[i].rescore. + * + * Shadow tables per rescore-enabled vector column: + * _rescore_chunks{NN} — quantized vectors in chunk layout (for coarse scan) + * _rescore_vectors{NN} — float vectors keyed by rowid (for fast rescore lookup) + */ + +// ============================================================================ +// Shadow table lifecycle +// ============================================================================ + +static int rescore_create_tables(vec0_vtab *p, sqlite3 *db, char **pzErr) { + for (int i = 0; i < p->numVectorColumns; i++) { + if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_RESCORE) + continue; + + // Quantized chunk table (same layout as _vector_chunks) + char *zSql = sqlite3_mprintf( + "CREATE TABLE \"%w\".\"%w_rescore_chunks%02d\"" + "(rowid PRIMARY KEY, vectors BLOB NOT NULL)", + p->schemaName, p->tableName, i); + if (!zSql) + return SQLITE_NOMEM; + sqlite3_stmt *stmt; + int rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, 0); + sqlite3_free(zSql); + if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) { + *pzErr = sqlite3_mprintf( + "Could not create '_rescore_chunks%02d' shadow table: %s", i, + sqlite3_errmsg(db)); + sqlite3_finalize(stmt); + return SQLITE_ERROR; + } + sqlite3_finalize(stmt); + + // Float vector table (rowid-keyed for fast random access) + zSql = sqlite3_mprintf( + "CREATE TABLE \"%w\".\"%w_rescore_vectors%02d\"" + "(rowid INTEGER PRIMARY KEY, vector BLOB NOT NULL)", + p->schemaName, p->tableName, i); + if (!zSql) + return SQLITE_NOMEM; + rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, 0); + sqlite3_free(zSql); + if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) { + *pzErr = sqlite3_mprintf( + "Could not create '_rescore_vectors%02d' shadow table: %s", i, + sqlite3_errmsg(db)); + sqlite3_finalize(stmt); + return SQLITE_ERROR; + } + sqlite3_finalize(stmt); + } + return SQLITE_OK; +} + +static int rescore_drop_tables(vec0_vtab *p) { + for (int i = 0; i < p->numVectorColumns; i++) { + sqlite3_stmt *stmt; + int rc; + char *zSql; + + if (p->shadowRescoreChunksNames[i]) { + zSql = sqlite3_mprintf("DROP TABLE IF EXISTS \"%w\".\"%w\"", + p->schemaName, p->shadowRescoreChunksNames[i]); + if (!zSql) + return SQLITE_NOMEM; + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0); + sqlite3_free(zSql); + if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) { + sqlite3_finalize(stmt); + return SQLITE_ERROR; + } + sqlite3_finalize(stmt); + } + + if (p->shadowRescoreVectorsNames[i]) { + zSql = sqlite3_mprintf("DROP TABLE IF EXISTS \"%w\".\"%w\"", + p->schemaName, p->shadowRescoreVectorsNames[i]); + if (!zSql) + return SQLITE_NOMEM; + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0); + sqlite3_free(zSql); + if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) { + sqlite3_finalize(stmt); + return SQLITE_ERROR; + } + sqlite3_finalize(stmt); + } + } + return SQLITE_OK; +} + +static size_t rescore_quantized_byte_size(struct VectorColumnDefinition *col) { + switch (col->rescore.quantizer_type) { + case VEC0_RESCORE_QUANTIZER_BIT: + return col->dimensions / CHAR_BIT; + case VEC0_RESCORE_QUANTIZER_INT8: + return col->dimensions; + default: + return 0; + } +} + +/** + * Insert a new chunk row into each _rescore_chunks{NN} table with a zeroblob. + */ +static int rescore_new_chunk(vec0_vtab *p, i64 chunk_rowid) { + for (int i = 0; i < p->numVectorColumns; i++) { + if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_RESCORE) + continue; + size_t quantized_size = + rescore_quantized_byte_size(&p->vector_columns[i]); + i64 blob_size = (i64)p->chunk_size * (i64)quantized_size; + + char *zSql = sqlite3_mprintf( + "INSERT INTO \"%w\".\"%w\"(_rowid_, rowid, vectors) VALUES (?, ?, ?)", + p->schemaName, p->shadowRescoreChunksNames[i]); + if (!zSql) + return SQLITE_NOMEM; + sqlite3_stmt *stmt; + int rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) { + sqlite3_finalize(stmt); + return rc; + } + sqlite3_bind_int64(stmt, 1, chunk_rowid); + sqlite3_bind_int64(stmt, 2, chunk_rowid); + sqlite3_bind_zeroblob64(stmt, 3, blob_size); + rc = sqlite3_step(stmt); + sqlite3_finalize(stmt); + if (rc != SQLITE_DONE) + return rc; + } + return SQLITE_OK; +} + +// ============================================================================ +// Quantization +// ============================================================================ + +static void rescore_quantize_float_to_bit(const float *src, uint8_t *dst, + size_t dimensions) { + memset(dst, 0, dimensions / CHAR_BIT); + for (size_t i = 0; i < dimensions; i++) { + if (src[i] >= 0.0f) { + dst[i / CHAR_BIT] |= (1 << (i % CHAR_BIT)); + } + } +} + +static void rescore_quantize_float_to_int8(const float *src, int8_t *dst, + size_t dimensions) { + float vmin = src[0], vmax = src[0]; + for (size_t i = 1; i < dimensions; i++) { + if (src[i] < vmin) vmin = src[i]; + if (src[i] > vmax) vmax = src[i]; + } + float range = vmax - vmin; + if (range == 0.0f) { + memset(dst, 0, dimensions); + return; + } + float scale = 255.0f / range; + for (size_t i = 0; i < dimensions; i++) { + float v = (src[i] - vmin) * scale - 128.0f; + if (v < -128.0f) v = -128.0f; + if (v > 127.0f) v = 127.0f; + dst[i] = (int8_t)v; + } +} + +// ============================================================================ +// Insert path +// ============================================================================ + +/** + * Quantize float vector to _rescore_chunks and store in _rescore_vectors. + */ +static int rescore_on_insert(vec0_vtab *p, i64 chunk_rowid, i64 chunk_offset, + i64 rowid, void *vectorDatas[]) { + for (int i = 0; i < p->numVectorColumns; i++) { + if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_RESCORE) + continue; + + struct VectorColumnDefinition *col = &p->vector_columns[i]; + size_t qsize = rescore_quantized_byte_size(col); + size_t fsize = vector_column_byte_size(*col); + int rc; + + // 1. Write quantized vector to _rescore_chunks blob + { + void *qbuf = sqlite3_malloc(qsize); + if (!qbuf) + return SQLITE_NOMEM; + + switch (col->rescore.quantizer_type) { + case VEC0_RESCORE_QUANTIZER_BIT: + rescore_quantize_float_to_bit((const float *)vectorDatas[i], + (uint8_t *)qbuf, col->dimensions); + break; + case VEC0_RESCORE_QUANTIZER_INT8: + rescore_quantize_float_to_int8((const float *)vectorDatas[i], + (int8_t *)qbuf, col->dimensions); + break; + } + + sqlite3_blob *blob = NULL; + rc = sqlite3_blob_open(p->db, p->schemaName, + p->shadowRescoreChunksNames[i], "vectors", + chunk_rowid, 1, &blob); + if (rc != SQLITE_OK) { + sqlite3_free(qbuf); + return rc; + } + rc = sqlite3_blob_write(blob, qbuf, qsize, chunk_offset * qsize); + sqlite3_free(qbuf); + int brc = sqlite3_blob_close(blob); + if (rc != SQLITE_OK) + return rc; + if (brc != SQLITE_OK) + return brc; + } + + // 2. Insert float vector into _rescore_vectors (rowid-keyed) + { + char *zSql = sqlite3_mprintf( + "INSERT INTO \"%w\".\"%w\"(rowid, vector) VALUES (?, ?)", + p->schemaName, p->shadowRescoreVectorsNames[i]); + if (!zSql) + return SQLITE_NOMEM; + sqlite3_stmt *stmt; + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) { + sqlite3_finalize(stmt); + return rc; + } + sqlite3_bind_int64(stmt, 1, rowid); + sqlite3_bind_blob(stmt, 2, vectorDatas[i], fsize, SQLITE_TRANSIENT); + rc = sqlite3_step(stmt); + sqlite3_finalize(stmt); + if (rc != SQLITE_DONE) + return SQLITE_ERROR; + } + } + return SQLITE_OK; +} + +// ============================================================================ +// Delete path +// ============================================================================ + +/** + * Zero out quantized vector in _rescore_chunks and delete from _rescore_vectors. + */ +static int rescore_on_delete(vec0_vtab *p, i64 chunk_id, u64 chunk_offset, + i64 rowid) { + for (int i = 0; i < p->numVectorColumns; i++) { + if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_RESCORE) + continue; + int rc; + + // 1. Zero out quantized data in _rescore_chunks + { + size_t qsize = rescore_quantized_byte_size(&p->vector_columns[i]); + void *zeroBuf = sqlite3_malloc(qsize); + if (!zeroBuf) + return SQLITE_NOMEM; + memset(zeroBuf, 0, qsize); + + sqlite3_blob *blob = NULL; + rc = sqlite3_blob_open(p->db, p->schemaName, + p->shadowRescoreChunksNames[i], "vectors", + chunk_id, 1, &blob); + if (rc != SQLITE_OK) { + sqlite3_free(zeroBuf); + return rc; + } + rc = sqlite3_blob_write(blob, zeroBuf, qsize, chunk_offset * qsize); + sqlite3_free(zeroBuf); + int brc = sqlite3_blob_close(blob); + if (rc != SQLITE_OK) + return rc; + if (brc != SQLITE_OK) + return brc; + } + + // 2. Delete from _rescore_vectors + { + char *zSql = sqlite3_mprintf( + "DELETE FROM \"%w\".\"%w\" WHERE rowid = ?", + p->schemaName, p->shadowRescoreVectorsNames[i]); + if (!zSql) + return SQLITE_NOMEM; + sqlite3_stmt *stmt; + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) + return rc; + sqlite3_bind_int64(stmt, 1, rowid); + rc = sqlite3_step(stmt); + sqlite3_finalize(stmt); + if (rc != SQLITE_DONE) + return SQLITE_ERROR; + } + } + return SQLITE_OK; +} + +/** + * Delete a chunk row from _rescore_chunks{NN} tables. + * (_rescore_vectors rows were already deleted per-row in rescore_on_delete) + */ +static int rescore_delete_chunk(vec0_vtab *p, i64 chunk_id) { + for (int i = 0; i < p->numVectorColumns; i++) { + if (!p->shadowRescoreChunksNames[i]) + continue; + char *zSql = sqlite3_mprintf( + "DELETE FROM \"%w\".\"%w\" WHERE rowid = ?", + p->schemaName, p->shadowRescoreChunksNames[i]); + if (!zSql) + return SQLITE_NOMEM; + sqlite3_stmt *stmt; + int rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) + return rc; + sqlite3_bind_int64(stmt, 1, chunk_id); + rc = sqlite3_step(stmt); + sqlite3_finalize(stmt); + if (rc != SQLITE_DONE) + return SQLITE_ERROR; + } + return SQLITE_OK; +} + +// ============================================================================ +// KNN rescore query +// ============================================================================ + +/** + * Phase 1: Coarse scan of quantized chunks → top k*oversample candidates (rowids). + * Phase 2: For each candidate, blob_open _rescore_vectors by rowid, read float + * vector, compute float distance. Sort, return top k. + * + * Phase 2 is fast because _rescore_vectors has INTEGER PRIMARY KEY, so + * sqlite3_blob_open/reopen addresses rows directly by rowid — no index lookup. + */ +static int rescore_knn(vec0_vtab *p, vec0_cursor *pCur, + struct VectorColumnDefinition *vector_column, + int vectorColumnIdx, struct Array *arrayRowidsIn, + struct Array *aMetadataIn, const char *idxStr, int argc, + sqlite3_value **argv, void *queryVector, i64 k, + struct vec0_query_knn_data *knn_data) { + (void)pCur; + (void)aMetadataIn; + int rc = SQLITE_OK; + int oversample = vector_column->rescore.oversample; + i64 k_oversample = k * oversample; + if (k_oversample > 4096) + k_oversample = 4096; + + size_t qdim = vector_column->dimensions; + size_t qsize = rescore_quantized_byte_size(vector_column); + size_t fsize = vector_column_byte_size(*vector_column); + + // Quantize the query vector + void *quantizedQuery = sqlite3_malloc(qsize); + if (!quantizedQuery) + return SQLITE_NOMEM; + + switch (vector_column->rescore.quantizer_type) { + case VEC0_RESCORE_QUANTIZER_BIT: + rescore_quantize_float_to_bit((const float *)queryVector, + (uint8_t *)quantizedQuery, qdim); + break; + case VEC0_RESCORE_QUANTIZER_INT8: + rescore_quantize_float_to_int8((const float *)queryVector, + (int8_t *)quantizedQuery, qdim); + break; + } + + // Phase 1: Scan quantized chunks for k*oversample candidates + sqlite3_stmt *stmtChunks = NULL; + rc = vec0_chunks_iter(p, idxStr, argc, argv, &stmtChunks); + if (rc != SQLITE_OK) { + sqlite3_free(quantizedQuery); + return rc; + } + + i64 *cand_rowids = sqlite3_malloc(k_oversample * sizeof(i64)); + f32 *cand_distances = sqlite3_malloc(k_oversample * sizeof(f32)); + i64 *tmp_rowids = sqlite3_malloc(k_oversample * sizeof(i64)); + f32 *tmp_distances = sqlite3_malloc(k_oversample * sizeof(f32)); + f32 *chunk_distances = sqlite3_malloc(p->chunk_size * sizeof(f32)); + i32 *chunk_topk_idxs = sqlite3_malloc(k_oversample * sizeof(i32)); + u8 *b = sqlite3_malloc(p->chunk_size / CHAR_BIT); + u8 *bTaken = sqlite3_malloc(p->chunk_size / CHAR_BIT); + u8 *bmRowids = NULL; + void *baseVectors = sqlite3_malloc((i64)p->chunk_size * (i64)qsize); + + if (!cand_rowids || !cand_distances || !tmp_rowids || !tmp_distances || + !chunk_distances || !chunk_topk_idxs || !b || !bTaken || !baseVectors) { + rc = SQLITE_NOMEM; + goto cleanup; + } + memset(cand_rowids, 0, k_oversample * sizeof(i64)); + memset(cand_distances, 0, k_oversample * sizeof(f32)); + + if (arrayRowidsIn) { + bmRowids = sqlite3_malloc(p->chunk_size / CHAR_BIT); + if (!bmRowids) { + rc = SQLITE_NOMEM; + goto cleanup; + } + } + + i64 cand_used = 0; + + while (1) { + rc = sqlite3_step(stmtChunks); + if (rc == SQLITE_DONE) + break; + if (rc != SQLITE_ROW) { + rc = SQLITE_ERROR; + goto cleanup; + } + + i64 chunk_id = sqlite3_column_int64(stmtChunks, 0); + unsigned char *chunkValidity = + (unsigned char *)sqlite3_column_blob(stmtChunks, 1); + i64 *chunkRowids = (i64 *)sqlite3_column_blob(stmtChunks, 2); + + memset(chunk_distances, 0, p->chunk_size * sizeof(f32)); + memset(chunk_topk_idxs, 0, k_oversample * sizeof(i32)); + bitmap_copy(b, chunkValidity, p->chunk_size); + + if (arrayRowidsIn) { + bitmap_clear(bmRowids, p->chunk_size); + for (int j = 0; j < p->chunk_size; j++) { + if (!bitmap_get(chunkValidity, j)) + continue; + i64 rid = chunkRowids[j]; + void *found = bsearch(&rid, arrayRowidsIn->z, arrayRowidsIn->length, + sizeof(i64), _cmp); + bitmap_set(bmRowids, j, found ? 1 : 0); + } + bitmap_and_inplace(b, bmRowids, p->chunk_size); + } + + // Read quantized vectors + sqlite3_blob *blobQ = NULL; + rc = sqlite3_blob_open(p->db, p->schemaName, + p->shadowRescoreChunksNames[vectorColumnIdx], + "vectors", chunk_id, 0, &blobQ); + if (rc != SQLITE_OK) + goto cleanup; + rc = sqlite3_blob_read(blobQ, baseVectors, + (i64)p->chunk_size * (i64)qsize, 0); + sqlite3_blob_close(blobQ); + if (rc != SQLITE_OK) + goto cleanup; + + // Compute quantized distances + for (int j = 0; j < p->chunk_size; j++) { + if (!bitmap_get(b, j)) + continue; + f32 dist; + switch (vector_column->rescore.quantizer_type) { + case VEC0_RESCORE_QUANTIZER_BIT: { + const u8 *base_j = ((u8 *)baseVectors) + (j * (qdim / CHAR_BIT)); + dist = distance_hamming(base_j, (u8 *)quantizedQuery, &qdim); + break; + } + case VEC0_RESCORE_QUANTIZER_INT8: { + const i8 *base_j = ((i8 *)baseVectors) + (j * qdim); + switch (vector_column->distance_metric) { + case VEC0_DISTANCE_METRIC_L2: + dist = distance_l2_sqr_int8(base_j, (i8 *)quantizedQuery, &qdim); + break; + case VEC0_DISTANCE_METRIC_COSINE: + dist = distance_cosine_int8(base_j, (i8 *)quantizedQuery, &qdim); + break; + case VEC0_DISTANCE_METRIC_L1: + dist = (f32)distance_l1_int8(base_j, (i8 *)quantizedQuery, &qdim); + break; + } + break; + } + } + chunk_distances[j] = dist; + } + + int used1; + min_idx(chunk_distances, p->chunk_size, b, chunk_topk_idxs, + min(k_oversample, p->chunk_size), bTaken, &used1); + + i64 merged_used; + merge_sorted_lists(cand_distances, cand_rowids, cand_used, chunk_distances, + chunkRowids, chunk_topk_idxs, + min(min(k_oversample, p->chunk_size), used1), + tmp_distances, tmp_rowids, k_oversample, &merged_used); + + for (i64 j = 0; j < merged_used; j++) { + cand_rowids[j] = tmp_rowids[j]; + cand_distances[j] = tmp_distances[j]; + } + cand_used = merged_used; + } + rc = SQLITE_OK; + + // Phase 2: Rescore candidates using _rescore_vectors (rowid-keyed) + if (cand_used == 0) { + knn_data->current_idx = 0; + knn_data->k = 0; + knn_data->rowids = NULL; + knn_data->distances = NULL; + knn_data->k_used = 0; + goto cleanup; + } + { + f32 *float_distances = sqlite3_malloc(cand_used * sizeof(f32)); + void *fBuf = sqlite3_malloc(fsize); + if (!float_distances || !fBuf) { + sqlite3_free(float_distances); + sqlite3_free(fBuf); + rc = SQLITE_NOMEM; + goto cleanup; + } + + // Open blob on _rescore_vectors, then reopen for each candidate rowid. + // blob_reopen is O(1) for INTEGER PRIMARY KEY tables. + sqlite3_blob *blobFloat = NULL; + rc = sqlite3_blob_open(p->db, p->schemaName, + p->shadowRescoreVectorsNames[vectorColumnIdx], + "vector", cand_rowids[0], 0, &blobFloat); + if (rc != SQLITE_OK) { + sqlite3_free(float_distances); + sqlite3_free(fBuf); + goto cleanup; + } + + rc = sqlite3_blob_read(blobFloat, fBuf, fsize, 0); + if (rc != SQLITE_OK) { + sqlite3_blob_close(blobFloat); + sqlite3_free(float_distances); + sqlite3_free(fBuf); + goto cleanup; + } + float_distances[0] = + vec0_distance_full(fBuf, queryVector, vector_column->dimensions, + vector_column->element_type, + vector_column->distance_metric); + + for (i64 j = 1; j < cand_used; j++) { + rc = sqlite3_blob_reopen(blobFloat, cand_rowids[j]); + if (rc != SQLITE_OK) { + sqlite3_blob_close(blobFloat); + sqlite3_free(float_distances); + sqlite3_free(fBuf); + goto cleanup; + } + rc = sqlite3_blob_read(blobFloat, fBuf, fsize, 0); + if (rc != SQLITE_OK) { + sqlite3_blob_close(blobFloat); + sqlite3_free(float_distances); + sqlite3_free(fBuf); + goto cleanup; + } + float_distances[j] = + vec0_distance_full(fBuf, queryVector, vector_column->dimensions, + vector_column->element_type, + vector_column->distance_metric); + } + sqlite3_blob_close(blobFloat); + sqlite3_free(fBuf); + + // Sort by float distance + for (i64 a = 0; a + 1 < cand_used; a++) { + i64 minIdx = a; + for (i64 c = a + 1; c < cand_used; c++) { + if (float_distances[c] < float_distances[minIdx]) + minIdx = c; + } + if (minIdx != a) { + f32 td = float_distances[a]; + float_distances[a] = float_distances[minIdx]; + float_distances[minIdx] = td; + i64 tr = cand_rowids[a]; + cand_rowids[a] = cand_rowids[minIdx]; + cand_rowids[minIdx] = tr; + } + } + + i64 result_k = min(k, cand_used); + i64 *out_rowids = sqlite3_malloc(result_k * sizeof(i64)); + f32 *out_distances = sqlite3_malloc(result_k * sizeof(f32)); + if (!out_rowids || !out_distances) { + sqlite3_free(out_rowids); + sqlite3_free(out_distances); + sqlite3_free(float_distances); + rc = SQLITE_NOMEM; + goto cleanup; + } + for (i64 j = 0; j < result_k; j++) { + out_rowids[j] = cand_rowids[j]; + out_distances[j] = float_distances[j]; + } + + knn_data->current_idx = 0; + knn_data->k = result_k; + knn_data->rowids = out_rowids; + knn_data->distances = out_distances; + knn_data->k_used = result_k; + + sqlite3_free(float_distances); + } + +cleanup: + sqlite3_finalize(stmtChunks); + sqlite3_free(quantizedQuery); + sqlite3_free(cand_rowids); + sqlite3_free(cand_distances); + sqlite3_free(tmp_rowids); + sqlite3_free(tmp_distances); + sqlite3_free(chunk_distances); + sqlite3_free(chunk_topk_idxs); + sqlite3_free(b); + sqlite3_free(bTaken); + sqlite3_free(bmRowids); + sqlite3_free(baseVectors); + return rc; +} + +#ifdef SQLITE_VEC_TEST +void _test_rescore_quantize_float_to_bit(const float *src, uint8_t *dst, size_t dim) { + rescore_quantize_float_to_bit(src, dst, dim); +} +void _test_rescore_quantize_float_to_int8(const float *src, int8_t *dst, size_t dim) { + rescore_quantize_float_to_int8(src, dst, dim); +} +size_t _test_rescore_quantized_byte_size_bit(size_t dimensions) { + struct VectorColumnDefinition col; + memset(&col, 0, sizeof(col)); + col.dimensions = dimensions; + col.rescore.quantizer_type = VEC0_RESCORE_QUANTIZER_BIT; + return rescore_quantized_byte_size(&col); +} +size_t _test_rescore_quantized_byte_size_int8(size_t dimensions) { + struct VectorColumnDefinition col; + memset(&col, 0, sizeof(col)); + col.dimensions = dimensions; + col.rescore.quantizer_type = VEC0_RESCORE_QUANTIZER_INT8; + return rescore_quantized_byte_size(&col); +} +#endif diff --git a/sqlite-vec.c b/sqlite-vec.c index 390123b..ff9e0da 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -112,6 +112,10 @@ typedef size_t usize; #define countof(x) (sizeof(x) / sizeof((x)[0])) #define min(a, b) (((a) <= (b)) ? (a) : (b)) +#ifndef SQLITE_VEC_ENABLE_RESCORE +#define SQLITE_VEC_ENABLE_RESCORE 1 +#endif + enum VectorElementType { // clang-format off SQLITE_VEC_ELEMENT_TYPE_FLOAT32 = 223 + 0, @@ -2532,8 +2536,23 @@ static f32 vec0_distance_full( enum Vec0IndexType { VEC0_INDEX_TYPE_FLAT = 1, +#if SQLITE_VEC_ENABLE_RESCORE + VEC0_INDEX_TYPE_RESCORE = 2, +#endif }; +#if SQLITE_VEC_ENABLE_RESCORE +enum Vec0RescoreQuantizerType { + VEC0_RESCORE_QUANTIZER_BIT = 1, + VEC0_RESCORE_QUANTIZER_INT8 = 2, +}; + +struct Vec0RescoreConfig { + enum Vec0RescoreQuantizerType quantizer_type; + int oversample; +}; +#endif + struct VectorColumnDefinition { char *name; int name_length; @@ -2541,6 +2560,9 @@ struct VectorColumnDefinition { enum VectorElementType element_type; enum Vec0DistanceMetrics distance_metric; enum Vec0IndexType index_type; +#if SQLITE_VEC_ENABLE_RESCORE + struct Vec0RescoreConfig rescore; +#endif }; struct Vec0PartitionColumnDefinition { @@ -2577,6 +2599,111 @@ size_t vector_column_byte_size(struct VectorColumnDefinition column) { return vector_byte_size(column.element_type, column.dimensions); } +#if SQLITE_VEC_ENABLE_RESCORE +/** + * @brief Parse rescore options from an "INDEXED BY rescore(...)" clause. + * + * @param scanner Scanner positioned right after the opening '(' of rescore(...) + * @param outConfig Output rescore config + * @param pzErr Error message output + * @return int SQLITE_OK on success, SQLITE_ERROR on error. + */ +static int vec0_parse_rescore_options(struct Vec0Scanner *scanner, + struct Vec0RescoreConfig *outConfig, + char **pzErr) { + struct Vec0Token token; + int rc; + int hasQuantizer = 0; + outConfig->oversample = 8; + outConfig->quantizer_type = 0; + + while (1) { + rc = vec0_scanner_next(scanner, &token); + if (rc == VEC0_TOKEN_RESULT_EOF) { + break; + } + // ')' closes rescore options + if (rc == VEC0_TOKEN_RESULT_SOME && token.token_type == TOKEN_TYPE_RPAREN) { + break; + } + if (rc != VEC0_TOKEN_RESULT_SOME || token.token_type != TOKEN_TYPE_IDENTIFIER) { + *pzErr = sqlite3_mprintf("Expected option name in rescore(...)"); + return SQLITE_ERROR; + } + + char *key = token.start; + int keyLength = token.end - token.start; + + // expect '=' + rc = vec0_scanner_next(scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME || token.token_type != TOKEN_TYPE_EQ) { + *pzErr = sqlite3_mprintf("Expected '=' after option name in rescore(...)"); + return SQLITE_ERROR; + } + + // value + rc = vec0_scanner_next(scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME) { + *pzErr = sqlite3_mprintf("Expected value after '=' in rescore(...)"); + return SQLITE_ERROR; + } + + if (sqlite3_strnicmp(key, "quantizer", keyLength) == 0) { + if (token.token_type != TOKEN_TYPE_IDENTIFIER) { + *pzErr = sqlite3_mprintf("Expected identifier for quantizer value in rescore(...)"); + return SQLITE_ERROR; + } + int valLen = token.end - token.start; + if (sqlite3_strnicmp(token.start, "bit", valLen) == 0) { + outConfig->quantizer_type = VEC0_RESCORE_QUANTIZER_BIT; + } else if (sqlite3_strnicmp(token.start, "int8", valLen) == 0) { + outConfig->quantizer_type = VEC0_RESCORE_QUANTIZER_INT8; + } else { + *pzErr = sqlite3_mprintf("Unknown quantizer type '%.*s' in rescore(...). Expected 'bit' or 'int8'.", valLen, token.start); + return SQLITE_ERROR; + } + hasQuantizer = 1; + } else if (sqlite3_strnicmp(key, "oversample", keyLength) == 0) { + if (token.token_type != TOKEN_TYPE_DIGIT) { + *pzErr = sqlite3_mprintf("Expected integer for oversample value in rescore(...)"); + return SQLITE_ERROR; + } + outConfig->oversample = atoi(token.start); + if (outConfig->oversample <= 0 || outConfig->oversample > 128) { + *pzErr = sqlite3_mprintf("oversample in rescore(...) must be between 1 and 128, got %d", outConfig->oversample); + return SQLITE_ERROR; + } + } else { + *pzErr = sqlite3_mprintf("Unknown option '%.*s' in rescore(...)", keyLength, key); + return SQLITE_ERROR; + } + + // optional comma between options + rc = vec0_scanner_next(scanner, &token); + if (rc == VEC0_TOKEN_RESULT_EOF) { + break; + } + if (rc == VEC0_TOKEN_RESULT_SOME && token.token_type == TOKEN_TYPE_RPAREN) { + break; + } + if (rc == VEC0_TOKEN_RESULT_SOME && token.token_type == TOKEN_TYPE_COMMA) { + continue; + } + // If it's not a comma or rparen, it might be the next key — push back isn't + // possible with this scanner, so we'll treat unexpected tokens as errors + *pzErr = sqlite3_mprintf("Unexpected token in rescore(...) options"); + return SQLITE_ERROR; + } + + if (!hasQuantizer) { + *pzErr = sqlite3_mprintf("rescore(...) requires a 'quantizer' option (quantizer=bit or quantizer=int8)"); + return SQLITE_ERROR; + } + + return SQLITE_OK; +} +#endif /* SQLITE_VEC_ENABLE_RESCORE */ + /** * @brief Parse an vec0 vtab argv[i] column definition and see if * it's a vector column defintion, ex `contents_embedding float[768]`. @@ -2601,6 +2728,10 @@ int vec0_parse_vector_column(const char *source, int source_length, enum VectorElementType elementType; enum Vec0DistanceMetrics distanceMetric = VEC0_DISTANCE_METRIC_L2; enum Vec0IndexType indexType = VEC0_INDEX_TYPE_FLAT; +#if SQLITE_VEC_ENABLE_RESCORE + struct Vec0RescoreConfig rescoreConfig; + memset(&rescoreConfig, 0, sizeof(rescoreConfig)); +#endif int dimensions; vec0_scanner_init(&scanner, source, source_length); @@ -2704,6 +2835,7 @@ int vec0_parse_vector_column(const char *source, int source_length, return SQLITE_ERROR; } } + // INDEXED BY flat() | rescore(...) else if (sqlite3_strnicmp(key, "indexed", keyLength) == 0) { // expect "by" rc = vec0_scanner_next(&scanner, &token); @@ -2733,7 +2865,32 @@ int vec0_parse_vector_column(const char *source, int source_length, token.token_type != TOKEN_TYPE_RPAREN) { return SQLITE_ERROR; } - } else { + } +#if SQLITE_VEC_ENABLE_RESCORE + else if (sqlite3_strnicmp(token.start, "rescore", indexNameLen) == 0) { + indexType = VEC0_INDEX_TYPE_RESCORE; + if (elementType != SQLITE_VEC_ELEMENT_TYPE_FLOAT32) { + return SQLITE_ERROR; + } + // expect '(' + rc = vec0_scanner_next(&scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME || token.token_type != TOKEN_TYPE_LPAREN) { + return SQLITE_ERROR; + } + char *rescoreErr = NULL; + rc = vec0_parse_rescore_options(&scanner, &rescoreConfig, &rescoreErr); + if (rc != SQLITE_OK) { + if (rescoreErr) sqlite3_free(rescoreErr); + return SQLITE_ERROR; + } + // validate dimensions for bit quantizer + if (rescoreConfig.quantizer_type == VEC0_RESCORE_QUANTIZER_BIT && + (dimensions % CHAR_BIT) != 0) { + return SQLITE_ERROR; + } + } +#endif + else { // unknown index type return SQLITE_ERROR; } @@ -2753,6 +2910,9 @@ int vec0_parse_vector_column(const char *source, int source_length, outColumn->element_type = elementType; outColumn->dimensions = dimensions; outColumn->index_type = indexType; +#if SQLITE_VEC_ENABLE_RESCORE + outColumn->rescore = rescoreConfig; +#endif return SQLITE_OK; } @@ -3093,6 +3253,19 @@ struct vec0_vtab { // The first numVectorColumns entries must be freed with sqlite3_free() char *shadowVectorChunksNames[VEC0_MAX_VECTOR_COLUMNS]; +#if SQLITE_VEC_ENABLE_RESCORE + // Name of all rescore chunk shadow tables, ie `_rescore_chunks00` + // Only populated for vector columns with rescore enabled. + // Must be freed with sqlite3_free() + char *shadowRescoreChunksNames[VEC0_MAX_VECTOR_COLUMNS]; + + // Name of all rescore vector shadow tables, ie `_rescore_vectors00` + // Rowid-keyed table for fast random-access float vector reads during rescore. + // Only populated for vector columns with rescore enabled. + // Must be freed with sqlite3_free() + char *shadowRescoreVectorsNames[VEC0_MAX_VECTOR_COLUMNS]; +#endif + // Name of all metadata chunk shadow tables, ie `_metadatachunks00` // Only the first numMetadataColumns entries will be available. // The first numMetadataColumns entries must be freed with sqlite3_free() @@ -3162,6 +3335,18 @@ struct vec0_vtab { sqlite3_stmt *stmtRowidsGetChunkPosition; }; +#if SQLITE_VEC_ENABLE_RESCORE +// Forward declarations for rescore functions (defined in sqlite-vec-rescore.c, +// included later after all helpers they depend on are defined). +static int rescore_create_tables(vec0_vtab *p, sqlite3 *db, char **pzErr); +static int rescore_drop_tables(vec0_vtab *p); +static int rescore_new_chunk(vec0_vtab *p, i64 chunk_rowid); +static int rescore_on_insert(vec0_vtab *p, i64 chunk_rowid, i64 chunk_offset, + i64 rowid, void *vectorDatas[]); +static int rescore_on_delete(vec0_vtab *p, i64 chunk_id, u64 chunk_offset, i64 rowid); +static int rescore_delete_chunk(vec0_vtab *p, i64 chunk_id); +#endif + /** * @brief Finalize all the sqlite3_stmt members in a vec0_vtab. * @@ -3201,6 +3386,14 @@ void vec0_free(vec0_vtab *p) { sqlite3_free(p->shadowVectorChunksNames[i]); p->shadowVectorChunksNames[i] = NULL; +#if SQLITE_VEC_ENABLE_RESCORE + sqlite3_free(p->shadowRescoreChunksNames[i]); + p->shadowRescoreChunksNames[i] = NULL; + + sqlite3_free(p->shadowRescoreVectorsNames[i]); + p->shadowRescoreVectorsNames[i] = NULL; +#endif + sqlite3_free(p->vector_columns[i].name); p->vector_columns[i].name = NULL; } @@ -3493,6 +3686,41 @@ int vec0_get_vector_data(vec0_vtab *pVtab, i64 rowid, int vector_column_idx, assert((vector_column_idx >= 0) && (vector_column_idx < pVtab->numVectorColumns)); +#if SQLITE_VEC_ENABLE_RESCORE + // Rescore columns store float vectors in _rescore_vectors (rowid-keyed) + if (p->vector_columns[vector_column_idx].index_type == VEC0_INDEX_TYPE_RESCORE) { + size = vector_column_byte_size(p->vector_columns[vector_column_idx]); + rc = sqlite3_blob_open(p->db, p->schemaName, + p->shadowRescoreVectorsNames[vector_column_idx], + "vector", rowid, 0, &vectorBlob); + if (rc != SQLITE_OK) { + vtab_set_error(&pVtab->base, + "Could not fetch vector data for %lld from rescore vectors", + rowid); + rc = SQLITE_ERROR; + goto cleanup; + } + buf = sqlite3_malloc(size); + if (!buf) { + rc = SQLITE_NOMEM; + goto cleanup; + } + rc = sqlite3_blob_read(vectorBlob, buf, size, 0); + if (rc != SQLITE_OK) { + sqlite3_free(buf); + buf = NULL; + rc = SQLITE_ERROR; + goto cleanup; + } + *outVector = buf; + if (outVectorSize) { + *outVectorSize = size; + } + rc = SQLITE_OK; + goto cleanup; + } +#endif /* SQLITE_VEC_ENABLE_RESCORE */ + rc = vec0_get_chunk_position(pVtab, rowid, NULL, &chunk_id, &chunk_offset); if (rc == SQLITE_EMPTY) { vtab_set_error(&pVtab->base, "Could not find a row with rowid %lld", rowid); @@ -4096,6 +4324,14 @@ int vec0_new_chunk(vec0_vtab *p, sqlite3_value ** partitionKeyValues, i64 *chunk continue; } int vector_column_idx = p->user_column_idxs[i]; + +#if SQLITE_VEC_ENABLE_RESCORE + // Rescore columns don't use _vector_chunks for float storage + if (p->vector_columns[vector_column_idx].index_type == VEC0_INDEX_TYPE_RESCORE) { + continue; + } +#endif + i64 vectorsSize = p->chunk_size * vector_column_byte_size(p->vector_columns[vector_column_idx]); @@ -4126,6 +4362,14 @@ int vec0_new_chunk(vec0_vtab *p, sqlite3_value ** partitionKeyValues, i64 *chunk } } +#if SQLITE_VEC_ENABLE_RESCORE + // Create new rescore chunks for each rescore-enabled vector column + rc = rescore_new_chunk(p, rowid); + if (rc != SQLITE_OK) { + return rc; + } +#endif + // Step 3: Create new metadata chunks for each metadata column for (int i = 0; i < vec0_num_defined_user_columns(p); i++) { if(p->user_column_kinds[i] != SQLITE_VEC0_USER_COLUMN_KIND_METADATA) { @@ -4487,6 +4731,35 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, goto error; } +#if SQLITE_VEC_ENABLE_RESCORE + { + int hasRescore = 0; + for (int i = 0; i < numVectorColumns; i++) { + if (pNew->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE) { + hasRescore = 1; + break; + } + } + if (hasRescore) { + if (numAuxiliaryColumns > 0) { + *pzErr = sqlite3_mprintf(VEC_CONSTRUCTOR_ERROR + "Auxiliary columns are not supported with rescore indexes"); + goto error; + } + if (numMetadataColumns > 0) { + *pzErr = sqlite3_mprintf(VEC_CONSTRUCTOR_ERROR + "Metadata columns are not supported with rescore indexes"); + goto error; + } + if (numPartitionColumns > 0) { + *pzErr = sqlite3_mprintf(VEC_CONSTRUCTOR_ERROR + "Partition key columns are not supported with rescore indexes"); + goto error; + } + } + } +#endif + sqlite3_str *createStr = sqlite3_str_new(NULL); sqlite3_str_appendall(createStr, "CREATE TABLE x("); if (pkColumnName) { @@ -4577,6 +4850,20 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, if (!pNew->shadowVectorChunksNames[i]) { goto error; } +#if SQLITE_VEC_ENABLE_RESCORE + if (pNew->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE) { + pNew->shadowRescoreChunksNames[i] = + sqlite3_mprintf("%s_rescore_chunks%02d", tableName, i); + if (!pNew->shadowRescoreChunksNames[i]) { + goto error; + } + pNew->shadowRescoreVectorsNames[i] = + sqlite3_mprintf("%s_rescore_vectors%02d", tableName, i); + if (!pNew->shadowRescoreVectorsNames[i]) { + goto error; + } + } +#endif } for (int i = 0; i < pNew->numMetadataColumns; i++) { pNew->shadowMetadataChunksNames[i] = @@ -4700,6 +4987,11 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, sqlite3_finalize(stmt); for (int i = 0; i < pNew->numVectorColumns; i++) { +#if SQLITE_VEC_ENABLE_RESCORE + // Rescore columns don't use _vector_chunks + if (pNew->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE) + continue; +#endif char *zSql = sqlite3_mprintf(VEC0_SHADOW_VECTOR_N_CREATE, pNew->schemaName, pNew->tableName, i); if (!zSql) { @@ -4718,6 +5010,13 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, sqlite3_finalize(stmt); } +#if SQLITE_VEC_ENABLE_RESCORE + rc = rescore_create_tables(pNew, db, pzErr); + if (rc != SQLITE_OK) { + goto error; + } +#endif + // See SHADOW_TABLE_ROWID_QUIRK in vec0_new_chunk() — same "rowid PRIMARY KEY" // without INTEGER type issue applies here. for (int i = 0; i < pNew->numMetadataColumns; i++) { @@ -4852,6 +5151,10 @@ static int vec0Destroy(sqlite3_vtab *pVtab) { sqlite3_finalize(stmt); for (int i = 0; i < p->numVectorColumns; i++) { +#if SQLITE_VEC_ENABLE_RESCORE + if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE) + continue; +#endif zSql = sqlite3_mprintf("DROP TABLE \"%w\".\"%w\"", p->schemaName, p->shadowVectorChunksNames[i]); rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0); @@ -4863,6 +5166,13 @@ static int vec0Destroy(sqlite3_vtab *pVtab) { sqlite3_finalize(stmt); } +#if SQLITE_VEC_ENABLE_RESCORE + rc = rescore_drop_tables(p); + if (rc != SQLITE_OK) { + goto done; + } +#endif + if(p->numAuxiliaryColumns > 0) { zSql = sqlite3_mprintf("DROP TABLE " VEC0_SHADOW_AUXILIARY_NAME, p->schemaName, p->tableName); rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0); @@ -6624,6 +6934,10 @@ cleanup: return rc; } +#if SQLITE_VEC_ENABLE_RESCORE +#include "sqlite-vec-rescore.c" +#endif + int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum, const char *idxStr, int argc, sqlite3_value **argv) { assert(argc == (strlen(idxStr)-1) / 4); @@ -6856,6 +7170,21 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum, } #endif +#if SQLITE_VEC_ENABLE_RESCORE + // Dispatch to rescore KNN path if this vector column has rescore enabled + if (vector_column->index_type == VEC0_INDEX_TYPE_RESCORE) { + rc = rescore_knn(p, pCur, vector_column, vectorColumnIdx, arrayRowidsIn, + aMetadataIn, idxStr, argc, argv, queryVector, k, knn_data); + if (rc != SQLITE_OK) { + goto cleanup; + } + pCur->knn_data = knn_data; + pCur->query_plan = VEC0_QUERY_PLAN_KNN; + rc = SQLITE_OK; + goto cleanup; + } +#endif + rc = vec0_chunks_iter(p, idxStr, argc, argv, &stmtChunks); if (rc != SQLITE_OK) { // IMP: V06942_23781 @@ -7680,6 +8009,12 @@ int vec0Update_InsertWriteFinalStep(vec0_vtab *p, i64 chunk_rowid, // Go insert the vector data into the vector chunk shadow tables for (int i = 0; i < p->numVectorColumns; i++) { +#if SQLITE_VEC_ENABLE_RESCORE + // Rescore columns store float vectors in _rescore_vectors instead + if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE) + continue; +#endif + sqlite3_blob *blobVectors; rc = sqlite3_blob_open(p->db, p->schemaName, p->shadowVectorChunksNames[i], "vectors", chunk_rowid, 1, &blobVectors); @@ -8082,6 +8417,13 @@ int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv, goto cleanup; } +#if SQLITE_VEC_ENABLE_RESCORE + rc = rescore_on_insert(p, chunk_rowid, chunk_offset, rowid, vectorDatas); + if (rc != SQLITE_OK) { + goto cleanup; + } +#endif + if(p->numAuxiliaryColumns > 0) { sqlite3_stmt *stmt; sqlite3_str * s = sqlite3_str_new(NULL); @@ -8272,6 +8614,11 @@ int vec0Update_Delete_ClearVectors(vec0_vtab *p, i64 chunk_id, u64 chunk_offset) { int rc, brc; for (int i = 0; i < p->numVectorColumns; i++) { +#if SQLITE_VEC_ENABLE_RESCORE + // Rescore columns don't use _vector_chunks + if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE) + continue; +#endif sqlite3_blob *blobVectors = NULL; size_t n = vector_column_byte_size(p->vector_columns[i]); @@ -8383,6 +8730,10 @@ int vec0Update_Delete_DeleteChunkIfEmpty(vec0_vtab *p, i64 chunk_id, // Delete from each _vector_chunksNN for (int i = 0; i < p->numVectorColumns; i++) { +#if SQLITE_VEC_ENABLE_RESCORE + if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE) + continue; +#endif zSql = sqlite3_mprintf( "DELETE FROM " VEC0_SHADOW_VECTOR_N_NAME " WHERE rowid = ?", p->schemaName, p->tableName, i); @@ -8399,6 +8750,12 @@ int vec0Update_Delete_DeleteChunkIfEmpty(vec0_vtab *p, i64 chunk_id, return SQLITE_ERROR; } +#if SQLITE_VEC_ENABLE_RESCORE + rc = rescore_delete_chunk(p, chunk_id); + if (rc != SQLITE_OK) + return rc; +#endif + // Delete from each _metadatachunksNN for (int i = 0; i < p->numMetadataColumns; i++) { zSql = sqlite3_mprintf( @@ -8606,6 +8963,14 @@ int vec0Update_Delete(sqlite3_vtab *pVTab, sqlite3_value *idValue) { return rc; } +#if SQLITE_VEC_ENABLE_RESCORE + // 4b. zero out quantized data in rescore chunk tables, delete from rescore vectors + rc = rescore_on_delete(p, chunk_id, chunk_offset, rowid); + if (rc != SQLITE_OK) { + return rc; + } +#endif + // 5. delete from _rowids table rc = vec0Update_Delete_DeleteRowids(p, rowid); if (rc != SQLITE_OK) { @@ -8663,8 +9028,11 @@ int vec0Update_UpdateAuxColumn(vec0_vtab *p, int auxiliary_column_idx, sqlite3_v } int vec0Update_UpdateVectorColumn(vec0_vtab *p, i64 chunk_id, i64 chunk_offset, - int i, sqlite3_value *valueVector) { + int i, sqlite3_value *valueVector, i64 rowid) { int rc; +#if !SQLITE_VEC_ENABLE_RESCORE + UNUSED_PARAMETER(rowid); +#endif sqlite3_blob *blobVectors = NULL; @@ -8708,6 +9076,59 @@ int vec0Update_UpdateVectorColumn(vec0_vtab *p, i64 chunk_id, i64 chunk_offset, goto cleanup; } +#if SQLITE_VEC_ENABLE_RESCORE + if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE) { + // For rescore columns, update _rescore_vectors and _rescore_chunks + struct VectorColumnDefinition *col = &p->vector_columns[i]; + size_t qsize = rescore_quantized_byte_size(col); + size_t fsize = vector_column_byte_size(*col); + + // 1. Update quantized chunk + { + void *qbuf = sqlite3_malloc(qsize); + if (!qbuf) { rc = SQLITE_NOMEM; goto cleanup; } + switch (col->rescore.quantizer_type) { + case VEC0_RESCORE_QUANTIZER_BIT: + rescore_quantize_float_to_bit((const float *)vector, (uint8_t *)qbuf, col->dimensions); + break; + case VEC0_RESCORE_QUANTIZER_INT8: + rescore_quantize_float_to_int8((const float *)vector, (int8_t *)qbuf, col->dimensions); + break; + } + sqlite3_blob *blobQ = NULL; + rc = sqlite3_blob_open(p->db, p->schemaName, + p->shadowRescoreChunksNames[i], "vectors", + chunk_id, 1, &blobQ); + if (rc != SQLITE_OK) { sqlite3_free(qbuf); goto cleanup; } + rc = sqlite3_blob_write(blobQ, qbuf, qsize, chunk_offset * qsize); + sqlite3_free(qbuf); + int brc2 = sqlite3_blob_close(blobQ); + if (rc != SQLITE_OK) goto cleanup; + if (brc2 != SQLITE_OK) { rc = brc2; goto cleanup; } + } + + // 2. Update float vector in _rescore_vectors (keyed by user rowid) + { + char *zSql = sqlite3_mprintf( + "UPDATE \"%w\".\"%w\" SET vector = ? WHERE rowid = ?", + p->schemaName, p->shadowRescoreVectorsNames[i]); + if (!zSql) { rc = SQLITE_NOMEM; goto cleanup; } + sqlite3_stmt *stmtUp; + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmtUp, NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) goto cleanup; + sqlite3_bind_blob(stmtUp, 1, vector, fsize, SQLITE_TRANSIENT); + sqlite3_bind_int64(stmtUp, 2, rowid); + rc = sqlite3_step(stmtUp); + sqlite3_finalize(stmtUp); + if (rc != SQLITE_DONE) { rc = SQLITE_ERROR; goto cleanup; } + } + + rc = SQLITE_OK; + goto cleanup; + } +#endif + rc = sqlite3_blob_open(p->db, p->schemaName, p->shadowVectorChunksNames[i], "vectors", chunk_id, 1, &blobVectors); if (rc != SQLITE_OK) { @@ -8839,7 +9260,7 @@ int vec0Update_Update(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv) { } rc = vec0Update_UpdateVectorColumn(p, chunk_id, chunk_offset, vector_idx, - valueVector); + valueVector, rowid); if (rc != SQLITE_OK) { return SQLITE_ERROR; } @@ -8997,9 +9418,15 @@ static sqlite3_module vec0Module = { #else #define SQLITE_VEC_DEBUG_BUILD_NEON "" #endif +#if SQLITE_VEC_ENABLE_RESCORE +#define SQLITE_VEC_DEBUG_BUILD_RESCORE "rescore" +#else +#define SQLITE_VEC_DEBUG_BUILD_RESCORE "" +#endif #define SQLITE_VEC_DEBUG_BUILD \ - SQLITE_VEC_DEBUG_BUILD_AVX " " SQLITE_VEC_DEBUG_BUILD_NEON + SQLITE_VEC_DEBUG_BUILD_AVX " " SQLITE_VEC_DEBUG_BUILD_NEON " " \ + SQLITE_VEC_DEBUG_BUILD_RESCORE #define SQLITE_VEC_DEBUG_STRING \ "Version: " SQLITE_VEC_VERSION "\n" \ diff --git a/tests/fuzz/.gitignore b/tests/fuzz/.gitignore index 757d1ac..b9c7d30 100644 --- a/tests/fuzz/.gitignore +++ b/tests/fuzz/.gitignore @@ -1,2 +1,7 @@ *.dSYM targets/ +corpus/ +crash-* +leak-* +timeout-* +*.log diff --git a/tests/fuzz/Makefile b/tests/fuzz/Makefile index 21629ef..0030c2e 100644 --- a/tests/fuzz/Makefile +++ b/tests/fuzz/Makefile @@ -72,10 +72,34 @@ $(TARGET_DIR)/vec_mismatch: vec-mismatch.c $(FUZZ_SRCS) | $(TARGET_DIR) $(TARGET_DIR)/vec0_delete_completeness: vec0-delete-completeness.c $(FUZZ_SRCS) | $(TARGET_DIR) $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@ +$(TARGET_DIR)/rescore_operations: rescore-operations.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE $(FUZZ_SRCS) $< -o $@ + +$(TARGET_DIR)/rescore_create: rescore-create.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE $(FUZZ_SRCS) $< -o $@ + +$(TARGET_DIR)/rescore_quantize: rescore-quantize.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE -DSQLITE_VEC_TEST $(FUZZ_SRCS) $< -o $@ + +$(TARGET_DIR)/rescore_shadow_corrupt: rescore-shadow-corrupt.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE $(FUZZ_SRCS) $< -o $@ + +$(TARGET_DIR)/rescore_knn_deep: rescore-knn-deep.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE $(FUZZ_SRCS) $< -o $@ + +$(TARGET_DIR)/rescore_quantize_edge: rescore-quantize-edge.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE -DSQLITE_VEC_TEST $(FUZZ_SRCS) $< -o $@ + +$(TARGET_DIR)/rescore_interleave: rescore-interleave.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE $(FUZZ_SRCS) $< -o $@ + FUZZ_TARGETS = vec0_create exec json numpy \ shadow_corrupt vec0_operations scalar_functions \ vec0_create_full metadata_columns vec_each vec_mismatch \ - vec0_delete_completeness + vec0_delete_completeness \ + rescore_operations rescore_create rescore_quantize \ + rescore_shadow_corrupt rescore_knn_deep \ + rescore_quantize_edge rescore_interleave all: $(addprefix $(TARGET_DIR)/,$(FUZZ_TARGETS)) diff --git a/tests/fuzz/rescore-create.c b/tests/fuzz/rescore-create.c new file mode 100644 index 0000000..3e69d6d --- /dev/null +++ b/tests/fuzz/rescore-create.c @@ -0,0 +1,36 @@ +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + int rc = SQLITE_OK; + sqlite3 *db; + sqlite3_stmt *stmt; + + rc = sqlite3_open(":memory:", &db); + assert(rc == SQLITE_OK); + rc = sqlite3_vec_init(db, NULL, NULL); + assert(rc == SQLITE_OK); + + sqlite3_str *s = sqlite3_str_new(NULL); + assert(s); + sqlite3_str_appendall(s, "CREATE VIRTUAL TABLE v USING vec0(emb float[128] indexed by rescore("); + sqlite3_str_appendf(s, "%.*s", (int)size, data); + sqlite3_str_appendall(s, "))"); + const char *zSql = sqlite3_str_finish(s); + assert(zSql); + + rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, NULL); + sqlite3_free((void *)zSql); + if (rc == SQLITE_OK) { + sqlite3_step(stmt); + } + sqlite3_finalize(stmt); + sqlite3_close(db); + return 0; +} diff --git a/tests/fuzz/rescore-create.dict b/tests/fuzz/rescore-create.dict new file mode 100644 index 0000000..a8adf71 --- /dev/null +++ b/tests/fuzz/rescore-create.dict @@ -0,0 +1,20 @@ +"rescore" +"quantizer" +"bit" +"int8" +"oversample" +"indexed" +"by" +"float" +"(" +")" +"," +"=" +"[" +"]" +"1" +"8" +"16" +"128" +"256" +"1024" diff --git a/tests/fuzz/rescore-interleave.c b/tests/fuzz/rescore-interleave.c new file mode 100644 index 0000000..74e8b8d --- /dev/null +++ b/tests/fuzz/rescore-interleave.c @@ -0,0 +1,151 @@ +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +/** + * Fuzz target: interleaved insert/update/delete/KNN operations on rescore + * tables with BOTH quantizer types, exercising the int8 quantizer path + * and the update code path that the existing rescore-operations.c misses. + * + * Key differences from rescore-operations.c: + * - Tests BOTH bit and int8 quantizers (the existing target only tests bit) + * - Fuzz-controlled query vectors (not fixed [1,0,0,...]) + * - Exercises the UPDATE path (line 9080+ in sqlite-vec.c) + * - Tests with 16 dimensions (more realistic, exercises more of the + * quantization loop) + * - Interleaves KNN between mutations to stress the blob_reopen path + * when _rescore_vectors rows have been deleted/modified + */ +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + if (size < 8) return 0; + + int rc; + sqlite3 *db; + sqlite3_stmt *stmtInsert = NULL; + sqlite3_stmt *stmtUpdate = NULL; + sqlite3_stmt *stmtDelete = NULL; + sqlite3_stmt *stmtKnn = NULL; + + rc = sqlite3_open(":memory:", &db); + assert(rc == SQLITE_OK); + rc = sqlite3_vec_init(db, NULL, NULL); + assert(rc == SQLITE_OK); + + /* Use first byte to pick quantizer */ + int use_int8 = data[0] & 1; + data++; size--; + + const char *create_sql = use_int8 + ? "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[16] indexed by rescore(quantizer=int8))" + : "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[16] indexed by rescore(quantizer=bit))"; + + rc = sqlite3_exec(db, create_sql, NULL, NULL, NULL); + if (rc != SQLITE_OK) { sqlite3_close(db); return 0; } + + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); + sqlite3_prepare_v2(db, + "UPDATE v SET emb = ? WHERE rowid = ?", -1, &stmtUpdate, NULL); + sqlite3_prepare_v2(db, + "DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL); + sqlite3_prepare_v2(db, + "SELECT rowid, distance FROM v WHERE emb MATCH ? " + "ORDER BY distance LIMIT 5", -1, &stmtKnn, NULL); + + if (!stmtInsert || !stmtUpdate || !stmtDelete || !stmtKnn) + goto cleanup; + + size_t i = 0; + while (i + 2 <= size) { + uint8_t op = data[i++] % 5; /* 5 operations now */ + uint8_t rowid_byte = data[i++]; + int64_t rowid = (int64_t)(rowid_byte % 24) + 1; + + switch (op) { + case 0: { + /* INSERT: consume bytes for 16 floats */ + float vec[16] = {0}; + for (int j = 0; j < 16 && i < size; j++, i++) { + vec[j] = (float)((int8_t)data[i]) / 8.0f; + } + sqlite3_reset(stmtInsert); + sqlite3_bind_int64(stmtInsert, 1, rowid); + sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT); + sqlite3_step(stmtInsert); + break; + } + case 1: { + /* DELETE */ + sqlite3_reset(stmtDelete); + sqlite3_bind_int64(stmtDelete, 1, rowid); + sqlite3_step(stmtDelete); + break; + } + case 2: { + /* KNN with fuzz-controlled query vector */ + float qvec[16] = {0}; + for (int j = 0; j < 16 && i < size; j++, i++) { + qvec[j] = (float)((int8_t)data[i]) / 4.0f; + } + sqlite3_reset(stmtKnn); + sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_STATIC); + while (sqlite3_step(stmtKnn) == SQLITE_ROW) { + (void)sqlite3_column_int64(stmtKnn, 0); + (void)sqlite3_column_double(stmtKnn, 1); + } + break; + } + case 3: { + /* UPDATE: modify an existing vector (exercises rescore update path) */ + float vec[16] = {0}; + for (int j = 0; j < 16 && i < size; j++, i++) { + vec[j] = (float)((int8_t)data[i]) / 6.0f; + } + sqlite3_reset(stmtUpdate); + sqlite3_bind_blob(stmtUpdate, 1, vec, sizeof(vec), SQLITE_TRANSIENT); + sqlite3_bind_int64(stmtUpdate, 2, rowid); + sqlite3_step(stmtUpdate); + break; + } + case 4: { + /* INSERT then immediately UPDATE same row (stresses blob lifecycle) */ + float vec1[16] = {0}; + float vec2[16] = {0}; + for (int j = 0; j < 16 && i < size; j++, i++) { + vec1[j] = (float)((int8_t)data[i]) / 10.0f; + vec2[j] = -vec1[j]; /* opposite direction */ + } + /* Insert */ + sqlite3_reset(stmtInsert); + sqlite3_bind_int64(stmtInsert, 1, rowid); + sqlite3_bind_blob(stmtInsert, 2, vec1, sizeof(vec1), SQLITE_TRANSIENT); + if (sqlite3_step(stmtInsert) == SQLITE_DONE) { + /* Only update if insert succeeded (rowid might already exist) */ + sqlite3_reset(stmtUpdate); + sqlite3_bind_blob(stmtUpdate, 1, vec2, sizeof(vec2), SQLITE_TRANSIENT); + sqlite3_bind_int64(stmtUpdate, 2, rowid); + sqlite3_step(stmtUpdate); + } + break; + } + } + } + + /* Final consistency check: full scan must not crash */ + sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL); + +cleanup: + sqlite3_finalize(stmtInsert); + sqlite3_finalize(stmtUpdate); + sqlite3_finalize(stmtDelete); + sqlite3_finalize(stmtKnn); + sqlite3_close(db); + return 0; +} diff --git a/tests/fuzz/rescore-knn-deep.c b/tests/fuzz/rescore-knn-deep.c new file mode 100644 index 0000000..8ff3c37 --- /dev/null +++ b/tests/fuzz/rescore-knn-deep.c @@ -0,0 +1,178 @@ +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +/** + * Fuzz target: deep exercise of rescore KNN with fuzz-controlled query vectors + * and both quantizer types (bit + int8), multiple distance metrics. + * + * The existing rescore-operations.c only tests bit quantizer with a fixed + * query vector. This target: + * - Tests both bit and int8 quantizers + * - Uses fuzz-controlled query vectors (hits NaN/Inf/denormal paths) + * - Tests all distance metrics with int8 (L2, cosine, L1) + * - Exercises large LIMIT values (oversample multiplication) + * - Tests KNN with rowid IN constraints + * - Exercises the insert->query->update->query->delete->query cycle + */ +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + if (size < 20) return 0; + + int rc; + sqlite3 *db; + + rc = sqlite3_open(":memory:", &db); + assert(rc == SQLITE_OK); + rc = sqlite3_vec_init(db, NULL, NULL); + assert(rc == SQLITE_OK); + + /* Use first 4 bytes for configuration */ + uint8_t config = data[0]; + uint8_t num_inserts = (data[1] % 20) + 3; /* 3..22 inserts */ + uint8_t limit_val = (data[2] % 50) + 1; /* 1..50 for LIMIT */ + uint8_t metric_choice = data[3] % 3; + data += 4; + size -= 4; + + int use_int8 = config & 1; + const char *metric_str; + switch (metric_choice) { + case 0: metric_str = ""; break; /* default L2 */ + case 1: metric_str = " distance_metric=cosine"; break; + case 2: metric_str = " distance_metric=l1"; break; + default: metric_str = ""; break; + } + + /* Build CREATE TABLE statement */ + char create_sql[256]; + if (use_int8) { + snprintf(create_sql, sizeof(create_sql), + "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[16] indexed by rescore(quantizer=int8)%s)", metric_str); + } else { + /* bit quantizer ignores distance_metric for the coarse pass (always hamming), + but the float rescore phase uses the specified metric */ + snprintf(create_sql, sizeof(create_sql), + "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[16] indexed by rescore(quantizer=bit)%s)", metric_str); + } + + rc = sqlite3_exec(db, create_sql, NULL, NULL, NULL); + if (rc != SQLITE_OK) { sqlite3_close(db); return 0; } + + /* Insert vectors using fuzz data */ + { + sqlite3_stmt *ins = NULL; + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &ins, NULL); + if (!ins) { sqlite3_close(db); return 0; } + + size_t cursor = 0; + for (int i = 0; i < num_inserts && cursor + 1 < size; i++) { + float vec[16]; + for (int j = 0; j < 16; j++) { + if (cursor < size) { + /* Map fuzz byte to float -- includes potential for + interesting float values via reinterpretation */ + int8_t sb = (int8_t)data[cursor++]; + vec[j] = (float)sb / 5.0f; + } else { + vec[j] = 0.0f; + } + } + sqlite3_reset(ins); + sqlite3_bind_int64(ins, 1, (sqlite3_int64)(i + 1)); + sqlite3_bind_blob(ins, 2, vec, sizeof(vec), SQLITE_TRANSIENT); + sqlite3_step(ins); + } + sqlite3_finalize(ins); + } + + /* Build a fuzz-controlled query vector from remaining data */ + float qvec[16] = {0}; + { + size_t cursor = 0; + for (int j = 0; j < 16 && cursor < size; j++) { + int8_t sb = (int8_t)data[cursor++]; + qvec[j] = (float)sb / 3.0f; + } + } + + /* KNN query with fuzz-controlled vector and LIMIT */ + { + char knn_sql[256]; + snprintf(knn_sql, sizeof(knn_sql), + "SELECT rowid, distance FROM v WHERE emb MATCH ? " + "ORDER BY distance LIMIT %d", (int)limit_val); + + sqlite3_stmt *knn = NULL; + sqlite3_prepare_v2(db, knn_sql, -1, &knn, NULL); + if (knn) { + sqlite3_bind_blob(knn, 1, qvec, sizeof(qvec), SQLITE_STATIC); + while (sqlite3_step(knn) == SQLITE_ROW) { + /* Read results to ensure distance computation didn't produce garbage + that crashes the cursor iteration */ + (void)sqlite3_column_int64(knn, 0); + (void)sqlite3_column_double(knn, 1); + } + sqlite3_finalize(knn); + } + } + + /* Update some vectors, then query again */ + { + float uvec[16]; + for (int j = 0; j < 16; j++) uvec[j] = qvec[15 - j]; /* reverse of query */ + sqlite3_stmt *upd = NULL; + sqlite3_prepare_v2(db, + "UPDATE v SET emb = ? WHERE rowid = 1", -1, &upd, NULL); + if (upd) { + sqlite3_bind_blob(upd, 1, uvec, sizeof(uvec), SQLITE_STATIC); + sqlite3_step(upd); + sqlite3_finalize(upd); + } + } + + /* Second KNN after update */ + { + sqlite3_stmt *knn = NULL; + sqlite3_prepare_v2(db, + "SELECT rowid, distance FROM v WHERE emb MATCH ? " + "ORDER BY distance LIMIT 10", -1, &knn, NULL); + if (knn) { + sqlite3_bind_blob(knn, 1, qvec, sizeof(qvec), SQLITE_STATIC); + while (sqlite3_step(knn) == SQLITE_ROW) {} + sqlite3_finalize(knn); + } + } + + /* Delete half the rows, then KNN again */ + for (int i = 1; i <= num_inserts; i += 2) { + char del_sql[64]; + snprintf(del_sql, sizeof(del_sql), + "DELETE FROM v WHERE rowid = %d", i); + sqlite3_exec(db, del_sql, NULL, NULL, NULL); + } + + /* Third KNN after deletes -- exercises distance computation over + zeroed-out slots in the quantized chunk */ + { + sqlite3_stmt *knn = NULL; + sqlite3_prepare_v2(db, + "SELECT rowid, distance FROM v WHERE emb MATCH ? " + "ORDER BY distance LIMIT 5", -1, &knn, NULL); + if (knn) { + sqlite3_bind_blob(knn, 1, qvec, sizeof(qvec), SQLITE_STATIC); + while (sqlite3_step(knn) == SQLITE_ROW) {} + sqlite3_finalize(knn); + } + } + + sqlite3_close(db); + return 0; +} diff --git a/tests/fuzz/rescore-operations.c b/tests/fuzz/rescore-operations.c new file mode 100644 index 0000000..4bb7ff1 --- /dev/null +++ b/tests/fuzz/rescore-operations.c @@ -0,0 +1,96 @@ +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + if (size < 6) return 0; + + int rc; + sqlite3 *db; + sqlite3_stmt *stmtInsert = NULL; + sqlite3_stmt *stmtDelete = NULL; + sqlite3_stmt *stmtKnn = NULL; + sqlite3_stmt *stmtScan = NULL; + + rc = sqlite3_open(":memory:", &db); + assert(rc == SQLITE_OK); + rc = sqlite3_vec_init(db, NULL, NULL); + assert(rc == SQLITE_OK); + + rc = sqlite3_exec(db, + "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[8] indexed by rescore(quantizer=bit))", + NULL, NULL, NULL); + if (rc != SQLITE_OK) { sqlite3_close(db); return 0; } + + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); + sqlite3_prepare_v2(db, + "DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL); + sqlite3_prepare_v2(db, + "SELECT rowid, distance FROM v WHERE emb MATCH ? ORDER BY distance LIMIT 3", + -1, &stmtKnn, NULL); + sqlite3_prepare_v2(db, + "SELECT rowid FROM v", -1, &stmtScan, NULL); + + if (!stmtInsert || !stmtDelete || !stmtKnn || !stmtScan) goto cleanup; + + size_t i = 0; + while (i + 2 <= size) { + uint8_t op = data[i++] % 4; + uint8_t rowid_byte = data[i++]; + int64_t rowid = (int64_t)(rowid_byte % 32) + 1; + + switch (op) { + case 0: { + // INSERT: consume 32 bytes for 8 floats, or use what's left + float vec[8] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + for (int j = 0; j < 8 && i < size; j++, i++) { + vec[j] = (float)((int8_t)data[i]) / 10.0f; + } + sqlite3_reset(stmtInsert); + sqlite3_bind_int64(stmtInsert, 1, rowid); + sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT); + sqlite3_step(stmtInsert); + break; + } + case 1: { + // DELETE + sqlite3_reset(stmtDelete); + sqlite3_bind_int64(stmtDelete, 1, rowid); + sqlite3_step(stmtDelete); + break; + } + case 2: { + // KNN query with a fixed query vector + float qvec[8] = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + sqlite3_reset(stmtKnn); + sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_STATIC); + while (sqlite3_step(stmtKnn) == SQLITE_ROW) {} + break; + } + case 3: { + // Full scan + sqlite3_reset(stmtScan); + while (sqlite3_step(stmtScan) == SQLITE_ROW) {} + break; + } + } + } + + // Final operations -- must not crash regardless of prior state + sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL); + +cleanup: + sqlite3_finalize(stmtInsert); + sqlite3_finalize(stmtDelete); + sqlite3_finalize(stmtKnn); + sqlite3_finalize(stmtScan); + sqlite3_close(db); + return 0; +} diff --git a/tests/fuzz/rescore-quantize-edge.c b/tests/fuzz/rescore-quantize-edge.c new file mode 100644 index 0000000..4ab9e20 --- /dev/null +++ b/tests/fuzz/rescore-quantize-edge.c @@ -0,0 +1,177 @@ +#include +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +/* Test wrappers from sqlite-vec-rescore.c (SQLITE_VEC_TEST build) */ +extern void _test_rescore_quantize_float_to_bit(const float *src, uint8_t *dst, size_t dim); +extern void _test_rescore_quantize_float_to_int8(const float *src, int8_t *dst, size_t dim); +extern size_t _test_rescore_quantized_byte_size_bit(size_t dimensions); +extern size_t _test_rescore_quantized_byte_size_int8(size_t dimensions); + +/** + * Fuzz target: edge cases in rescore quantization functions. + * + * The existing rescore-quantize.c only tests dimensions that are multiples of 8 + * and never passes special float values. This target: + * + * - Tests rescore_quantized_byte_size with arbitrary dimension values + * (including 0, 1, 7, MAX values -- looking for integer division issues) + * - Passes raw float reinterpretation of fuzz bytes (NaN, Inf, denormals, + * negative zero -- these are the values that break min/max/range logic) + * - Tests the int8 quantizer with all-identical values (range=0 branch) + * - Tests the int8 quantizer with extreme ranges (overflow in scale calc) + * - Tests bit quantizer with exact float threshold (0.0f boundary) + */ +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + if (size < 8) return 0; + + uint8_t mode = data[0] % 5; + data++; size--; + + switch (mode) { + case 0: { + /* Test rescore_quantized_byte_size with fuzz-controlled dimensions. + This function does dimensions / CHAR_BIT for bit, dimensions for int8. + We're checking it doesn't do anything weird with edge values. */ + if (size < sizeof(size_t)) return 0; + size_t dim; + memcpy(&dim, data, sizeof(dim)); + + /* These should never crash, just return values */ + size_t bit_size = _test_rescore_quantized_byte_size_bit(dim); + size_t int8_size = _test_rescore_quantized_byte_size_int8(dim); + + /* Verify basic invariants */ + (void)bit_size; + (void)int8_size; + break; + } + + case 1: { + /* Bit quantize with raw reinterpreted floats (NaN, Inf, denormal). + The key check: src[i] >= 0.0f -- NaN comparison is always false, + so NaN should produce 0-bits. But denormals and -0.0f are tricky. */ + size_t num_floats = size / sizeof(float); + if (num_floats == 0) return 0; + /* Round to multiple of 8 for bit quantizer */ + size_t dim = (num_floats / 8) * 8; + if (dim == 0) return 0; + + const float *src = (const float *)data; + size_t bit_bytes = dim / 8; + uint8_t *dst = (uint8_t *)malloc(bit_bytes); + if (!dst) return 0; + + _test_rescore_quantize_float_to_bit(src, dst, dim); + + /* Verify: for each bit, if src >= 0 then bit should be set */ + for (size_t i = 0; i < dim; i++) { + int bit_set = (dst[i / 8] >> (i % 8)) & 1; + if (src[i] >= 0.0f) { + assert(bit_set == 1); + } else if (src[i] < 0.0f) { + /* Definitely negative -- bit must be 0 */ + assert(bit_set == 0); + } + /* NaN: comparison is false, so bit_set should be 0 */ + } + + free(dst); + break; + } + + case 2: { + /* Int8 quantize with raw reinterpreted floats. + The dangerous paths: + - All values identical (range == 0) -> memset path + - vmin/vmax with NaN (NaN < anything is false, NaN > anything is false) + - Extreme range causing scale = 255/range to be Inf or 0 + - denormals near the clamping boundaries */ + size_t num_floats = size / sizeof(float); + if (num_floats == 0) return 0; + + const float *src = (const float *)data; + int8_t *dst = (int8_t *)malloc(num_floats); + if (!dst) return 0; + + _test_rescore_quantize_float_to_int8(src, dst, num_floats); + + /* Output must always be in [-128, 127] (trivially true for int8_t, + but check the actual clamping logic worked) */ + for (size_t i = 0; i < num_floats; i++) { + assert(dst[i] >= -128 && dst[i] <= 127); + } + + free(dst); + break; + } + + case 3: { + /* Int8 quantize stress: all-same values (range=0 branch) */ + size_t dim = (size < 64) ? size : 64; + if (dim == 0) return 0; + + float *src = (float *)malloc(dim * sizeof(float)); + int8_t *dst = (int8_t *)malloc(dim); + if (!src || !dst) { free(src); free(dst); return 0; } + + /* Fill with a single value derived from fuzz data */ + float val; + memcpy(&val, data, sizeof(float) < size ? sizeof(float) : size); + for (size_t i = 0; i < dim; i++) src[i] = val; + + _test_rescore_quantize_float_to_int8(src, dst, dim); + + /* All outputs should be 0 when range == 0 */ + for (size_t i = 0; i < dim; i++) { + assert(dst[i] == 0); + } + + free(src); + free(dst); + break; + } + + case 4: { + /* Int8 quantize with extreme range: one huge positive, one huge negative. + Tests scale = 255.0f / range overflow to Inf, then v * Inf = Inf, + then clamping to [-128, 127]. */ + if (size < 2 * sizeof(float)) return 0; + + float extreme[2]; + memcpy(extreme, data, 2 * sizeof(float)); + + /* Only test if both are finite (NaN/Inf tested in case 2) */ + if (!isfinite(extreme[0]) || !isfinite(extreme[1])) return 0; + + /* Build a vector with these two extreme values plus some fuzz */ + size_t dim = 16; + float src[16]; + src[0] = extreme[0]; + src[1] = extreme[1]; + for (size_t i = 2; i < dim; i++) { + if (2 * sizeof(float) + (i - 2) < size) { + src[i] = (float)((int8_t)data[2 * sizeof(float) + (i - 2)]) * 1000.0f; + } else { + src[i] = 0.0f; + } + } + + int8_t dst[16]; + _test_rescore_quantize_float_to_int8(src, dst, dim); + + for (size_t i = 0; i < dim; i++) { + assert(dst[i] >= -128 && dst[i] <= 127); + } + break; + } + } + + return 0; +} diff --git a/tests/fuzz/rescore-quantize.c b/tests/fuzz/rescore-quantize.c new file mode 100644 index 0000000..6aad445 --- /dev/null +++ b/tests/fuzz/rescore-quantize.c @@ -0,0 +1,54 @@ +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +/* These are SQLITE_VEC_TEST wrappers defined in sqlite-vec-rescore.c */ +extern void _test_rescore_quantize_float_to_bit(const float *src, uint8_t *dst, size_t dim); +extern void _test_rescore_quantize_float_to_int8(const float *src, int8_t *dst, size_t dim); + +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + /* Need at least 4 bytes for one float */ + if (size < 4) return 0; + + /* Use the input as an array of floats. Dimensions must be a multiple of 8 + * for the bit quantizer. */ + size_t num_floats = size / sizeof(float); + if (num_floats == 0) return 0; + + /* Round down to multiple of 8 for bit quantizer compatibility */ + size_t dim = (num_floats / 8) * 8; + if (dim == 0) dim = 8; + if (dim > num_floats) return 0; + + const float *src = (const float *)data; + + /* Allocate output buffers */ + size_t bit_bytes = dim / 8; + uint8_t *bit_dst = (uint8_t *)malloc(bit_bytes); + int8_t *int8_dst = (int8_t *)malloc(dim); + if (!bit_dst || !int8_dst) { + free(bit_dst); + free(int8_dst); + return 0; + } + + /* Test bit quantization */ + _test_rescore_quantize_float_to_bit(src, bit_dst, dim); + + /* Test int8 quantization */ + _test_rescore_quantize_float_to_int8(src, int8_dst, dim); + + /* Verify int8 output is in range */ + for (size_t i = 0; i < dim; i++) { + assert(int8_dst[i] >= -128 && int8_dst[i] <= 127); + } + + free(bit_dst); + free(int8_dst); + return 0; +} diff --git a/tests/fuzz/rescore-shadow-corrupt.c b/tests/fuzz/rescore-shadow-corrupt.c new file mode 100644 index 0000000..edd87ef --- /dev/null +++ b/tests/fuzz/rescore-shadow-corrupt.c @@ -0,0 +1,230 @@ +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +/** + * Fuzz target: corrupt rescore shadow tables then exercise KNN/read/write. + * + * This targets the dangerous code paths in rescore_knn (Phase 1 + 2): + * - sqlite3_blob_read into baseVectors with potentially wrong-sized blobs + * - distance computation on corrupted/partial quantized data + * - blob_reopen on _rescore_vectors with missing/corrupted rows + * - insert/delete after corruption (blob_write to wrong offsets) + * + * The existing shadow-corrupt.c only tests vec0 without rescore. + */ +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + if (size < 4) return 0; + + int rc; + sqlite3 *db; + + rc = sqlite3_open(":memory:", &db); + assert(rc == SQLITE_OK); + rc = sqlite3_vec_init(db, NULL, NULL); + assert(rc == SQLITE_OK); + + /* Pick quantizer type from first byte */ + int use_int8 = data[0] & 1; + int target = (data[1] % 8); + const uint8_t *payload = data + 2; + int payload_size = (int)(size - 2); + + const char *create_sql = use_int8 + ? "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[16] indexed by rescore(quantizer=int8))" + : "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[16] indexed by rescore(quantizer=bit))"; + + rc = sqlite3_exec(db, create_sql, NULL, NULL, NULL); + if (rc != SQLITE_OK) { sqlite3_close(db); return 0; } + + /* Insert several vectors so there's a full chunk to corrupt */ + { + sqlite3_stmt *ins = NULL; + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &ins, NULL); + if (!ins) { sqlite3_close(db); return 0; } + + for (int i = 1; i <= 8; i++) { + float vec[16]; + for (int j = 0; j < 16; j++) vec[j] = (float)(i * 10 + j) / 100.0f; + sqlite3_reset(ins); + sqlite3_bind_int64(ins, 1, i); + sqlite3_bind_blob(ins, 2, vec, sizeof(vec), SQLITE_TRANSIENT); + sqlite3_step(ins); + } + sqlite3_finalize(ins); + } + + /* Now corrupt rescore shadow tables based on fuzz input */ + sqlite3_stmt *stmt = NULL; + + switch (target) { + case 0: { + /* Corrupt _rescore_chunks00 vectors blob with fuzz data */ + rc = sqlite3_prepare_v2(db, + "UPDATE v_rescore_chunks00 SET vectors = ? WHERE rowid = 1", + -1, &stmt, NULL); + if (rc == SQLITE_OK) { + sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC); + sqlite3_step(stmt); + sqlite3_finalize(stmt); + stmt = NULL; + } + break; + } + case 1: { + /* Corrupt _rescore_vectors00 vector blob for a specific row */ + rc = sqlite3_prepare_v2(db, + "UPDATE v_rescore_vectors00 SET vector = ? WHERE rowid = 3", + -1, &stmt, NULL); + if (rc == SQLITE_OK) { + sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC); + sqlite3_step(stmt); + sqlite3_finalize(stmt); + stmt = NULL; + } + break; + } + case 2: { + /* Truncate the quantized chunk blob to wrong size */ + rc = sqlite3_prepare_v2(db, + "UPDATE v_rescore_chunks00 SET vectors = X'DEADBEEF' WHERE rowid = 1", + -1, &stmt, NULL); + if (rc == SQLITE_OK) { + sqlite3_step(stmt); + sqlite3_finalize(stmt); + stmt = NULL; + } + break; + } + case 3: { + /* Delete rows from _rescore_vectors (orphan the float vectors) */ + sqlite3_exec(db, + "DELETE FROM v_rescore_vectors00 WHERE rowid IN (2, 4, 6)", + NULL, NULL, NULL); + break; + } + case 4: { + /* Delete the chunk row entirely from _rescore_chunks */ + sqlite3_exec(db, + "DELETE FROM v_rescore_chunks00 WHERE rowid = 1", + NULL, NULL, NULL); + break; + } + case 5: { + /* Set vectors to NULL in _rescore_chunks */ + sqlite3_exec(db, + "UPDATE v_rescore_chunks00 SET vectors = NULL WHERE rowid = 1", + NULL, NULL, NULL); + break; + } + case 6: { + /* Set vector to NULL in _rescore_vectors */ + sqlite3_exec(db, + "UPDATE v_rescore_vectors00 SET vector = NULL WHERE rowid = 3", + NULL, NULL, NULL); + break; + } + case 7: { + /* Corrupt BOTH tables with fuzz data */ + int half = payload_size / 2; + rc = sqlite3_prepare_v2(db, + "UPDATE v_rescore_chunks00 SET vectors = ? WHERE rowid = 1", + -1, &stmt, NULL); + if (rc == SQLITE_OK) { + sqlite3_bind_blob(stmt, 1, payload, half, SQLITE_STATIC); + sqlite3_step(stmt); + sqlite3_finalize(stmt); + stmt = NULL; + } + rc = sqlite3_prepare_v2(db, + "UPDATE v_rescore_vectors00 SET vector = ? WHERE rowid = 1", + -1, &stmt, NULL); + if (rc == SQLITE_OK) { + sqlite3_bind_blob(stmt, 1, payload + half, + payload_size - half, SQLITE_STATIC); + sqlite3_step(stmt); + sqlite3_finalize(stmt); + stmt = NULL; + } + break; + } + } + + /* Exercise ALL read/write paths -- NONE should crash */ + + /* KNN query (triggers rescore_knn Phase 1 + Phase 2) */ + { + float qvec[16] = {1,0,0,0, 0,1,0,0, 0,0,1,0, 0,0,0,1}; + sqlite3_stmt *knn = NULL; + sqlite3_prepare_v2(db, + "SELECT rowid, distance FROM v WHERE emb MATCH ? " + "ORDER BY distance LIMIT 5", -1, &knn, NULL); + if (knn) { + sqlite3_bind_blob(knn, 1, qvec, sizeof(qvec), SQLITE_STATIC); + while (sqlite3_step(knn) == SQLITE_ROW) {} + sqlite3_finalize(knn); + } + } + + /* Full scan (triggers reading from _rescore_vectors) */ + sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL); + + /* Point lookups */ + sqlite3_exec(db, "SELECT * FROM v WHERE rowid = 1", NULL, NULL, NULL); + sqlite3_exec(db, "SELECT * FROM v WHERE rowid = 3", NULL, NULL, NULL); + + /* Insert after corruption */ + { + float vec[16] = {0}; + sqlite3_stmt *ins = NULL; + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (99, ?)", -1, &ins, NULL); + if (ins) { + sqlite3_bind_blob(ins, 1, vec, sizeof(vec), SQLITE_STATIC); + sqlite3_step(ins); + sqlite3_finalize(ins); + } + } + + /* Delete after corruption */ + sqlite3_exec(db, "DELETE FROM v WHERE rowid = 5", NULL, NULL, NULL); + + /* Update after corruption */ + { + float vec[16] = {1,1,1,1, 1,1,1,1, 1,1,1,1, 1,1,1,1}; + sqlite3_stmt *upd = NULL; + sqlite3_prepare_v2(db, + "UPDATE v SET emb = ? WHERE rowid = 1", -1, &upd, NULL); + if (upd) { + sqlite3_bind_blob(upd, 1, vec, sizeof(vec), SQLITE_STATIC); + sqlite3_step(upd); + sqlite3_finalize(upd); + } + } + + /* KNN again after modifications to corrupted state */ + { + float qvec[16] = {0,0,0,0, 0,0,0,0, 1,1,1,1, 1,1,1,1}; + sqlite3_stmt *knn = NULL; + sqlite3_prepare_v2(db, + "SELECT rowid, distance FROM v WHERE emb MATCH ? " + "ORDER BY distance LIMIT 3", -1, &knn, NULL); + if (knn) { + sqlite3_bind_blob(knn, 1, qvec, sizeof(qvec), SQLITE_STATIC); + while (sqlite3_step(knn) == SQLITE_ROW) {} + sqlite3_finalize(knn); + } + } + + sqlite3_exec(db, "DROP TABLE v", NULL, NULL, NULL); + sqlite3_close(db); + return 0; +} diff --git a/tests/sqlite-vec-internal.h b/tests/sqlite-vec-internal.h index a02c72a..cbc2c08 100644 --- a/tests/sqlite-vec-internal.h +++ b/tests/sqlite-vec-internal.h @@ -65,8 +65,23 @@ enum Vec0DistanceMetrics { enum Vec0IndexType { VEC0_INDEX_TYPE_FLAT = 1, +#ifdef SQLITE_VEC_ENABLE_RESCORE + VEC0_INDEX_TYPE_RESCORE = 2, +#endif }; +#ifdef SQLITE_VEC_ENABLE_RESCORE +enum Vec0RescoreQuantizerType { + VEC0_RESCORE_QUANTIZER_BIT = 1, + VEC0_RESCORE_QUANTIZER_INT8 = 2, +}; + +struct Vec0RescoreConfig { + enum Vec0RescoreQuantizerType quantizer_type; + int oversample; +}; +#endif + struct VectorColumnDefinition { char *name; int name_length; @@ -74,6 +89,9 @@ struct VectorColumnDefinition { enum VectorElementType element_type; enum Vec0DistanceMetrics distance_metric; enum Vec0IndexType index_type; +#ifdef SQLITE_VEC_ENABLE_RESCORE + struct Vec0RescoreConfig rescore; +#endif }; int vec0_parse_vector_column(const char *source, int source_length, @@ -88,6 +106,13 @@ int vec0_parse_partition_key_definition(const char *source, int source_length, float _test_distance_l2_sqr_float(const float *a, const float *b, size_t dims); float _test_distance_cosine_float(const float *a, const float *b, size_t dims); float _test_distance_hamming(const unsigned char *a, const unsigned char *b, size_t dims); + +#ifdef SQLITE_VEC_ENABLE_RESCORE +void _test_rescore_quantize_float_to_bit(const float *src, uint8_t *dst, size_t dim); +void _test_rescore_quantize_float_to_int8(const float *src, int8_t *dst, size_t dim); +size_t _test_rescore_quantized_byte_size_bit(size_t dimensions); +size_t _test_rescore_quantized_byte_size_int8(size_t dimensions); +#endif #endif #endif /* SQLITE_VEC_INTERNAL_H */ diff --git a/tests/test-rescore-mutations.py b/tests/test-rescore-mutations.py new file mode 100644 index 0000000..28495c2 --- /dev/null +++ b/tests/test-rescore-mutations.py @@ -0,0 +1,470 @@ +"""Mutation and edge-case tests for the rescore index feature.""" +import struct +import sqlite3 +import pytest +import math +import random + + +@pytest.fixture() +def db(): + db = sqlite3.connect(":memory:") + db.row_factory = sqlite3.Row + db.enable_load_extension(True) + db.load_extension("dist/vec0") + db.enable_load_extension(False) + return db + + +def float_vec(values): + """Pack a list of floats into a blob for sqlite-vec.""" + return struct.pack(f"{len(values)}f", *values) + + +def unpack_float_vec(blob): + """Unpack a float vector blob.""" + n = len(blob) // 4 + return list(struct.unpack(f"{n}f", blob)) + + +# ============================================================================ +# Error cases: rescore + aux/metadata/partition +# ============================================================================ + + +def test_create_error_with_aux_column(db): + """Rescore should reject auxiliary columns.""" + with pytest.raises(sqlite3.OperationalError, match="Auxiliary columns"): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)," + " +extra text" + ")" + ) + + +def test_create_error_with_metadata_column(db): + """Rescore should reject metadata columns.""" + with pytest.raises(sqlite3.OperationalError, match="Metadata columns"): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)," + " genre text" + ")" + ) + + +def test_create_error_with_partition_key(db): + """Rescore should reject partition key columns.""" + with pytest.raises(sqlite3.OperationalError, match="Partition key"): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)," + " user_id integer partition key" + ")" + ) + + +# ============================================================================ +# Insert / batch / delete / update mutations +# ============================================================================ + + +def test_insert_single_verify_knn(db): + """Insert a single row and verify KNN returns it.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)]) + rows = db.execute( + "SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1", + [float_vec([1.0] * 8)], + ).fetchall() + assert len(rows) == 1 + assert rows[0]["rowid"] == 1 + assert rows[0]["distance"] < 0.01 + + +def test_insert_large_batch(db): + """Insert 200+ rows (multiple chunks with default chunk_size=1024) and verify count and KNN.""" + dim = 16 + n = 200 + random.seed(99) + db.execute( + f"CREATE VIRTUAL TABLE t USING vec0(" + f" embedding float[{dim}] indexed by rescore(quantizer=int8)" + f")" + ) + for i in range(n): + v = [random.gauss(0, 1) for _ in range(dim)] + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (?, ?)", + [i + 1, float_vec(v)], + ) + row = db.execute("SELECT count(*) as cnt FROM t").fetchone() + assert row["cnt"] == n + + # KNN should return results + query = float_vec([random.gauss(0, 1) for _ in range(dim)]) + rows = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10", + [query], + ).fetchall() + assert len(rows) == 10 + + +def test_delete_all_rows(db): + """Delete every row, verify count=0, KNN returns empty.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + for i in range(20): + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (?, ?)", + [i + 1, float_vec([float(i)] * 8)], + ) + assert db.execute("SELECT count(*) as cnt FROM t").fetchone()["cnt"] == 20 + + for i in range(20): + db.execute("DELETE FROM t WHERE rowid = ?", [i + 1]) + + assert db.execute("SELECT count(*) as cnt FROM t").fetchone()["cnt"] == 0 + + rows = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 5", + [float_vec([0.0] * 8)], + ).fetchall() + assert len(rows) == 0 + + +def test_delete_then_reinsert_same_rowid(db): + """Delete rowid=1, re-insert rowid=1 with different vector, verify KNN uses new vector.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=int8)" + ")" + ) + # Insert rowid=1 near origin, rowid=2 far from origin + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (1, ?)", + [float_vec([0.1] * 8)], + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (2, ?)", + [float_vec([100.0] * 8)], + ) + + # KNN to [0]*8 -> rowid 1 is closer + rows = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1", + [float_vec([0.0] * 8)], + ).fetchall() + assert rows[0]["rowid"] == 1 + + # Delete rowid=1, re-insert with vector far from origin + db.execute("DELETE FROM t WHERE rowid = 1") + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (1, ?)", + [float_vec([200.0] * 8)], + ) + + # Now KNN to [0]*8 -> rowid 2 should be closer + rows = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1", + [float_vec([0.0] * 8)], + ).fetchall() + assert rows[0]["rowid"] == 2 + + +def test_update_vector(db): + """UPDATE the vector column and verify KNN reflects new value.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=int8)" + ")" + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (1, ?)", + [float_vec([0.0] * 8)], + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (2, ?)", + [float_vec([10.0] * 8)], + ) + + # Update rowid=1 to be far away + db.execute( + "UPDATE t SET embedding = ? WHERE rowid = 1", + [float_vec([100.0] * 8)], + ) + + # Now KNN to [0]*8 -> rowid 2 should be closest + rows = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1", + [float_vec([0.0] * 8)], + ).fetchall() + assert rows[0]["rowid"] == 2 + + +def test_knn_after_delete_all_but_one(db): + """Insert 50 rows, delete 49, KNN should only return the survivor.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + for i in range(50): + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (?, ?)", + [i + 1, float_vec([float(i)] * 8)], + ) + # Delete all except rowid=25 + for i in range(50): + if i + 1 != 25: + db.execute("DELETE FROM t WHERE rowid = ?", [i + 1]) + + assert db.execute("SELECT count(*) as cnt FROM t").fetchone()["cnt"] == 1 + + rows = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10", + [float_vec([0.0] * 8)], + ).fetchall() + assert len(rows) == 1 + assert rows[0]["rowid"] == 25 + + +# ============================================================================ +# Edge cases +# ============================================================================ + + +def test_single_row_knn(db): + """Table with exactly 1 row. LIMIT 1 returns it; LIMIT 5 returns 1.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)]) + + rows = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1", + [float_vec([1.0] * 8)], + ).fetchall() + assert len(rows) == 1 + + rows = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 5", + [float_vec([1.0] * 8)], + ).fetchall() + assert len(rows) == 1 + + +def test_knn_with_all_identical_vectors(db): + """All vectors are the same. All distances should be equal.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=int8)" + ")" + ) + vec = [3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0] + for i in range(10): + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (?, ?)", + [i + 1, float_vec(vec)], + ) + + rows = db.execute( + "SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10", + [float_vec(vec)], + ).fetchall() + assert len(rows) == 10 + # All distances should be ~0 (exact match) + for r in rows: + assert r["distance"] < 0.01 + + +def test_zero_vector_insert(db): + """Insert the zero vector [0,0,...,0]. Should not crash quantization.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (1, ?)", + [float_vec([0.0] * 8)], + ) + row = db.execute("SELECT count(*) as cnt FROM t").fetchone() + assert row["cnt"] == 1 + + # Also test int8 quantizer with zero vector + db.execute( + "CREATE VIRTUAL TABLE t2 USING vec0(" + " embedding float[8] indexed by rescore(quantizer=int8)" + ")" + ) + db.execute( + "INSERT INTO t2(rowid, embedding) VALUES (1, ?)", + [float_vec([0.0] * 8)], + ) + row = db.execute("SELECT count(*) as cnt FROM t2").fetchone() + assert row["cnt"] == 1 + + +def test_very_large_values(db): + """Insert vectors with very large float values. Quantization should not crash.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=int8)" + ")" + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (1, ?)", + [float_vec([1e30] * 8)], + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (2, ?)", + [float_vec([1e30, -1e30, 1e30, -1e30, 1e30, -1e30, 1e30, -1e30])], + ) + row = db.execute("SELECT count(*) as cnt FROM t").fetchone() + assert row["cnt"] == 2 + + +def test_negative_values(db): + """Insert vectors with all negative values. Bit quantization maps all to 0.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (1, ?)", + [float_vec([-1.0, -2.0, -3.0, -4.0, -5.0, -6.0, -7.0, -8.0])], + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (2, ?)", + [float_vec([-0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8])], + ) + row = db.execute("SELECT count(*) as cnt FROM t").fetchone() + assert row["cnt"] == 2 + + # KNN should still work + rows = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 2", + [float_vec([-0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8])], + ).fetchall() + assert len(rows) == 2 + assert rows[0]["rowid"] == 2 + + +def test_single_dimension(db): + """Single-dimension vector (edge case for quantization).""" + # int8 quantizer (bit needs dim divisible by 8) + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=int8)" + ")" + ) + db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)]) + db.execute("INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec([5.0] * 8)]) + rows = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1", + [float_vec([1.0] * 8)], + ).fetchall() + assert rows[0]["rowid"] == 1 + + +# ============================================================================ +# vec_debug() verification +# ============================================================================ + + +def test_vec_debug_contains_rescore(db): + """vec_debug() should contain 'rescore' in build flags when compiled with SQLITE_VEC_ENABLE_RESCORE.""" + row = db.execute("SELECT vec_debug() as d").fetchone() + assert "rescore" in row["d"] + + +# ============================================================================ +# Insert batch recall test +# ============================================================================ + + +def test_insert_batch_recall(db): + """Insert 150 rows and verify KNN recall is reasonable (>0.6).""" + dim = 16 + n = 150 + k = 10 + random.seed(77) + + db.execute( + f"CREATE VIRTUAL TABLE t_rescore USING vec0(" + f" embedding float[{dim}] indexed by rescore(quantizer=int8, oversample=16)" + f")" + ) + db.execute( + f"CREATE VIRTUAL TABLE t_flat USING vec0(embedding float[{dim}])" + ) + + vectors = [[random.gauss(0, 1) for _ in range(dim)] for _ in range(n)] + for i, v in enumerate(vectors): + blob = float_vec(v) + db.execute( + "INSERT INTO t_rescore(rowid, embedding) VALUES (?, ?)", [i + 1, blob] + ) + db.execute( + "INSERT INTO t_flat(rowid, embedding) VALUES (?, ?)", [i + 1, blob] + ) + + query = float_vec([random.gauss(0, 1) for _ in range(dim)]) + + rescore_rows = db.execute( + "SELECT rowid FROM t_rescore WHERE embedding MATCH ? ORDER BY distance LIMIT ?", + [query, k], + ).fetchall() + flat_rows = db.execute( + "SELECT rowid FROM t_flat WHERE embedding MATCH ? ORDER BY distance LIMIT ?", + [query, k], + ).fetchall() + + rescore_ids = {r["rowid"] for r in rescore_rows} + flat_ids = {r["rowid"] for r in flat_rows} + recall = len(rescore_ids & flat_ids) / k + assert recall >= 0.6, f"Recall too low: {recall}" + + +# ============================================================================ +# Distance metric variants +# ============================================================================ + + +def test_knn_int8_cosine(db): + """Rescore with quantizer=int8 and distance_metric=cosine.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] distance_metric=cosine indexed by rescore(quantizer=int8)" + ")" + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (1, ?)", + [float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (2, ?)", + [float_vec([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (3, ?)", + [float_vec([1.0, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ) + rows = db.execute( + "SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 2", + [float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ).fetchall() + assert rows[0]["rowid"] == 1 + assert rows[0]["distance"] < 0.01 diff --git a/tests/test-rescore.py b/tests/test-rescore.py new file mode 100644 index 0000000..5025857 --- /dev/null +++ b/tests/test-rescore.py @@ -0,0 +1,568 @@ +"""Tests for the rescore index feature in sqlite-vec.""" +import struct +import sqlite3 +import pytest +import math +import random + + +@pytest.fixture() +def db(): + db = sqlite3.connect(":memory:") + db.row_factory = sqlite3.Row + db.enable_load_extension(True) + db.load_extension("dist/vec0") + db.enable_load_extension(False) + return db + + +def float_vec(values): + """Pack a list of floats into a blob for sqlite-vec.""" + return struct.pack(f"{len(values)}f", *values) + + +def unpack_float_vec(blob): + """Unpack a float vector blob.""" + n = len(blob) // 4 + return list(struct.unpack(f"{n}f", blob)) + + +# ============================================================================ +# Creation tests +# ============================================================================ + + +def test_create_bit(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[128] indexed by rescore(quantizer=bit)" + ")" + ) + # Table exists and has the right structure + row = db.execute( + "SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'" + ).fetchone() + assert row["cnt"] > 0 + + +def test_create_int8(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[128] indexed by rescore(quantizer=int8)" + ")" + ) + row = db.execute( + "SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'" + ).fetchone() + assert row["cnt"] > 0 + + +def test_create_with_oversample(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[128] indexed by rescore(quantizer=bit, oversample=16)" + ")" + ) + row = db.execute( + "SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'" + ).fetchone() + assert row["cnt"] > 0 + + +def test_create_with_distance_metric(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[128] distance_metric=cosine indexed by rescore(quantizer=bit)" + ")" + ) + row = db.execute( + "SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'" + ).fetchone() + assert row["cnt"] > 0 + + +def test_create_error_missing_quantizer(db): + with pytest.raises(sqlite3.OperationalError): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[128] indexed by rescore(oversample=8)" + ")" + ) + + +def test_create_error_invalid_quantizer(db): + with pytest.raises(sqlite3.OperationalError): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[128] indexed by rescore(quantizer=float)" + ")" + ) + + +def test_create_error_on_bit_column(db): + with pytest.raises(sqlite3.OperationalError): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding bit[1024] indexed by rescore(quantizer=bit)" + ")" + ) + + +def test_create_error_on_int8_column(db): + with pytest.raises(sqlite3.OperationalError): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding int8[128] indexed by rescore(quantizer=bit)" + ")" + ) + + +def test_create_error_bad_oversample_zero(db): + with pytest.raises(sqlite3.OperationalError): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[128] indexed by rescore(quantizer=bit, oversample=0)" + ")" + ) + + +def test_create_error_bad_oversample_too_large(db): + with pytest.raises(sqlite3.OperationalError): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[128] indexed by rescore(quantizer=bit, oversample=999)" + ")" + ) + + +def test_create_error_bit_dim_not_divisible_by_8(db): + with pytest.raises(sqlite3.OperationalError): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[100] indexed by rescore(quantizer=bit)" + ")" + ) + + +# ============================================================================ +# Shadow table tests +# ============================================================================ + + +def test_shadow_tables_exist(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[128] indexed by rescore(quantizer=bit)" + ")" + ) + tables = [ + r[0] + for r in db.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 't_%' ORDER BY name" + ).fetchall() + ] + assert "t_rescore_chunks00" in tables + assert "t_rescore_vectors00" in tables + # Rescore columns don't create _vector_chunks + assert "t_vector_chunks00" not in tables + + +def test_drop_cleans_up(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[128] indexed by rescore(quantizer=bit)" + ")" + ) + db.execute("DROP TABLE t") + tables = [ + r[0] + for r in db.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 't_%'" + ).fetchall() + ] + assert len(tables) == 0 + + +# ============================================================================ +# Insert tests +# ============================================================================ + + +def test_insert_single(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)]) + row = db.execute("SELECT count(*) as cnt FROM t").fetchone() + assert row["cnt"] == 1 + + +def test_insert_multiple(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=int8)" + ")" + ) + for i in range(10): + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (?, ?)", + [i + 1, float_vec([float(i)] * 8)], + ) + row = db.execute("SELECT count(*) as cnt FROM t").fetchone() + assert row["cnt"] == 10 + + +# ============================================================================ +# Delete tests +# ============================================================================ + + +def test_delete_single(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)]) + db.execute("DELETE FROM t WHERE rowid = 1") + row = db.execute("SELECT count(*) as cnt FROM t").fetchone() + assert row["cnt"] == 0 + + +def test_delete_and_reinsert(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)]) + db.execute("DELETE FROM t WHERE rowid = 1") + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec([2.0] * 8)] + ) + row = db.execute("SELECT count(*) as cnt FROM t").fetchone() + assert row["cnt"] == 1 + + +def test_point_query_returns_float(db): + """SELECT by rowid should return the original float vector, not quantized.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + vals = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] + db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec(vals)]) + row = db.execute("SELECT embedding FROM t WHERE rowid = 1").fetchone() + result = unpack_float_vec(row["embedding"]) + for a, b in zip(result, vals): + assert abs(a - b) < 1e-6 + + +# ============================================================================ +# KNN tests +# ============================================================================ + + +def test_knn_basic_bit(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + # Insert vectors where [1,0,0,...] is closest to query [1,0,0,...] + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (1, ?)", + [float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (2, ?)", + [float_vec([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (3, ?)", + [float_vec([0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ) + rows = db.execute( + "SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1", + [float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ).fetchall() + assert len(rows) == 1 + assert rows[0]["rowid"] == 1 + + +def test_knn_basic_int8(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=int8)" + ")" + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (1, ?)", + [float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (2, ?)", + [float_vec([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (3, ?)", + [float_vec([0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ) + rows = db.execute( + "SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1", + [float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ).fetchall() + assert len(rows) == 1 + assert rows[0]["rowid"] == 1 + + +def test_knn_returns_float_distances(db): + """KNN should return float-precision distances, not quantized distances.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + v1 = [1.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + v2 = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] + db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec(v1)]) + db.execute("INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec(v2)]) + + query = [1.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + rows = db.execute( + "SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 2", + [float_vec(query)], + ).fetchall() + + # First result should be exact match with distance ~0 + assert rows[0]["rowid"] == 1 + assert rows[0]["distance"] < 0.01 + + # Second result should have a float distance + # sqrt((1-0)^2 + (0.5-0)^2 + (0-1)^2) = sqrt(2.25) = 1.5 + assert abs(rows[1]["distance"] - 1.5) < 0.01 + + +def test_knn_recall(db): + """With enough vectors, rescore should achieve good recall (>0.9).""" + dim = 32 + n = 1000 + k = 10 + random.seed(42) + + db.execute( + "CREATE VIRTUAL TABLE t_rescore USING vec0(" + f" embedding float[{dim}] indexed by rescore(quantizer=bit, oversample=16)" + ")" + ) + db.execute( + f"CREATE VIRTUAL TABLE t_flat USING vec0(embedding float[{dim}])" + ) + + vectors = [[random.gauss(0, 1) for _ in range(dim)] for _ in range(n)] + for i, v in enumerate(vectors): + blob = float_vec(v) + db.execute( + "INSERT INTO t_rescore(rowid, embedding) VALUES (?, ?)", [i + 1, blob] + ) + db.execute( + "INSERT INTO t_flat(rowid, embedding) VALUES (?, ?)", [i + 1, blob] + ) + + query = float_vec([random.gauss(0, 1) for _ in range(dim)]) + + rescore_rows = db.execute( + "SELECT rowid FROM t_rescore WHERE embedding MATCH ? ORDER BY distance LIMIT ?", + [query, k], + ).fetchall() + flat_rows = db.execute( + "SELECT rowid FROM t_flat WHERE embedding MATCH ? ORDER BY distance LIMIT ?", + [query, k], + ).fetchall() + + rescore_ids = {r["rowid"] for r in rescore_rows} + flat_ids = {r["rowid"] for r in flat_rows} + recall = len(rescore_ids & flat_ids) / k + assert recall >= 0.7, f"Recall too low: {recall}" + + +def test_knn_cosine(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] distance_metric=cosine indexed by rescore(quantizer=bit)" + ")" + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (1, ?)", + [float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (2, ?)", + [float_vec([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ) + rows = db.execute( + "SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1", + [float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ).fetchall() + assert rows[0]["rowid"] == 1 + # cosine distance of identical vectors should be ~0 + assert rows[0]["distance"] < 0.01 + + +def test_knn_empty_table(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + rows = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 5", + [float_vec([1.0] * 8)], + ).fetchall() + assert len(rows) == 0 + + +def test_knn_k_larger_than_n(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)]) + db.execute("INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec([2.0] * 8)]) + rows = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10", + [float_vec([1.0] * 8)], + ).fetchall() + assert len(rows) == 2 + + +# ============================================================================ +# Integration / edge case tests +# ============================================================================ + + +def test_knn_with_rowid_in(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + for i in range(5): + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (?, ?)", + [i + 1, float_vec([float(i)] * 8)], + ) + # Only search within rowids 1, 3, 5 + rows = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? AND rowid IN (1, 3, 5) ORDER BY distance LIMIT 3", + [float_vec([0.0] * 8)], + ).fetchall() + result_ids = {r["rowid"] for r in rows} + assert result_ids <= {1, 3, 5} + + +def test_knn_after_deletes(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=int8)" + ")" + ) + for i in range(10): + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (?, ?)", + [i + 1, float_vec([float(i)] * 8)], + ) + # Delete the closest match (rowid 1 = [0,0,...]) + db.execute("DELETE FROM t WHERE rowid = 1") + rows = db.execute( + "SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 5", + [float_vec([0.0] * 8)], + ).fetchall() + # Verify ordering: rowid 2 ([1]*8) should be closest, then 3 ([2]*8), etc. + assert len(rows) >= 2 + assert rows[0]["distance"] <= rows[1]["distance"] + # rowid 2 = [1,1,...] → L2 = sqrt(8) ≈ 2.83, rowid 3 = [2,2,...] → L2 = sqrt(32) ≈ 5.66 + assert rows[0]["rowid"] == 2, f"Expected rowid 2, got {rows[0]['rowid']} with dist={rows[0]['distance']}" + + +def test_oversample_effect(db): + """Higher oversample should give equal or better recall.""" + dim = 32 + n = 500 + k = 10 + random.seed(123) + + vectors = [[random.gauss(0, 1) for _ in range(dim)] for _ in range(n)] + query = float_vec([random.gauss(0, 1) for _ in range(dim)]) + + recalls = [] + for oversample in [2, 16]: + tname = f"t_os{oversample}" + db.execute( + f"CREATE VIRTUAL TABLE {tname} USING vec0(" + f" embedding float[{dim}] indexed by rescore(quantizer=bit, oversample={oversample})" + ")" + ) + for i, v in enumerate(vectors): + db.execute( + f"INSERT INTO {tname}(rowid, embedding) VALUES (?, ?)", + [i + 1, float_vec(v)], + ) + rows = db.execute( + f"SELECT rowid FROM {tname} WHERE embedding MATCH ? ORDER BY distance LIMIT ?", + [query, k], + ).fetchall() + recalls.append({r["rowid"] for r in rows}) + + # Also get ground truth + db.execute(f"CREATE VIRTUAL TABLE t_flat USING vec0(embedding float[{dim}])") + for i, v in enumerate(vectors): + db.execute( + "INSERT INTO t_flat(rowid, embedding) VALUES (?, ?)", + [i + 1, float_vec(v)], + ) + gt_rows = db.execute( + "SELECT rowid FROM t_flat WHERE embedding MATCH ? ORDER BY distance LIMIT ?", + [query, k], + ).fetchall() + gt_ids = {r["rowid"] for r in gt_rows} + + recall_low = len(recalls[0] & gt_ids) / k + recall_high = len(recalls[1] & gt_ids) / k + assert recall_high >= recall_low + + +def test_multiple_vector_columns(db): + """One column with rescore, one without.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " v1 float[8] indexed by rescore(quantizer=bit)," + " v2 float[8]" + ")" + ) + db.execute( + "INSERT INTO t(rowid, v1, v2) VALUES (1, ?, ?)", + [float_vec([1.0] * 8), float_vec([0.0] * 8)], + ) + db.execute( + "INSERT INTO t(rowid, v1, v2) VALUES (2, ?, ?)", + [float_vec([0.0] * 8), float_vec([1.0] * 8)], + ) + + # KNN on v1 (rescore path) + rows = db.execute( + "SELECT rowid FROM t WHERE v1 MATCH ? ORDER BY distance LIMIT 1", + [float_vec([1.0] * 8)], + ).fetchall() + assert rows[0]["rowid"] == 1 + + # KNN on v2 (normal path) + rows = db.execute( + "SELECT rowid FROM t WHERE v2 MATCH ? ORDER BY distance LIMIT 1", + [float_vec([1.0] * 8)], + ).fetchall() + assert rows[0]["rowid"] == 2 diff --git a/tests/test-unit.c b/tests/test-unit.c index 9eb8704..b180625 100644 --- a/tests/test-unit.c +++ b/tests/test-unit.c @@ -760,6 +760,202 @@ void test_distance_hamming() { printf(" All distance_hamming tests passed.\n"); } +#ifdef SQLITE_VEC_ENABLE_RESCORE + +void test_rescore_quantize_float_to_bit() { + printf("Starting %s...\n", __func__); + uint8_t dst[16]; + + // All positive -> all bits 1 + { + float src[8] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + memset(dst, 0, sizeof(dst)); + _test_rescore_quantize_float_to_bit(src, dst, 8); + assert(dst[0] == 0xFF); + } + + // All negative -> all bits 0 + { + float src[8] = {-1.0f, -2.0f, -3.0f, -4.0f, -5.0f, -6.0f, -7.0f, -8.0f}; + memset(dst, 0xFF, sizeof(dst)); + _test_rescore_quantize_float_to_bit(src, dst, 8); + assert(dst[0] == 0x00); + } + + // Alternating positive/negative + { + float src[8] = {1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f}; + _test_rescore_quantize_float_to_bit(src, dst, 8); + // bits 0,2,4,6 set => 0b01010101 = 0x55 + assert(dst[0] == 0x55); + } + + // Zero values -> bit is set (>= 0.0f) + { + float src[8] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + _test_rescore_quantize_float_to_bit(src, dst, 8); + assert(dst[0] == 0xFF); + } + + // 128 dimensions -> 16 bytes output + { + float src[128]; + for (int i = 0; i < 128; i++) src[i] = (i % 2 == 0) ? 1.0f : -1.0f; + memset(dst, 0, 16); + _test_rescore_quantize_float_to_bit(src, dst, 128); + // Even indices set: bits 0,2,4,6 in each byte => 0x55 + for (int i = 0; i < 16; i++) { + assert(dst[i] == 0x55); + } + } + + printf(" All rescore_quantize_float_to_bit tests passed.\n"); +} + +void test_rescore_quantize_float_to_int8() { + printf("Starting %s...\n", __func__); + int8_t dst[256]; + + // Uniform vector -> all zeros (range=0) + { + float src[8] = {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f}; + _test_rescore_quantize_float_to_int8(src, dst, 8); + for (int i = 0; i < 8; i++) { + assert(dst[i] == 0); + } + } + + // [0.0, 1.0] -> should map to [-128, 127] + { + float src[2] = {0.0f, 1.0f}; + _test_rescore_quantize_float_to_int8(src, dst, 2); + assert(dst[0] == -128); + assert(dst[1] == 127); + } + + // [-1.0, 0.0] -> should map to [-128, 127] + { + float src[2] = {-1.0f, 0.0f}; + _test_rescore_quantize_float_to_int8(src, dst, 2); + assert(dst[0] == -128); + assert(dst[1] == 127); + } + + // Single-element: range=0 -> 0 + { + float src[1] = {42.0f}; + _test_rescore_quantize_float_to_int8(src, dst, 1); + assert(dst[0] == 0); + } + + // Verify range: all outputs in [-128, 127], min near -128, max near 127 + { + float src[4] = {-100.0f, 0.0f, 100.0f, 50.0f}; + _test_rescore_quantize_float_to_int8(src, dst, 4); + for (int i = 0; i < 4; i++) { + assert(dst[i] >= -128 && dst[i] <= 127); + } + // Min maps to -128 (exact), max maps to ~127 (may lose 1 to float rounding) + assert(dst[0] == -128); + assert(dst[2] >= 126 && dst[2] <= 127); + // Middle value (50) should be positive + assert(dst[3] > 0); + } + + printf(" All rescore_quantize_float_to_int8 tests passed.\n"); +} + +void test_rescore_quantized_byte_size() { + printf("Starting %s...\n", __func__); + + // Bit quantizer: dims/8 + assert(_test_rescore_quantized_byte_size_bit(128) == 16); + assert(_test_rescore_quantized_byte_size_bit(8) == 1); + assert(_test_rescore_quantized_byte_size_bit(1024) == 128); + + // Int8 quantizer: dims + assert(_test_rescore_quantized_byte_size_int8(128) == 128); + assert(_test_rescore_quantized_byte_size_int8(8) == 8); + assert(_test_rescore_quantized_byte_size_int8(1024) == 1024); + + printf(" All rescore_quantized_byte_size tests passed.\n"); +} + +void test_vec0_parse_vector_column_rescore() { + printf("Starting %s...\n", __func__); + struct VectorColumnDefinition col; + int rc; + + // Basic bit quantizer + { + const char *input = "emb float[128] indexed by rescore(quantizer=bit)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.index_type == VEC0_INDEX_TYPE_RESCORE); + assert(col.rescore.quantizer_type == VEC0_RESCORE_QUANTIZER_BIT); + assert(col.rescore.oversample == 8); // default + assert(col.dimensions == 128); + sqlite3_free(col.name); + } + + // Int8 quantizer + { + const char *input = "emb float[128] indexed by rescore(quantizer=int8)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.index_type == VEC0_INDEX_TYPE_RESCORE); + assert(col.rescore.quantizer_type == VEC0_RESCORE_QUANTIZER_INT8); + sqlite3_free(col.name); + } + + // Bit quantizer with oversample + { + const char *input = "emb float[128] indexed by rescore(quantizer=bit, oversample=16)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.index_type == VEC0_INDEX_TYPE_RESCORE); + assert(col.rescore.quantizer_type == VEC0_RESCORE_QUANTIZER_BIT); + assert(col.rescore.oversample == 16); + sqlite3_free(col.name); + } + + // Error: non-float element type + { + const char *input = "emb int8[128] indexed by rescore(quantizer=bit)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_ERROR); + } + + // Error: dims not divisible by 8 for bit quantizer + { + const char *input = "emb float[100] indexed by rescore(quantizer=bit)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_ERROR); + } + + // Error: missing quantizer + { + const char *input = "emb float[128] indexed by rescore(oversample=8)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_ERROR); + } + + // With distance_metric=cosine + { + const char *input = "emb float[128] distance_metric=cosine indexed by rescore(quantizer=int8)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.index_type == VEC0_INDEX_TYPE_RESCORE); + assert(col.distance_metric == VEC0_DISTANCE_METRIC_COSINE); + assert(col.rescore.quantizer_type == VEC0_RESCORE_QUANTIZER_INT8); + sqlite3_free(col.name); + } + + printf(" All vec0_parse_vector_column_rescore tests passed.\n"); +} + +#endif /* SQLITE_VEC_ENABLE_RESCORE */ + int main() { printf("Starting unit tests...\n"); #ifdef SQLITE_VEC_ENABLE_AVX @@ -768,6 +964,9 @@ int main() { #ifdef SQLITE_VEC_ENABLE_NEON printf("SQLITE_VEC_ENABLE_NEON=1\n"); #endif +#ifdef SQLITE_VEC_ENABLE_RESCORE + printf("SQLITE_VEC_ENABLE_RESCORE=1\n"); +#endif #if !defined(SQLITE_VEC_ENABLE_AVX) && !defined(SQLITE_VEC_ENABLE_NEON) printf("SIMD: none\n"); #endif @@ -778,5 +977,11 @@ int main() { test_distance_l2_sqr_float(); test_distance_cosine_float(); test_distance_hamming(); +#ifdef SQLITE_VEC_ENABLE_RESCORE + test_rescore_quantize_float_to_bit(); + test_rescore_quantize_float_to_int8(); + test_rescore_quantized_byte_size(); + test_vec0_parse_vector_column_rescore(); +#endif printf("All unit tests passed.\n"); }