mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 08:46:49 +02:00
Add rescore index for ANN queries
Add rescore index type: stores full-precision float vectors in a rowid-keyed shadow table, quantizes to int8 for fast initial scan, then rescores top candidates with original vectors. Includes config parser, shadow table management, insert/delete support, KNN integration, compile flag (SQLITE_VEC_ENABLE_RESCORE), fuzz targets, and tests.
This commit is contained in:
parent
bf2455f2ba
commit
ba0db0b6d6
19 changed files with 3378 additions and 8 deletions
5
tests/fuzz/.gitignore
vendored
5
tests/fuzz/.gitignore
vendored
|
|
@ -1,2 +1,7 @@
|
|||
*.dSYM
|
||||
targets/
|
||||
corpus/
|
||||
crash-*
|
||||
leak-*
|
||||
timeout-*
|
||||
*.log
|
||||
|
|
|
|||
|
|
@ -72,10 +72,34 @@ $(TARGET_DIR)/vec_mismatch: vec-mismatch.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
|||
$(TARGET_DIR)/vec0_delete_completeness: vec0-delete-completeness.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/rescore_operations: rescore-operations.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/rescore_create: rescore-create.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/rescore_quantize: rescore-quantize.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE -DSQLITE_VEC_TEST $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/rescore_shadow_corrupt: rescore-shadow-corrupt.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/rescore_knn_deep: rescore-knn-deep.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/rescore_quantize_edge: rescore-quantize-edge.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE -DSQLITE_VEC_TEST $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/rescore_interleave: rescore-interleave.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
FUZZ_TARGETS = vec0_create exec json numpy \
|
||||
shadow_corrupt vec0_operations scalar_functions \
|
||||
vec0_create_full metadata_columns vec_each vec_mismatch \
|
||||
vec0_delete_completeness
|
||||
vec0_delete_completeness \
|
||||
rescore_operations rescore_create rescore_quantize \
|
||||
rescore_shadow_corrupt rescore_knn_deep \
|
||||
rescore_quantize_edge rescore_interleave
|
||||
|
||||
all: $(addprefix $(TARGET_DIR)/,$(FUZZ_TARGETS))
|
||||
|
||||
|
|
|
|||
36
tests/fuzz/rescore-create.c
Normal file
36
tests/fuzz/rescore-create.c
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
int rc = SQLITE_OK;
|
||||
sqlite3 *db;
|
||||
sqlite3_stmt *stmt;
|
||||
|
||||
rc = sqlite3_open(":memory:", &db);
|
||||
assert(rc == SQLITE_OK);
|
||||
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
|
||||
sqlite3_str *s = sqlite3_str_new(NULL);
|
||||
assert(s);
|
||||
sqlite3_str_appendall(s, "CREATE VIRTUAL TABLE v USING vec0(emb float[128] indexed by rescore(");
|
||||
sqlite3_str_appendf(s, "%.*s", (int)size, data);
|
||||
sqlite3_str_appendall(s, "))");
|
||||
const char *zSql = sqlite3_str_finish(s);
|
||||
assert(zSql);
|
||||
|
||||
rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, NULL);
|
||||
sqlite3_free((void *)zSql);
|
||||
if (rc == SQLITE_OK) {
|
||||
sqlite3_step(stmt);
|
||||
}
|
||||
sqlite3_finalize(stmt);
|
||||
sqlite3_close(db);
|
||||
return 0;
|
||||
}
|
||||
20
tests/fuzz/rescore-create.dict
Normal file
20
tests/fuzz/rescore-create.dict
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
"rescore"
|
||||
"quantizer"
|
||||
"bit"
|
||||
"int8"
|
||||
"oversample"
|
||||
"indexed"
|
||||
"by"
|
||||
"float"
|
||||
"("
|
||||
")"
|
||||
","
|
||||
"="
|
||||
"["
|
||||
"]"
|
||||
"1"
|
||||
"8"
|
||||
"16"
|
||||
"128"
|
||||
"256"
|
||||
"1024"
|
||||
151
tests/fuzz/rescore-interleave.c
Normal file
151
tests/fuzz/rescore-interleave.c
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
/**
|
||||
* Fuzz target: interleaved insert/update/delete/KNN operations on rescore
|
||||
* tables with BOTH quantizer types, exercising the int8 quantizer path
|
||||
* and the update code path that the existing rescore-operations.c misses.
|
||||
*
|
||||
* Key differences from rescore-operations.c:
|
||||
* - Tests BOTH bit and int8 quantizers (the existing target only tests bit)
|
||||
* - Fuzz-controlled query vectors (not fixed [1,0,0,...])
|
||||
* - Exercises the UPDATE path (line 9080+ in sqlite-vec.c)
|
||||
* - Tests with 16 dimensions (more realistic, exercises more of the
|
||||
* quantization loop)
|
||||
* - Interleaves KNN between mutations to stress the blob_reopen path
|
||||
* when _rescore_vectors rows have been deleted/modified
|
||||
*/
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
if (size < 8) return 0;
|
||||
|
||||
int rc;
|
||||
sqlite3 *db;
|
||||
sqlite3_stmt *stmtInsert = NULL;
|
||||
sqlite3_stmt *stmtUpdate = NULL;
|
||||
sqlite3_stmt *stmtDelete = NULL;
|
||||
sqlite3_stmt *stmtKnn = NULL;
|
||||
|
||||
rc = sqlite3_open(":memory:", &db);
|
||||
assert(rc == SQLITE_OK);
|
||||
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
|
||||
/* Use first byte to pick quantizer */
|
||||
int use_int8 = data[0] & 1;
|
||||
data++; size--;
|
||||
|
||||
const char *create_sql = use_int8
|
||||
? "CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[16] indexed by rescore(quantizer=int8))"
|
||||
: "CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[16] indexed by rescore(quantizer=bit))";
|
||||
|
||||
rc = sqlite3_exec(db, create_sql, NULL, NULL, NULL);
|
||||
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||
sqlite3_prepare_v2(db,
|
||||
"UPDATE v SET emb = ? WHERE rowid = ?", -1, &stmtUpdate, NULL);
|
||||
sqlite3_prepare_v2(db,
|
||||
"DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
|
||||
sqlite3_prepare_v2(db,
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? "
|
||||
"ORDER BY distance LIMIT 5", -1, &stmtKnn, NULL);
|
||||
|
||||
if (!stmtInsert || !stmtUpdate || !stmtDelete || !stmtKnn)
|
||||
goto cleanup;
|
||||
|
||||
size_t i = 0;
|
||||
while (i + 2 <= size) {
|
||||
uint8_t op = data[i++] % 5; /* 5 operations now */
|
||||
uint8_t rowid_byte = data[i++];
|
||||
int64_t rowid = (int64_t)(rowid_byte % 24) + 1;
|
||||
|
||||
switch (op) {
|
||||
case 0: {
|
||||
/* INSERT: consume bytes for 16 floats */
|
||||
float vec[16] = {0};
|
||||
for (int j = 0; j < 16 && i < size; j++, i++) {
|
||||
vec[j] = (float)((int8_t)data[i]) / 8.0f;
|
||||
}
|
||||
sqlite3_reset(stmtInsert);
|
||||
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmtInsert);
|
||||
break;
|
||||
}
|
||||
case 1: {
|
||||
/* DELETE */
|
||||
sqlite3_reset(stmtDelete);
|
||||
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||
sqlite3_step(stmtDelete);
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
/* KNN with fuzz-controlled query vector */
|
||||
float qvec[16] = {0};
|
||||
for (int j = 0; j < 16 && i < size; j++, i++) {
|
||||
qvec[j] = (float)((int8_t)data[i]) / 4.0f;
|
||||
}
|
||||
sqlite3_reset(stmtKnn);
|
||||
sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {
|
||||
(void)sqlite3_column_int64(stmtKnn, 0);
|
||||
(void)sqlite3_column_double(stmtKnn, 1);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 3: {
|
||||
/* UPDATE: modify an existing vector (exercises rescore update path) */
|
||||
float vec[16] = {0};
|
||||
for (int j = 0; j < 16 && i < size; j++, i++) {
|
||||
vec[j] = (float)((int8_t)data[i]) / 6.0f;
|
||||
}
|
||||
sqlite3_reset(stmtUpdate);
|
||||
sqlite3_bind_blob(stmtUpdate, 1, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||
sqlite3_bind_int64(stmtUpdate, 2, rowid);
|
||||
sqlite3_step(stmtUpdate);
|
||||
break;
|
||||
}
|
||||
case 4: {
|
||||
/* INSERT then immediately UPDATE same row (stresses blob lifecycle) */
|
||||
float vec1[16] = {0};
|
||||
float vec2[16] = {0};
|
||||
for (int j = 0; j < 16 && i < size; j++, i++) {
|
||||
vec1[j] = (float)((int8_t)data[i]) / 10.0f;
|
||||
vec2[j] = -vec1[j]; /* opposite direction */
|
||||
}
|
||||
/* Insert */
|
||||
sqlite3_reset(stmtInsert);
|
||||
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||
sqlite3_bind_blob(stmtInsert, 2, vec1, sizeof(vec1), SQLITE_TRANSIENT);
|
||||
if (sqlite3_step(stmtInsert) == SQLITE_DONE) {
|
||||
/* Only update if insert succeeded (rowid might already exist) */
|
||||
sqlite3_reset(stmtUpdate);
|
||||
sqlite3_bind_blob(stmtUpdate, 1, vec2, sizeof(vec2), SQLITE_TRANSIENT);
|
||||
sqlite3_bind_int64(stmtUpdate, 2, rowid);
|
||||
sqlite3_step(stmtUpdate);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Final consistency check: full scan must not crash */
|
||||
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||
|
||||
cleanup:
|
||||
sqlite3_finalize(stmtInsert);
|
||||
sqlite3_finalize(stmtUpdate);
|
||||
sqlite3_finalize(stmtDelete);
|
||||
sqlite3_finalize(stmtKnn);
|
||||
sqlite3_close(db);
|
||||
return 0;
|
||||
}
|
||||
178
tests/fuzz/rescore-knn-deep.c
Normal file
178
tests/fuzz/rescore-knn-deep.c
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
/**
|
||||
* Fuzz target: deep exercise of rescore KNN with fuzz-controlled query vectors
|
||||
* and both quantizer types (bit + int8), multiple distance metrics.
|
||||
*
|
||||
* The existing rescore-operations.c only tests bit quantizer with a fixed
|
||||
* query vector. This target:
|
||||
* - Tests both bit and int8 quantizers
|
||||
* - Uses fuzz-controlled query vectors (hits NaN/Inf/denormal paths)
|
||||
* - Tests all distance metrics with int8 (L2, cosine, L1)
|
||||
* - Exercises large LIMIT values (oversample multiplication)
|
||||
* - Tests KNN with rowid IN constraints
|
||||
* - Exercises the insert->query->update->query->delete->query cycle
|
||||
*/
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
if (size < 20) return 0;
|
||||
|
||||
int rc;
|
||||
sqlite3 *db;
|
||||
|
||||
rc = sqlite3_open(":memory:", &db);
|
||||
assert(rc == SQLITE_OK);
|
||||
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
|
||||
/* Use first 4 bytes for configuration */
|
||||
uint8_t config = data[0];
|
||||
uint8_t num_inserts = (data[1] % 20) + 3; /* 3..22 inserts */
|
||||
uint8_t limit_val = (data[2] % 50) + 1; /* 1..50 for LIMIT */
|
||||
uint8_t metric_choice = data[3] % 3;
|
||||
data += 4;
|
||||
size -= 4;
|
||||
|
||||
int use_int8 = config & 1;
|
||||
const char *metric_str;
|
||||
switch (metric_choice) {
|
||||
case 0: metric_str = ""; break; /* default L2 */
|
||||
case 1: metric_str = " distance_metric=cosine"; break;
|
||||
case 2: metric_str = " distance_metric=l1"; break;
|
||||
default: metric_str = ""; break;
|
||||
}
|
||||
|
||||
/* Build CREATE TABLE statement */
|
||||
char create_sql[256];
|
||||
if (use_int8) {
|
||||
snprintf(create_sql, sizeof(create_sql),
|
||||
"CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[16] indexed by rescore(quantizer=int8)%s)", metric_str);
|
||||
} else {
|
||||
/* bit quantizer ignores distance_metric for the coarse pass (always hamming),
|
||||
but the float rescore phase uses the specified metric */
|
||||
snprintf(create_sql, sizeof(create_sql),
|
||||
"CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[16] indexed by rescore(quantizer=bit)%s)", metric_str);
|
||||
}
|
||||
|
||||
rc = sqlite3_exec(db, create_sql, NULL, NULL, NULL);
|
||||
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||
|
||||
/* Insert vectors using fuzz data */
|
||||
{
|
||||
sqlite3_stmt *ins = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &ins, NULL);
|
||||
if (!ins) { sqlite3_close(db); return 0; }
|
||||
|
||||
size_t cursor = 0;
|
||||
for (int i = 0; i < num_inserts && cursor + 1 < size; i++) {
|
||||
float vec[16];
|
||||
for (int j = 0; j < 16; j++) {
|
||||
if (cursor < size) {
|
||||
/* Map fuzz byte to float -- includes potential for
|
||||
interesting float values via reinterpretation */
|
||||
int8_t sb = (int8_t)data[cursor++];
|
||||
vec[j] = (float)sb / 5.0f;
|
||||
} else {
|
||||
vec[j] = 0.0f;
|
||||
}
|
||||
}
|
||||
sqlite3_reset(ins);
|
||||
sqlite3_bind_int64(ins, 1, (sqlite3_int64)(i + 1));
|
||||
sqlite3_bind_blob(ins, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||
sqlite3_step(ins);
|
||||
}
|
||||
sqlite3_finalize(ins);
|
||||
}
|
||||
|
||||
/* Build a fuzz-controlled query vector from remaining data */
|
||||
float qvec[16] = {0};
|
||||
{
|
||||
size_t cursor = 0;
|
||||
for (int j = 0; j < 16 && cursor < size; j++) {
|
||||
int8_t sb = (int8_t)data[cursor++];
|
||||
qvec[j] = (float)sb / 3.0f;
|
||||
}
|
||||
}
|
||||
|
||||
/* KNN query with fuzz-controlled vector and LIMIT */
|
||||
{
|
||||
char knn_sql[256];
|
||||
snprintf(knn_sql, sizeof(knn_sql),
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? "
|
||||
"ORDER BY distance LIMIT %d", (int)limit_val);
|
||||
|
||||
sqlite3_stmt *knn = NULL;
|
||||
sqlite3_prepare_v2(db, knn_sql, -1, &knn, NULL);
|
||||
if (knn) {
|
||||
sqlite3_bind_blob(knn, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||
while (sqlite3_step(knn) == SQLITE_ROW) {
|
||||
/* Read results to ensure distance computation didn't produce garbage
|
||||
that crashes the cursor iteration */
|
||||
(void)sqlite3_column_int64(knn, 0);
|
||||
(void)sqlite3_column_double(knn, 1);
|
||||
}
|
||||
sqlite3_finalize(knn);
|
||||
}
|
||||
}
|
||||
|
||||
/* Update some vectors, then query again */
|
||||
{
|
||||
float uvec[16];
|
||||
for (int j = 0; j < 16; j++) uvec[j] = qvec[15 - j]; /* reverse of query */
|
||||
sqlite3_stmt *upd = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"UPDATE v SET emb = ? WHERE rowid = 1", -1, &upd, NULL);
|
||||
if (upd) {
|
||||
sqlite3_bind_blob(upd, 1, uvec, sizeof(uvec), SQLITE_STATIC);
|
||||
sqlite3_step(upd);
|
||||
sqlite3_finalize(upd);
|
||||
}
|
||||
}
|
||||
|
||||
/* Second KNN after update */
|
||||
{
|
||||
sqlite3_stmt *knn = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? "
|
||||
"ORDER BY distance LIMIT 10", -1, &knn, NULL);
|
||||
if (knn) {
|
||||
sqlite3_bind_blob(knn, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||
while (sqlite3_step(knn) == SQLITE_ROW) {}
|
||||
sqlite3_finalize(knn);
|
||||
}
|
||||
}
|
||||
|
||||
/* Delete half the rows, then KNN again */
|
||||
for (int i = 1; i <= num_inserts; i += 2) {
|
||||
char del_sql[64];
|
||||
snprintf(del_sql, sizeof(del_sql),
|
||||
"DELETE FROM v WHERE rowid = %d", i);
|
||||
sqlite3_exec(db, del_sql, NULL, NULL, NULL);
|
||||
}
|
||||
|
||||
/* Third KNN after deletes -- exercises distance computation over
|
||||
zeroed-out slots in the quantized chunk */
|
||||
{
|
||||
sqlite3_stmt *knn = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? "
|
||||
"ORDER BY distance LIMIT 5", -1, &knn, NULL);
|
||||
if (knn) {
|
||||
sqlite3_bind_blob(knn, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||
while (sqlite3_step(knn) == SQLITE_ROW) {}
|
||||
sqlite3_finalize(knn);
|
||||
}
|
||||
}
|
||||
|
||||
sqlite3_close(db);
|
||||
return 0;
|
||||
}
|
||||
96
tests/fuzz/rescore-operations.c
Normal file
96
tests/fuzz/rescore-operations.c
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
if (size < 6) return 0;
|
||||
|
||||
int rc;
|
||||
sqlite3 *db;
|
||||
sqlite3_stmt *stmtInsert = NULL;
|
||||
sqlite3_stmt *stmtDelete = NULL;
|
||||
sqlite3_stmt *stmtKnn = NULL;
|
||||
sqlite3_stmt *stmtScan = NULL;
|
||||
|
||||
rc = sqlite3_open(":memory:", &db);
|
||||
assert(rc == SQLITE_OK);
|
||||
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
|
||||
rc = sqlite3_exec(db,
|
||||
"CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[8] indexed by rescore(quantizer=bit))",
|
||||
NULL, NULL, NULL);
|
||||
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||
sqlite3_prepare_v2(db,
|
||||
"DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
|
||||
sqlite3_prepare_v2(db,
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? ORDER BY distance LIMIT 3",
|
||||
-1, &stmtKnn, NULL);
|
||||
sqlite3_prepare_v2(db,
|
||||
"SELECT rowid FROM v", -1, &stmtScan, NULL);
|
||||
|
||||
if (!stmtInsert || !stmtDelete || !stmtKnn || !stmtScan) goto cleanup;
|
||||
|
||||
size_t i = 0;
|
||||
while (i + 2 <= size) {
|
||||
uint8_t op = data[i++] % 4;
|
||||
uint8_t rowid_byte = data[i++];
|
||||
int64_t rowid = (int64_t)(rowid_byte % 32) + 1;
|
||||
|
||||
switch (op) {
|
||||
case 0: {
|
||||
// INSERT: consume 32 bytes for 8 floats, or use what's left
|
||||
float vec[8] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
|
||||
for (int j = 0; j < 8 && i < size; j++, i++) {
|
||||
vec[j] = (float)((int8_t)data[i]) / 10.0f;
|
||||
}
|
||||
sqlite3_reset(stmtInsert);
|
||||
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmtInsert);
|
||||
break;
|
||||
}
|
||||
case 1: {
|
||||
// DELETE
|
||||
sqlite3_reset(stmtDelete);
|
||||
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||
sqlite3_step(stmtDelete);
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
// KNN query with a fixed query vector
|
||||
float qvec[8] = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
|
||||
sqlite3_reset(stmtKnn);
|
||||
sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||
break;
|
||||
}
|
||||
case 3: {
|
||||
// Full scan
|
||||
sqlite3_reset(stmtScan);
|
||||
while (sqlite3_step(stmtScan) == SQLITE_ROW) {}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Final operations -- must not crash regardless of prior state
|
||||
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||
|
||||
cleanup:
|
||||
sqlite3_finalize(stmtInsert);
|
||||
sqlite3_finalize(stmtDelete);
|
||||
sqlite3_finalize(stmtKnn);
|
||||
sqlite3_finalize(stmtScan);
|
||||
sqlite3_close(db);
|
||||
return 0;
|
||||
}
|
||||
177
tests/fuzz/rescore-quantize-edge.c
Normal file
177
tests/fuzz/rescore-quantize-edge.c
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <math.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
/* Test wrappers from sqlite-vec-rescore.c (SQLITE_VEC_TEST build) */
|
||||
extern void _test_rescore_quantize_float_to_bit(const float *src, uint8_t *dst, size_t dim);
|
||||
extern void _test_rescore_quantize_float_to_int8(const float *src, int8_t *dst, size_t dim);
|
||||
extern size_t _test_rescore_quantized_byte_size_bit(size_t dimensions);
|
||||
extern size_t _test_rescore_quantized_byte_size_int8(size_t dimensions);
|
||||
|
||||
/**
|
||||
* Fuzz target: edge cases in rescore quantization functions.
|
||||
*
|
||||
* The existing rescore-quantize.c only tests dimensions that are multiples of 8
|
||||
* and never passes special float values. This target:
|
||||
*
|
||||
* - Tests rescore_quantized_byte_size with arbitrary dimension values
|
||||
* (including 0, 1, 7, MAX values -- looking for integer division issues)
|
||||
* - Passes raw float reinterpretation of fuzz bytes (NaN, Inf, denormals,
|
||||
* negative zero -- these are the values that break min/max/range logic)
|
||||
* - Tests the int8 quantizer with all-identical values (range=0 branch)
|
||||
* - Tests the int8 quantizer with extreme ranges (overflow in scale calc)
|
||||
* - Tests bit quantizer with exact float threshold (0.0f boundary)
|
||||
*/
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
if (size < 8) return 0;
|
||||
|
||||
uint8_t mode = data[0] % 5;
|
||||
data++; size--;
|
||||
|
||||
switch (mode) {
|
||||
case 0: {
|
||||
/* Test rescore_quantized_byte_size with fuzz-controlled dimensions.
|
||||
This function does dimensions / CHAR_BIT for bit, dimensions for int8.
|
||||
We're checking it doesn't do anything weird with edge values. */
|
||||
if (size < sizeof(size_t)) return 0;
|
||||
size_t dim;
|
||||
memcpy(&dim, data, sizeof(dim));
|
||||
|
||||
/* These should never crash, just return values */
|
||||
size_t bit_size = _test_rescore_quantized_byte_size_bit(dim);
|
||||
size_t int8_size = _test_rescore_quantized_byte_size_int8(dim);
|
||||
|
||||
/* Verify basic invariants */
|
||||
(void)bit_size;
|
||||
(void)int8_size;
|
||||
break;
|
||||
}
|
||||
|
||||
case 1: {
|
||||
/* Bit quantize with raw reinterpreted floats (NaN, Inf, denormal).
|
||||
The key check: src[i] >= 0.0f -- NaN comparison is always false,
|
||||
so NaN should produce 0-bits. But denormals and -0.0f are tricky. */
|
||||
size_t num_floats = size / sizeof(float);
|
||||
if (num_floats == 0) return 0;
|
||||
/* Round to multiple of 8 for bit quantizer */
|
||||
size_t dim = (num_floats / 8) * 8;
|
||||
if (dim == 0) return 0;
|
||||
|
||||
const float *src = (const float *)data;
|
||||
size_t bit_bytes = dim / 8;
|
||||
uint8_t *dst = (uint8_t *)malloc(bit_bytes);
|
||||
if (!dst) return 0;
|
||||
|
||||
_test_rescore_quantize_float_to_bit(src, dst, dim);
|
||||
|
||||
/* Verify: for each bit, if src >= 0 then bit should be set */
|
||||
for (size_t i = 0; i < dim; i++) {
|
||||
int bit_set = (dst[i / 8] >> (i % 8)) & 1;
|
||||
if (src[i] >= 0.0f) {
|
||||
assert(bit_set == 1);
|
||||
} else if (src[i] < 0.0f) {
|
||||
/* Definitely negative -- bit must be 0 */
|
||||
assert(bit_set == 0);
|
||||
}
|
||||
/* NaN: comparison is false, so bit_set should be 0 */
|
||||
}
|
||||
|
||||
free(dst);
|
||||
break;
|
||||
}
|
||||
|
||||
case 2: {
|
||||
/* Int8 quantize with raw reinterpreted floats.
|
||||
The dangerous paths:
|
||||
- All values identical (range == 0) -> memset path
|
||||
- vmin/vmax with NaN (NaN < anything is false, NaN > anything is false)
|
||||
- Extreme range causing scale = 255/range to be Inf or 0
|
||||
- denormals near the clamping boundaries */
|
||||
size_t num_floats = size / sizeof(float);
|
||||
if (num_floats == 0) return 0;
|
||||
|
||||
const float *src = (const float *)data;
|
||||
int8_t *dst = (int8_t *)malloc(num_floats);
|
||||
if (!dst) return 0;
|
||||
|
||||
_test_rescore_quantize_float_to_int8(src, dst, num_floats);
|
||||
|
||||
/* Output must always be in [-128, 127] (trivially true for int8_t,
|
||||
but check the actual clamping logic worked) */
|
||||
for (size_t i = 0; i < num_floats; i++) {
|
||||
assert(dst[i] >= -128 && dst[i] <= 127);
|
||||
}
|
||||
|
||||
free(dst);
|
||||
break;
|
||||
}
|
||||
|
||||
case 3: {
|
||||
/* Int8 quantize stress: all-same values (range=0 branch) */
|
||||
size_t dim = (size < 64) ? size : 64;
|
||||
if (dim == 0) return 0;
|
||||
|
||||
float *src = (float *)malloc(dim * sizeof(float));
|
||||
int8_t *dst = (int8_t *)malloc(dim);
|
||||
if (!src || !dst) { free(src); free(dst); return 0; }
|
||||
|
||||
/* Fill with a single value derived from fuzz data */
|
||||
float val;
|
||||
memcpy(&val, data, sizeof(float) < size ? sizeof(float) : size);
|
||||
for (size_t i = 0; i < dim; i++) src[i] = val;
|
||||
|
||||
_test_rescore_quantize_float_to_int8(src, dst, dim);
|
||||
|
||||
/* All outputs should be 0 when range == 0 */
|
||||
for (size_t i = 0; i < dim; i++) {
|
||||
assert(dst[i] == 0);
|
||||
}
|
||||
|
||||
free(src);
|
||||
free(dst);
|
||||
break;
|
||||
}
|
||||
|
||||
case 4: {
|
||||
/* Int8 quantize with extreme range: one huge positive, one huge negative.
|
||||
Tests scale = 255.0f / range overflow to Inf, then v * Inf = Inf,
|
||||
then clamping to [-128, 127]. */
|
||||
if (size < 2 * sizeof(float)) return 0;
|
||||
|
||||
float extreme[2];
|
||||
memcpy(extreme, data, 2 * sizeof(float));
|
||||
|
||||
/* Only test if both are finite (NaN/Inf tested in case 2) */
|
||||
if (!isfinite(extreme[0]) || !isfinite(extreme[1])) return 0;
|
||||
|
||||
/* Build a vector with these two extreme values plus some fuzz */
|
||||
size_t dim = 16;
|
||||
float src[16];
|
||||
src[0] = extreme[0];
|
||||
src[1] = extreme[1];
|
||||
for (size_t i = 2; i < dim; i++) {
|
||||
if (2 * sizeof(float) + (i - 2) < size) {
|
||||
src[i] = (float)((int8_t)data[2 * sizeof(float) + (i - 2)]) * 1000.0f;
|
||||
} else {
|
||||
src[i] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
int8_t dst[16];
|
||||
_test_rescore_quantize_float_to_int8(src, dst, dim);
|
||||
|
||||
for (size_t i = 0; i < dim; i++) {
|
||||
assert(dst[i] >= -128 && dst[i] <= 127);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
54
tests/fuzz/rescore-quantize.c
Normal file
54
tests/fuzz/rescore-quantize.c
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
/* These are SQLITE_VEC_TEST wrappers defined in sqlite-vec-rescore.c */
|
||||
extern void _test_rescore_quantize_float_to_bit(const float *src, uint8_t *dst, size_t dim);
|
||||
extern void _test_rescore_quantize_float_to_int8(const float *src, int8_t *dst, size_t dim);
|
||||
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
/* Need at least 4 bytes for one float */
|
||||
if (size < 4) return 0;
|
||||
|
||||
/* Use the input as an array of floats. Dimensions must be a multiple of 8
|
||||
* for the bit quantizer. */
|
||||
size_t num_floats = size / sizeof(float);
|
||||
if (num_floats == 0) return 0;
|
||||
|
||||
/* Round down to multiple of 8 for bit quantizer compatibility */
|
||||
size_t dim = (num_floats / 8) * 8;
|
||||
if (dim == 0) dim = 8;
|
||||
if (dim > num_floats) return 0;
|
||||
|
||||
const float *src = (const float *)data;
|
||||
|
||||
/* Allocate output buffers */
|
||||
size_t bit_bytes = dim / 8;
|
||||
uint8_t *bit_dst = (uint8_t *)malloc(bit_bytes);
|
||||
int8_t *int8_dst = (int8_t *)malloc(dim);
|
||||
if (!bit_dst || !int8_dst) {
|
||||
free(bit_dst);
|
||||
free(int8_dst);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* Test bit quantization */
|
||||
_test_rescore_quantize_float_to_bit(src, bit_dst, dim);
|
||||
|
||||
/* Test int8 quantization */
|
||||
_test_rescore_quantize_float_to_int8(src, int8_dst, dim);
|
||||
|
||||
/* Verify int8 output is in range */
|
||||
for (size_t i = 0; i < dim; i++) {
|
||||
assert(int8_dst[i] >= -128 && int8_dst[i] <= 127);
|
||||
}
|
||||
|
||||
free(bit_dst);
|
||||
free(int8_dst);
|
||||
return 0;
|
||||
}
|
||||
230
tests/fuzz/rescore-shadow-corrupt.c
Normal file
230
tests/fuzz/rescore-shadow-corrupt.c
Normal file
|
|
@ -0,0 +1,230 @@
|
|||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
/**
|
||||
* Fuzz target: corrupt rescore shadow tables then exercise KNN/read/write.
|
||||
*
|
||||
* This targets the dangerous code paths in rescore_knn (Phase 1 + 2):
|
||||
* - sqlite3_blob_read into baseVectors with potentially wrong-sized blobs
|
||||
* - distance computation on corrupted/partial quantized data
|
||||
* - blob_reopen on _rescore_vectors with missing/corrupted rows
|
||||
* - insert/delete after corruption (blob_write to wrong offsets)
|
||||
*
|
||||
* The existing shadow-corrupt.c only tests vec0 without rescore.
|
||||
*/
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
if (size < 4) return 0;
|
||||
|
||||
int rc;
|
||||
sqlite3 *db;
|
||||
|
||||
rc = sqlite3_open(":memory:", &db);
|
||||
assert(rc == SQLITE_OK);
|
||||
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
|
||||
/* Pick quantizer type from first byte */
|
||||
int use_int8 = data[0] & 1;
|
||||
int target = (data[1] % 8);
|
||||
const uint8_t *payload = data + 2;
|
||||
int payload_size = (int)(size - 2);
|
||||
|
||||
const char *create_sql = use_int8
|
||||
? "CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[16] indexed by rescore(quantizer=int8))"
|
||||
: "CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[16] indexed by rescore(quantizer=bit))";
|
||||
|
||||
rc = sqlite3_exec(db, create_sql, NULL, NULL, NULL);
|
||||
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||
|
||||
/* Insert several vectors so there's a full chunk to corrupt */
|
||||
{
|
||||
sqlite3_stmt *ins = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &ins, NULL);
|
||||
if (!ins) { sqlite3_close(db); return 0; }
|
||||
|
||||
for (int i = 1; i <= 8; i++) {
|
||||
float vec[16];
|
||||
for (int j = 0; j < 16; j++) vec[j] = (float)(i * 10 + j) / 100.0f;
|
||||
sqlite3_reset(ins);
|
||||
sqlite3_bind_int64(ins, 1, i);
|
||||
sqlite3_bind_blob(ins, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||
sqlite3_step(ins);
|
||||
}
|
||||
sqlite3_finalize(ins);
|
||||
}
|
||||
|
||||
/* Now corrupt rescore shadow tables based on fuzz input */
|
||||
sqlite3_stmt *stmt = NULL;
|
||||
|
||||
switch (target) {
|
||||
case 0: {
|
||||
/* Corrupt _rescore_chunks00 vectors blob with fuzz data */
|
||||
rc = sqlite3_prepare_v2(db,
|
||||
"UPDATE v_rescore_chunks00 SET vectors = ? WHERE rowid = 1",
|
||||
-1, &stmt, NULL);
|
||||
if (rc == SQLITE_OK) {
|
||||
sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC);
|
||||
sqlite3_step(stmt);
|
||||
sqlite3_finalize(stmt);
|
||||
stmt = NULL;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 1: {
|
||||
/* Corrupt _rescore_vectors00 vector blob for a specific row */
|
||||
rc = sqlite3_prepare_v2(db,
|
||||
"UPDATE v_rescore_vectors00 SET vector = ? WHERE rowid = 3",
|
||||
-1, &stmt, NULL);
|
||||
if (rc == SQLITE_OK) {
|
||||
sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC);
|
||||
sqlite3_step(stmt);
|
||||
sqlite3_finalize(stmt);
|
||||
stmt = NULL;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
/* Truncate the quantized chunk blob to wrong size */
|
||||
rc = sqlite3_prepare_v2(db,
|
||||
"UPDATE v_rescore_chunks00 SET vectors = X'DEADBEEF' WHERE rowid = 1",
|
||||
-1, &stmt, NULL);
|
||||
if (rc == SQLITE_OK) {
|
||||
sqlite3_step(stmt);
|
||||
sqlite3_finalize(stmt);
|
||||
stmt = NULL;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 3: {
|
||||
/* Delete rows from _rescore_vectors (orphan the float vectors) */
|
||||
sqlite3_exec(db,
|
||||
"DELETE FROM v_rescore_vectors00 WHERE rowid IN (2, 4, 6)",
|
||||
NULL, NULL, NULL);
|
||||
break;
|
||||
}
|
||||
case 4: {
|
||||
/* Delete the chunk row entirely from _rescore_chunks */
|
||||
sqlite3_exec(db,
|
||||
"DELETE FROM v_rescore_chunks00 WHERE rowid = 1",
|
||||
NULL, NULL, NULL);
|
||||
break;
|
||||
}
|
||||
case 5: {
|
||||
/* Set vectors to NULL in _rescore_chunks */
|
||||
sqlite3_exec(db,
|
||||
"UPDATE v_rescore_chunks00 SET vectors = NULL WHERE rowid = 1",
|
||||
NULL, NULL, NULL);
|
||||
break;
|
||||
}
|
||||
case 6: {
|
||||
/* Set vector to NULL in _rescore_vectors */
|
||||
sqlite3_exec(db,
|
||||
"UPDATE v_rescore_vectors00 SET vector = NULL WHERE rowid = 3",
|
||||
NULL, NULL, NULL);
|
||||
break;
|
||||
}
|
||||
case 7: {
|
||||
/* Corrupt BOTH tables with fuzz data */
|
||||
int half = payload_size / 2;
|
||||
rc = sqlite3_prepare_v2(db,
|
||||
"UPDATE v_rescore_chunks00 SET vectors = ? WHERE rowid = 1",
|
||||
-1, &stmt, NULL);
|
||||
if (rc == SQLITE_OK) {
|
||||
sqlite3_bind_blob(stmt, 1, payload, half, SQLITE_STATIC);
|
||||
sqlite3_step(stmt);
|
||||
sqlite3_finalize(stmt);
|
||||
stmt = NULL;
|
||||
}
|
||||
rc = sqlite3_prepare_v2(db,
|
||||
"UPDATE v_rescore_vectors00 SET vector = ? WHERE rowid = 1",
|
||||
-1, &stmt, NULL);
|
||||
if (rc == SQLITE_OK) {
|
||||
sqlite3_bind_blob(stmt, 1, payload + half,
|
||||
payload_size - half, SQLITE_STATIC);
|
||||
sqlite3_step(stmt);
|
||||
sqlite3_finalize(stmt);
|
||||
stmt = NULL;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/* Exercise ALL read/write paths -- NONE should crash */
|
||||
|
||||
/* KNN query (triggers rescore_knn Phase 1 + Phase 2) */
|
||||
{
|
||||
float qvec[16] = {1,0,0,0, 0,1,0,0, 0,0,1,0, 0,0,0,1};
|
||||
sqlite3_stmt *knn = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? "
|
||||
"ORDER BY distance LIMIT 5", -1, &knn, NULL);
|
||||
if (knn) {
|
||||
sqlite3_bind_blob(knn, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||
while (sqlite3_step(knn) == SQLITE_ROW) {}
|
||||
sqlite3_finalize(knn);
|
||||
}
|
||||
}
|
||||
|
||||
/* Full scan (triggers reading from _rescore_vectors) */
|
||||
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||
|
||||
/* Point lookups */
|
||||
sqlite3_exec(db, "SELECT * FROM v WHERE rowid = 1", NULL, NULL, NULL);
|
||||
sqlite3_exec(db, "SELECT * FROM v WHERE rowid = 3", NULL, NULL, NULL);
|
||||
|
||||
/* Insert after corruption */
|
||||
{
|
||||
float vec[16] = {0};
|
||||
sqlite3_stmt *ins = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(rowid, emb) VALUES (99, ?)", -1, &ins, NULL);
|
||||
if (ins) {
|
||||
sqlite3_bind_blob(ins, 1, vec, sizeof(vec), SQLITE_STATIC);
|
||||
sqlite3_step(ins);
|
||||
sqlite3_finalize(ins);
|
||||
}
|
||||
}
|
||||
|
||||
/* Delete after corruption */
|
||||
sqlite3_exec(db, "DELETE FROM v WHERE rowid = 5", NULL, NULL, NULL);
|
||||
|
||||
/* Update after corruption */
|
||||
{
|
||||
float vec[16] = {1,1,1,1, 1,1,1,1, 1,1,1,1, 1,1,1,1};
|
||||
sqlite3_stmt *upd = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"UPDATE v SET emb = ? WHERE rowid = 1", -1, &upd, NULL);
|
||||
if (upd) {
|
||||
sqlite3_bind_blob(upd, 1, vec, sizeof(vec), SQLITE_STATIC);
|
||||
sqlite3_step(upd);
|
||||
sqlite3_finalize(upd);
|
||||
}
|
||||
}
|
||||
|
||||
/* KNN again after modifications to corrupted state */
|
||||
{
|
||||
float qvec[16] = {0,0,0,0, 0,0,0,0, 1,1,1,1, 1,1,1,1};
|
||||
sqlite3_stmt *knn = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? "
|
||||
"ORDER BY distance LIMIT 3", -1, &knn, NULL);
|
||||
if (knn) {
|
||||
sqlite3_bind_blob(knn, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||
while (sqlite3_step(knn) == SQLITE_ROW) {}
|
||||
sqlite3_finalize(knn);
|
||||
}
|
||||
}
|
||||
|
||||
sqlite3_exec(db, "DROP TABLE v", NULL, NULL, NULL);
|
||||
sqlite3_close(db);
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -65,8 +65,23 @@ enum Vec0DistanceMetrics {
|
|||
|
||||
enum Vec0IndexType {
|
||||
VEC0_INDEX_TYPE_FLAT = 1,
|
||||
#ifdef SQLITE_VEC_ENABLE_RESCORE
|
||||
VEC0_INDEX_TYPE_RESCORE = 2,
|
||||
#endif
|
||||
};
|
||||
|
||||
#ifdef SQLITE_VEC_ENABLE_RESCORE
|
||||
enum Vec0RescoreQuantizerType {
|
||||
VEC0_RESCORE_QUANTIZER_BIT = 1,
|
||||
VEC0_RESCORE_QUANTIZER_INT8 = 2,
|
||||
};
|
||||
|
||||
struct Vec0RescoreConfig {
|
||||
enum Vec0RescoreQuantizerType quantizer_type;
|
||||
int oversample;
|
||||
};
|
||||
#endif
|
||||
|
||||
struct VectorColumnDefinition {
|
||||
char *name;
|
||||
int name_length;
|
||||
|
|
@ -74,6 +89,9 @@ struct VectorColumnDefinition {
|
|||
enum VectorElementType element_type;
|
||||
enum Vec0DistanceMetrics distance_metric;
|
||||
enum Vec0IndexType index_type;
|
||||
#ifdef SQLITE_VEC_ENABLE_RESCORE
|
||||
struct Vec0RescoreConfig rescore;
|
||||
#endif
|
||||
};
|
||||
|
||||
int vec0_parse_vector_column(const char *source, int source_length,
|
||||
|
|
@ -88,6 +106,13 @@ int vec0_parse_partition_key_definition(const char *source, int source_length,
|
|||
float _test_distance_l2_sqr_float(const float *a, const float *b, size_t dims);
|
||||
float _test_distance_cosine_float(const float *a, const float *b, size_t dims);
|
||||
float _test_distance_hamming(const unsigned char *a, const unsigned char *b, size_t dims);
|
||||
|
||||
#ifdef SQLITE_VEC_ENABLE_RESCORE
|
||||
void _test_rescore_quantize_float_to_bit(const float *src, uint8_t *dst, size_t dim);
|
||||
void _test_rescore_quantize_float_to_int8(const float *src, int8_t *dst, size_t dim);
|
||||
size_t _test_rescore_quantized_byte_size_bit(size_t dimensions);
|
||||
size_t _test_rescore_quantized_byte_size_int8(size_t dimensions);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#endif /* SQLITE_VEC_INTERNAL_H */
|
||||
|
|
|
|||
470
tests/test-rescore-mutations.py
Normal file
470
tests/test-rescore-mutations.py
Normal file
|
|
@ -0,0 +1,470 @@
|
|||
"""Mutation and edge-case tests for the rescore index feature."""
|
||||
import struct
|
||||
import sqlite3
|
||||
import pytest
|
||||
import math
|
||||
import random
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def db():
|
||||
db = sqlite3.connect(":memory:")
|
||||
db.row_factory = sqlite3.Row
|
||||
db.enable_load_extension(True)
|
||||
db.load_extension("dist/vec0")
|
||||
db.enable_load_extension(False)
|
||||
return db
|
||||
|
||||
|
||||
def float_vec(values):
|
||||
"""Pack a list of floats into a blob for sqlite-vec."""
|
||||
return struct.pack(f"{len(values)}f", *values)
|
||||
|
||||
|
||||
def unpack_float_vec(blob):
|
||||
"""Unpack a float vector blob."""
|
||||
n = len(blob) // 4
|
||||
return list(struct.unpack(f"{n}f", blob))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Error cases: rescore + aux/metadata/partition
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_create_error_with_aux_column(db):
|
||||
"""Rescore should reject auxiliary columns."""
|
||||
with pytest.raises(sqlite3.OperationalError, match="Auxiliary columns"):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit),"
|
||||
" +extra text"
|
||||
")"
|
||||
)
|
||||
|
||||
|
||||
def test_create_error_with_metadata_column(db):
|
||||
"""Rescore should reject metadata columns."""
|
||||
with pytest.raises(sqlite3.OperationalError, match="Metadata columns"):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit),"
|
||||
" genre text"
|
||||
")"
|
||||
)
|
||||
|
||||
|
||||
def test_create_error_with_partition_key(db):
|
||||
"""Rescore should reject partition key columns."""
|
||||
with pytest.raises(sqlite3.OperationalError, match="Partition key"):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit),"
|
||||
" user_id integer partition key"
|
||||
")"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Insert / batch / delete / update mutations
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_insert_single_verify_knn(db):
|
||||
"""Insert a single row and verify KNN returns it."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)])
|
||||
rows = db.execute(
|
||||
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
|
||||
[float_vec([1.0] * 8)],
|
||||
).fetchall()
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["rowid"] == 1
|
||||
assert rows[0]["distance"] < 0.01
|
||||
|
||||
|
||||
def test_insert_large_batch(db):
|
||||
"""Insert 200+ rows (multiple chunks with default chunk_size=1024) and verify count and KNN."""
|
||||
dim = 16
|
||||
n = 200
|
||||
random.seed(99)
|
||||
db.execute(
|
||||
f"CREATE VIRTUAL TABLE t USING vec0("
|
||||
f" embedding float[{dim}] indexed by rescore(quantizer=int8)"
|
||||
f")"
|
||||
)
|
||||
for i in range(n):
|
||||
v = [random.gauss(0, 1) for _ in range(dim)]
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
|
||||
[i + 1, float_vec(v)],
|
||||
)
|
||||
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
|
||||
assert row["cnt"] == n
|
||||
|
||||
# KNN should return results
|
||||
query = float_vec([random.gauss(0, 1) for _ in range(dim)])
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10",
|
||||
[query],
|
||||
).fetchall()
|
||||
assert len(rows) == 10
|
||||
|
||||
|
||||
def test_delete_all_rows(db):
|
||||
"""Delete every row, verify count=0, KNN returns empty."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
for i in range(20):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
|
||||
[i + 1, float_vec([float(i)] * 8)],
|
||||
)
|
||||
assert db.execute("SELECT count(*) as cnt FROM t").fetchone()["cnt"] == 20
|
||||
|
||||
for i in range(20):
|
||||
db.execute("DELETE FROM t WHERE rowid = ?", [i + 1])
|
||||
|
||||
assert db.execute("SELECT count(*) as cnt FROM t").fetchone()["cnt"] == 0
|
||||
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 5",
|
||||
[float_vec([0.0] * 8)],
|
||||
).fetchall()
|
||||
assert len(rows) == 0
|
||||
|
||||
|
||||
def test_delete_then_reinsert_same_rowid(db):
|
||||
"""Delete rowid=1, re-insert rowid=1 with different vector, verify KNN uses new vector."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=int8)"
|
||||
")"
|
||||
)
|
||||
# Insert rowid=1 near origin, rowid=2 far from origin
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
|
||||
[float_vec([0.1] * 8)],
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
|
||||
[float_vec([100.0] * 8)],
|
||||
)
|
||||
|
||||
# KNN to [0]*8 -> rowid 1 is closer
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
|
||||
[float_vec([0.0] * 8)],
|
||||
).fetchall()
|
||||
assert rows[0]["rowid"] == 1
|
||||
|
||||
# Delete rowid=1, re-insert with vector far from origin
|
||||
db.execute("DELETE FROM t WHERE rowid = 1")
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
|
||||
[float_vec([200.0] * 8)],
|
||||
)
|
||||
|
||||
# Now KNN to [0]*8 -> rowid 2 should be closer
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
|
||||
[float_vec([0.0] * 8)],
|
||||
).fetchall()
|
||||
assert rows[0]["rowid"] == 2
|
||||
|
||||
|
||||
def test_update_vector(db):
|
||||
"""UPDATE the vector column and verify KNN reflects new value."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=int8)"
|
||||
")"
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
|
||||
[float_vec([0.0] * 8)],
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
|
||||
[float_vec([10.0] * 8)],
|
||||
)
|
||||
|
||||
# Update rowid=1 to be far away
|
||||
db.execute(
|
||||
"UPDATE t SET embedding = ? WHERE rowid = 1",
|
||||
[float_vec([100.0] * 8)],
|
||||
)
|
||||
|
||||
# Now KNN to [0]*8 -> rowid 2 should be closest
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
|
||||
[float_vec([0.0] * 8)],
|
||||
).fetchall()
|
||||
assert rows[0]["rowid"] == 2
|
||||
|
||||
|
||||
def test_knn_after_delete_all_but_one(db):
|
||||
"""Insert 50 rows, delete 49, KNN should only return the survivor."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
for i in range(50):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
|
||||
[i + 1, float_vec([float(i)] * 8)],
|
||||
)
|
||||
# Delete all except rowid=25
|
||||
for i in range(50):
|
||||
if i + 1 != 25:
|
||||
db.execute("DELETE FROM t WHERE rowid = ?", [i + 1])
|
||||
|
||||
assert db.execute("SELECT count(*) as cnt FROM t").fetchone()["cnt"] == 1
|
||||
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10",
|
||||
[float_vec([0.0] * 8)],
|
||||
).fetchall()
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["rowid"] == 25
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Edge cases
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_single_row_knn(db):
|
||||
"""Table with exactly 1 row. LIMIT 1 returns it; LIMIT 5 returns 1."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)])
|
||||
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
|
||||
[float_vec([1.0] * 8)],
|
||||
).fetchall()
|
||||
assert len(rows) == 1
|
||||
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 5",
|
||||
[float_vec([1.0] * 8)],
|
||||
).fetchall()
|
||||
assert len(rows) == 1
|
||||
|
||||
|
||||
def test_knn_with_all_identical_vectors(db):
|
||||
"""All vectors are the same. All distances should be equal."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=int8)"
|
||||
")"
|
||||
)
|
||||
vec = [3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0]
|
||||
for i in range(10):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
|
||||
[i + 1, float_vec(vec)],
|
||||
)
|
||||
|
||||
rows = db.execute(
|
||||
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10",
|
||||
[float_vec(vec)],
|
||||
).fetchall()
|
||||
assert len(rows) == 10
|
||||
# All distances should be ~0 (exact match)
|
||||
for r in rows:
|
||||
assert r["distance"] < 0.01
|
||||
|
||||
|
||||
def test_zero_vector_insert(db):
|
||||
"""Insert the zero vector [0,0,...,0]. Should not crash quantization."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
|
||||
[float_vec([0.0] * 8)],
|
||||
)
|
||||
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
|
||||
assert row["cnt"] == 1
|
||||
|
||||
# Also test int8 quantizer with zero vector
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t2 USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=int8)"
|
||||
")"
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t2(rowid, embedding) VALUES (1, ?)",
|
||||
[float_vec([0.0] * 8)],
|
||||
)
|
||||
row = db.execute("SELECT count(*) as cnt FROM t2").fetchone()
|
||||
assert row["cnt"] == 1
|
||||
|
||||
|
||||
def test_very_large_values(db):
|
||||
"""Insert vectors with very large float values. Quantization should not crash."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=int8)"
|
||||
")"
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
|
||||
[float_vec([1e30] * 8)],
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
|
||||
[float_vec([1e30, -1e30, 1e30, -1e30, 1e30, -1e30, 1e30, -1e30])],
|
||||
)
|
||||
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
|
||||
assert row["cnt"] == 2
|
||||
|
||||
|
||||
def test_negative_values(db):
|
||||
"""Insert vectors with all negative values. Bit quantization maps all to 0."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
|
||||
[float_vec([-1.0, -2.0, -3.0, -4.0, -5.0, -6.0, -7.0, -8.0])],
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
|
||||
[float_vec([-0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8])],
|
||||
)
|
||||
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
|
||||
assert row["cnt"] == 2
|
||||
|
||||
# KNN should still work
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 2",
|
||||
[float_vec([-0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8])],
|
||||
).fetchall()
|
||||
assert len(rows) == 2
|
||||
assert rows[0]["rowid"] == 2
|
||||
|
||||
|
||||
def test_single_dimension(db):
|
||||
"""Single-dimension vector (edge case for quantization)."""
|
||||
# int8 quantizer (bit needs dim divisible by 8)
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=int8)"
|
||||
")"
|
||||
)
|
||||
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)])
|
||||
db.execute("INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec([5.0] * 8)])
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
|
||||
[float_vec([1.0] * 8)],
|
||||
).fetchall()
|
||||
assert rows[0]["rowid"] == 1
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# vec_debug() verification
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_vec_debug_contains_rescore(db):
|
||||
"""vec_debug() should contain 'rescore' in build flags when compiled with SQLITE_VEC_ENABLE_RESCORE."""
|
||||
row = db.execute("SELECT vec_debug() as d").fetchone()
|
||||
assert "rescore" in row["d"]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Insert batch recall test
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_insert_batch_recall(db):
|
||||
"""Insert 150 rows and verify KNN recall is reasonable (>0.6)."""
|
||||
dim = 16
|
||||
n = 150
|
||||
k = 10
|
||||
random.seed(77)
|
||||
|
||||
db.execute(
|
||||
f"CREATE VIRTUAL TABLE t_rescore USING vec0("
|
||||
f" embedding float[{dim}] indexed by rescore(quantizer=int8, oversample=16)"
|
||||
f")"
|
||||
)
|
||||
db.execute(
|
||||
f"CREATE VIRTUAL TABLE t_flat USING vec0(embedding float[{dim}])"
|
||||
)
|
||||
|
||||
vectors = [[random.gauss(0, 1) for _ in range(dim)] for _ in range(n)]
|
||||
for i, v in enumerate(vectors):
|
||||
blob = float_vec(v)
|
||||
db.execute(
|
||||
"INSERT INTO t_rescore(rowid, embedding) VALUES (?, ?)", [i + 1, blob]
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t_flat(rowid, embedding) VALUES (?, ?)", [i + 1, blob]
|
||||
)
|
||||
|
||||
query = float_vec([random.gauss(0, 1) for _ in range(dim)])
|
||||
|
||||
rescore_rows = db.execute(
|
||||
"SELECT rowid FROM t_rescore WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
|
||||
[query, k],
|
||||
).fetchall()
|
||||
flat_rows = db.execute(
|
||||
"SELECT rowid FROM t_flat WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
|
||||
[query, k],
|
||||
).fetchall()
|
||||
|
||||
rescore_ids = {r["rowid"] for r in rescore_rows}
|
||||
flat_ids = {r["rowid"] for r in flat_rows}
|
||||
recall = len(rescore_ids & flat_ids) / k
|
||||
assert recall >= 0.6, f"Recall too low: {recall}"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Distance metric variants
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_knn_int8_cosine(db):
|
||||
"""Rescore with quantizer=int8 and distance_metric=cosine."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] distance_metric=cosine indexed by rescore(quantizer=int8)"
|
||||
")"
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
|
||||
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
|
||||
[float_vec([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (3, ?)",
|
||||
[float_vec([1.0, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
)
|
||||
rows = db.execute(
|
||||
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 2",
|
||||
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
).fetchall()
|
||||
assert rows[0]["rowid"] == 1
|
||||
assert rows[0]["distance"] < 0.01
|
||||
568
tests/test-rescore.py
Normal file
568
tests/test-rescore.py
Normal file
|
|
@ -0,0 +1,568 @@
|
|||
"""Tests for the rescore index feature in sqlite-vec."""
|
||||
import struct
|
||||
import sqlite3
|
||||
import pytest
|
||||
import math
|
||||
import random
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def db():
|
||||
db = sqlite3.connect(":memory:")
|
||||
db.row_factory = sqlite3.Row
|
||||
db.enable_load_extension(True)
|
||||
db.load_extension("dist/vec0")
|
||||
db.enable_load_extension(False)
|
||||
return db
|
||||
|
||||
|
||||
def float_vec(values):
|
||||
"""Pack a list of floats into a blob for sqlite-vec."""
|
||||
return struct.pack(f"{len(values)}f", *values)
|
||||
|
||||
|
||||
def unpack_float_vec(blob):
|
||||
"""Unpack a float vector blob."""
|
||||
n = len(blob) // 4
|
||||
return list(struct.unpack(f"{n}f", blob))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Creation tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_create_bit(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[128] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
# Table exists and has the right structure
|
||||
row = db.execute(
|
||||
"SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'"
|
||||
).fetchone()
|
||||
assert row["cnt"] > 0
|
||||
|
||||
|
||||
def test_create_int8(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[128] indexed by rescore(quantizer=int8)"
|
||||
")"
|
||||
)
|
||||
row = db.execute(
|
||||
"SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'"
|
||||
).fetchone()
|
||||
assert row["cnt"] > 0
|
||||
|
||||
|
||||
def test_create_with_oversample(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[128] indexed by rescore(quantizer=bit, oversample=16)"
|
||||
")"
|
||||
)
|
||||
row = db.execute(
|
||||
"SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'"
|
||||
).fetchone()
|
||||
assert row["cnt"] > 0
|
||||
|
||||
|
||||
def test_create_with_distance_metric(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[128] distance_metric=cosine indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
row = db.execute(
|
||||
"SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'"
|
||||
).fetchone()
|
||||
assert row["cnt"] > 0
|
||||
|
||||
|
||||
def test_create_error_missing_quantizer(db):
|
||||
with pytest.raises(sqlite3.OperationalError):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[128] indexed by rescore(oversample=8)"
|
||||
")"
|
||||
)
|
||||
|
||||
|
||||
def test_create_error_invalid_quantizer(db):
|
||||
with pytest.raises(sqlite3.OperationalError):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[128] indexed by rescore(quantizer=float)"
|
||||
")"
|
||||
)
|
||||
|
||||
|
||||
def test_create_error_on_bit_column(db):
|
||||
with pytest.raises(sqlite3.OperationalError):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding bit[1024] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
|
||||
|
||||
def test_create_error_on_int8_column(db):
|
||||
with pytest.raises(sqlite3.OperationalError):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding int8[128] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
|
||||
|
||||
def test_create_error_bad_oversample_zero(db):
|
||||
with pytest.raises(sqlite3.OperationalError):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[128] indexed by rescore(quantizer=bit, oversample=0)"
|
||||
")"
|
||||
)
|
||||
|
||||
|
||||
def test_create_error_bad_oversample_too_large(db):
|
||||
with pytest.raises(sqlite3.OperationalError):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[128] indexed by rescore(quantizer=bit, oversample=999)"
|
||||
")"
|
||||
)
|
||||
|
||||
|
||||
def test_create_error_bit_dim_not_divisible_by_8(db):
|
||||
with pytest.raises(sqlite3.OperationalError):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[100] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Shadow table tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_shadow_tables_exist(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[128] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
tables = [
|
||||
r[0]
|
||||
for r in db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 't_%' ORDER BY name"
|
||||
).fetchall()
|
||||
]
|
||||
assert "t_rescore_chunks00" in tables
|
||||
assert "t_rescore_vectors00" in tables
|
||||
# Rescore columns don't create _vector_chunks
|
||||
assert "t_vector_chunks00" not in tables
|
||||
|
||||
|
||||
def test_drop_cleans_up(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[128] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
db.execute("DROP TABLE t")
|
||||
tables = [
|
||||
r[0]
|
||||
for r in db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 't_%'"
|
||||
).fetchall()
|
||||
]
|
||||
assert len(tables) == 0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Insert tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_insert_single(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)])
|
||||
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
|
||||
assert row["cnt"] == 1
|
||||
|
||||
|
||||
def test_insert_multiple(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=int8)"
|
||||
")"
|
||||
)
|
||||
for i in range(10):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
|
||||
[i + 1, float_vec([float(i)] * 8)],
|
||||
)
|
||||
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
|
||||
assert row["cnt"] == 10
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Delete tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_delete_single(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)])
|
||||
db.execute("DELETE FROM t WHERE rowid = 1")
|
||||
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
|
||||
assert row["cnt"] == 0
|
||||
|
||||
|
||||
def test_delete_and_reinsert(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)])
|
||||
db.execute("DELETE FROM t WHERE rowid = 1")
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec([2.0] * 8)]
|
||||
)
|
||||
row = db.execute("SELECT count(*) as cnt FROM t").fetchone()
|
||||
assert row["cnt"] == 1
|
||||
|
||||
|
||||
def test_point_query_returns_float(db):
|
||||
"""SELECT by rowid should return the original float vector, not quantized."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
vals = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
|
||||
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec(vals)])
|
||||
row = db.execute("SELECT embedding FROM t WHERE rowid = 1").fetchone()
|
||||
result = unpack_float_vec(row["embedding"])
|
||||
for a, b in zip(result, vals):
|
||||
assert abs(a - b) < 1e-6
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# KNN tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_knn_basic_bit(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
# Insert vectors where [1,0,0,...] is closest to query [1,0,0,...]
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
|
||||
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
|
||||
[float_vec([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (3, ?)",
|
||||
[float_vec([0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
)
|
||||
rows = db.execute(
|
||||
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
|
||||
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
).fetchall()
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["rowid"] == 1
|
||||
|
||||
|
||||
def test_knn_basic_int8(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=int8)"
|
||||
")"
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
|
||||
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
|
||||
[float_vec([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (3, ?)",
|
||||
[float_vec([0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
)
|
||||
rows = db.execute(
|
||||
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
|
||||
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
).fetchall()
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["rowid"] == 1
|
||||
|
||||
|
||||
def test_knn_returns_float_distances(db):
|
||||
"""KNN should return float-precision distances, not quantized distances."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
v1 = [1.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
v2 = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]
|
||||
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec(v1)])
|
||||
db.execute("INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec(v2)])
|
||||
|
||||
query = [1.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
rows = db.execute(
|
||||
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 2",
|
||||
[float_vec(query)],
|
||||
).fetchall()
|
||||
|
||||
# First result should be exact match with distance ~0
|
||||
assert rows[0]["rowid"] == 1
|
||||
assert rows[0]["distance"] < 0.01
|
||||
|
||||
# Second result should have a float distance
|
||||
# sqrt((1-0)^2 + (0.5-0)^2 + (0-1)^2) = sqrt(2.25) = 1.5
|
||||
assert abs(rows[1]["distance"] - 1.5) < 0.01
|
||||
|
||||
|
||||
def test_knn_recall(db):
|
||||
"""With enough vectors, rescore should achieve good recall (>0.9)."""
|
||||
dim = 32
|
||||
n = 1000
|
||||
k = 10
|
||||
random.seed(42)
|
||||
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t_rescore USING vec0("
|
||||
f" embedding float[{dim}] indexed by rescore(quantizer=bit, oversample=16)"
|
||||
")"
|
||||
)
|
||||
db.execute(
|
||||
f"CREATE VIRTUAL TABLE t_flat USING vec0(embedding float[{dim}])"
|
||||
)
|
||||
|
||||
vectors = [[random.gauss(0, 1) for _ in range(dim)] for _ in range(n)]
|
||||
for i, v in enumerate(vectors):
|
||||
blob = float_vec(v)
|
||||
db.execute(
|
||||
"INSERT INTO t_rescore(rowid, embedding) VALUES (?, ?)", [i + 1, blob]
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t_flat(rowid, embedding) VALUES (?, ?)", [i + 1, blob]
|
||||
)
|
||||
|
||||
query = float_vec([random.gauss(0, 1) for _ in range(dim)])
|
||||
|
||||
rescore_rows = db.execute(
|
||||
"SELECT rowid FROM t_rescore WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
|
||||
[query, k],
|
||||
).fetchall()
|
||||
flat_rows = db.execute(
|
||||
"SELECT rowid FROM t_flat WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
|
||||
[query, k],
|
||||
).fetchall()
|
||||
|
||||
rescore_ids = {r["rowid"] for r in rescore_rows}
|
||||
flat_ids = {r["rowid"] for r in flat_rows}
|
||||
recall = len(rescore_ids & flat_ids) / k
|
||||
assert recall >= 0.7, f"Recall too low: {recall}"
|
||||
|
||||
|
||||
def test_knn_cosine(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] distance_metric=cosine indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (1, ?)",
|
||||
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (2, ?)",
|
||||
[float_vec([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
)
|
||||
rows = db.execute(
|
||||
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1",
|
||||
[float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
|
||||
).fetchall()
|
||||
assert rows[0]["rowid"] == 1
|
||||
# cosine distance of identical vectors should be ~0
|
||||
assert rows[0]["distance"] < 0.01
|
||||
|
||||
|
||||
def test_knn_empty_table(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 5",
|
||||
[float_vec([1.0] * 8)],
|
||||
).fetchall()
|
||||
assert len(rows) == 0
|
||||
|
||||
|
||||
def test_knn_k_larger_than_n(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)])
|
||||
db.execute("INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec([2.0] * 8)])
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10",
|
||||
[float_vec([1.0] * 8)],
|
||||
).fetchall()
|
||||
assert len(rows) == 2
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Integration / edge case tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_knn_with_rowid_in(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=bit)"
|
||||
")"
|
||||
)
|
||||
for i in range(5):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
|
||||
[i + 1, float_vec([float(i)] * 8)],
|
||||
)
|
||||
# Only search within rowids 1, 3, 5
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE embedding MATCH ? AND rowid IN (1, 3, 5) ORDER BY distance LIMIT 3",
|
||||
[float_vec([0.0] * 8)],
|
||||
).fetchall()
|
||||
result_ids = {r["rowid"] for r in rows}
|
||||
assert result_ids <= {1, 3, 5}
|
||||
|
||||
|
||||
def test_knn_after_deletes(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" embedding float[8] indexed by rescore(quantizer=int8)"
|
||||
")"
|
||||
)
|
||||
for i in range(10):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
|
||||
[i + 1, float_vec([float(i)] * 8)],
|
||||
)
|
||||
# Delete the closest match (rowid 1 = [0,0,...])
|
||||
db.execute("DELETE FROM t WHERE rowid = 1")
|
||||
rows = db.execute(
|
||||
"SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 5",
|
||||
[float_vec([0.0] * 8)],
|
||||
).fetchall()
|
||||
# Verify ordering: rowid 2 ([1]*8) should be closest, then 3 ([2]*8), etc.
|
||||
assert len(rows) >= 2
|
||||
assert rows[0]["distance"] <= rows[1]["distance"]
|
||||
# rowid 2 = [1,1,...] → L2 = sqrt(8) ≈ 2.83, rowid 3 = [2,2,...] → L2 = sqrt(32) ≈ 5.66
|
||||
assert rows[0]["rowid"] == 2, f"Expected rowid 2, got {rows[0]['rowid']} with dist={rows[0]['distance']}"
|
||||
|
||||
|
||||
def test_oversample_effect(db):
|
||||
"""Higher oversample should give equal or better recall."""
|
||||
dim = 32
|
||||
n = 500
|
||||
k = 10
|
||||
random.seed(123)
|
||||
|
||||
vectors = [[random.gauss(0, 1) for _ in range(dim)] for _ in range(n)]
|
||||
query = float_vec([random.gauss(0, 1) for _ in range(dim)])
|
||||
|
||||
recalls = []
|
||||
for oversample in [2, 16]:
|
||||
tname = f"t_os{oversample}"
|
||||
db.execute(
|
||||
f"CREATE VIRTUAL TABLE {tname} USING vec0("
|
||||
f" embedding float[{dim}] indexed by rescore(quantizer=bit, oversample={oversample})"
|
||||
")"
|
||||
)
|
||||
for i, v in enumerate(vectors):
|
||||
db.execute(
|
||||
f"INSERT INTO {tname}(rowid, embedding) VALUES (?, ?)",
|
||||
[i + 1, float_vec(v)],
|
||||
)
|
||||
rows = db.execute(
|
||||
f"SELECT rowid FROM {tname} WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
|
||||
[query, k],
|
||||
).fetchall()
|
||||
recalls.append({r["rowid"] for r in rows})
|
||||
|
||||
# Also get ground truth
|
||||
db.execute(f"CREATE VIRTUAL TABLE t_flat USING vec0(embedding float[{dim}])")
|
||||
for i, v in enumerate(vectors):
|
||||
db.execute(
|
||||
"INSERT INTO t_flat(rowid, embedding) VALUES (?, ?)",
|
||||
[i + 1, float_vec(v)],
|
||||
)
|
||||
gt_rows = db.execute(
|
||||
"SELECT rowid FROM t_flat WHERE embedding MATCH ? ORDER BY distance LIMIT ?",
|
||||
[query, k],
|
||||
).fetchall()
|
||||
gt_ids = {r["rowid"] for r in gt_rows}
|
||||
|
||||
recall_low = len(recalls[0] & gt_ids) / k
|
||||
recall_high = len(recalls[1] & gt_ids) / k
|
||||
assert recall_high >= recall_low
|
||||
|
||||
|
||||
def test_multiple_vector_columns(db):
|
||||
"""One column with rescore, one without."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" v1 float[8] indexed by rescore(quantizer=bit),"
|
||||
" v2 float[8]"
|
||||
")"
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v1, v2) VALUES (1, ?, ?)",
|
||||
[float_vec([1.0] * 8), float_vec([0.0] * 8)],
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v1, v2) VALUES (2, ?, ?)",
|
||||
[float_vec([0.0] * 8), float_vec([1.0] * 8)],
|
||||
)
|
||||
|
||||
# KNN on v1 (rescore path)
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE v1 MATCH ? ORDER BY distance LIMIT 1",
|
||||
[float_vec([1.0] * 8)],
|
||||
).fetchall()
|
||||
assert rows[0]["rowid"] == 1
|
||||
|
||||
# KNN on v2 (normal path)
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE v2 MATCH ? ORDER BY distance LIMIT 1",
|
||||
[float_vec([1.0] * 8)],
|
||||
).fetchall()
|
||||
assert rows[0]["rowid"] == 2
|
||||
|
|
@ -760,6 +760,202 @@ void test_distance_hamming() {
|
|||
printf(" All distance_hamming tests passed.\n");
|
||||
}
|
||||
|
||||
#ifdef SQLITE_VEC_ENABLE_RESCORE
|
||||
|
||||
void test_rescore_quantize_float_to_bit() {
|
||||
printf("Starting %s...\n", __func__);
|
||||
uint8_t dst[16];
|
||||
|
||||
// All positive -> all bits 1
|
||||
{
|
||||
float src[8] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f};
|
||||
memset(dst, 0, sizeof(dst));
|
||||
_test_rescore_quantize_float_to_bit(src, dst, 8);
|
||||
assert(dst[0] == 0xFF);
|
||||
}
|
||||
|
||||
// All negative -> all bits 0
|
||||
{
|
||||
float src[8] = {-1.0f, -2.0f, -3.0f, -4.0f, -5.0f, -6.0f, -7.0f, -8.0f};
|
||||
memset(dst, 0xFF, sizeof(dst));
|
||||
_test_rescore_quantize_float_to_bit(src, dst, 8);
|
||||
assert(dst[0] == 0x00);
|
||||
}
|
||||
|
||||
// Alternating positive/negative
|
||||
{
|
||||
float src[8] = {1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f};
|
||||
_test_rescore_quantize_float_to_bit(src, dst, 8);
|
||||
// bits 0,2,4,6 set => 0b01010101 = 0x55
|
||||
assert(dst[0] == 0x55);
|
||||
}
|
||||
|
||||
// Zero values -> bit is set (>= 0.0f)
|
||||
{
|
||||
float src[8] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
|
||||
_test_rescore_quantize_float_to_bit(src, dst, 8);
|
||||
assert(dst[0] == 0xFF);
|
||||
}
|
||||
|
||||
// 128 dimensions -> 16 bytes output
|
||||
{
|
||||
float src[128];
|
||||
for (int i = 0; i < 128; i++) src[i] = (i % 2 == 0) ? 1.0f : -1.0f;
|
||||
memset(dst, 0, 16);
|
||||
_test_rescore_quantize_float_to_bit(src, dst, 128);
|
||||
// Even indices set: bits 0,2,4,6 in each byte => 0x55
|
||||
for (int i = 0; i < 16; i++) {
|
||||
assert(dst[i] == 0x55);
|
||||
}
|
||||
}
|
||||
|
||||
printf(" All rescore_quantize_float_to_bit tests passed.\n");
|
||||
}
|
||||
|
||||
void test_rescore_quantize_float_to_int8() {
|
||||
printf("Starting %s...\n", __func__);
|
||||
int8_t dst[256];
|
||||
|
||||
// Uniform vector -> all zeros (range=0)
|
||||
{
|
||||
float src[8] = {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f};
|
||||
_test_rescore_quantize_float_to_int8(src, dst, 8);
|
||||
for (int i = 0; i < 8; i++) {
|
||||
assert(dst[i] == 0);
|
||||
}
|
||||
}
|
||||
|
||||
// [0.0, 1.0] -> should map to [-128, 127]
|
||||
{
|
||||
float src[2] = {0.0f, 1.0f};
|
||||
_test_rescore_quantize_float_to_int8(src, dst, 2);
|
||||
assert(dst[0] == -128);
|
||||
assert(dst[1] == 127);
|
||||
}
|
||||
|
||||
// [-1.0, 0.0] -> should map to [-128, 127]
|
||||
{
|
||||
float src[2] = {-1.0f, 0.0f};
|
||||
_test_rescore_quantize_float_to_int8(src, dst, 2);
|
||||
assert(dst[0] == -128);
|
||||
assert(dst[1] == 127);
|
||||
}
|
||||
|
||||
// Single-element: range=0 -> 0
|
||||
{
|
||||
float src[1] = {42.0f};
|
||||
_test_rescore_quantize_float_to_int8(src, dst, 1);
|
||||
assert(dst[0] == 0);
|
||||
}
|
||||
|
||||
// Verify range: all outputs in [-128, 127], min near -128, max near 127
|
||||
{
|
||||
float src[4] = {-100.0f, 0.0f, 100.0f, 50.0f};
|
||||
_test_rescore_quantize_float_to_int8(src, dst, 4);
|
||||
for (int i = 0; i < 4; i++) {
|
||||
assert(dst[i] >= -128 && dst[i] <= 127);
|
||||
}
|
||||
// Min maps to -128 (exact), max maps to ~127 (may lose 1 to float rounding)
|
||||
assert(dst[0] == -128);
|
||||
assert(dst[2] >= 126 && dst[2] <= 127);
|
||||
// Middle value (50) should be positive
|
||||
assert(dst[3] > 0);
|
||||
}
|
||||
|
||||
printf(" All rescore_quantize_float_to_int8 tests passed.\n");
|
||||
}
|
||||
|
||||
void test_rescore_quantized_byte_size() {
|
||||
printf("Starting %s...\n", __func__);
|
||||
|
||||
// Bit quantizer: dims/8
|
||||
assert(_test_rescore_quantized_byte_size_bit(128) == 16);
|
||||
assert(_test_rescore_quantized_byte_size_bit(8) == 1);
|
||||
assert(_test_rescore_quantized_byte_size_bit(1024) == 128);
|
||||
|
||||
// Int8 quantizer: dims
|
||||
assert(_test_rescore_quantized_byte_size_int8(128) == 128);
|
||||
assert(_test_rescore_quantized_byte_size_int8(8) == 8);
|
||||
assert(_test_rescore_quantized_byte_size_int8(1024) == 1024);
|
||||
|
||||
printf(" All rescore_quantized_byte_size tests passed.\n");
|
||||
}
|
||||
|
||||
void test_vec0_parse_vector_column_rescore() {
|
||||
printf("Starting %s...\n", __func__);
|
||||
struct VectorColumnDefinition col;
|
||||
int rc;
|
||||
|
||||
// Basic bit quantizer
|
||||
{
|
||||
const char *input = "emb float[128] indexed by rescore(quantizer=bit)";
|
||||
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||
assert(rc == SQLITE_OK);
|
||||
assert(col.index_type == VEC0_INDEX_TYPE_RESCORE);
|
||||
assert(col.rescore.quantizer_type == VEC0_RESCORE_QUANTIZER_BIT);
|
||||
assert(col.rescore.oversample == 8); // default
|
||||
assert(col.dimensions == 128);
|
||||
sqlite3_free(col.name);
|
||||
}
|
||||
|
||||
// Int8 quantizer
|
||||
{
|
||||
const char *input = "emb float[128] indexed by rescore(quantizer=int8)";
|
||||
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||
assert(rc == SQLITE_OK);
|
||||
assert(col.index_type == VEC0_INDEX_TYPE_RESCORE);
|
||||
assert(col.rescore.quantizer_type == VEC0_RESCORE_QUANTIZER_INT8);
|
||||
sqlite3_free(col.name);
|
||||
}
|
||||
|
||||
// Bit quantizer with oversample
|
||||
{
|
||||
const char *input = "emb float[128] indexed by rescore(quantizer=bit, oversample=16)";
|
||||
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||
assert(rc == SQLITE_OK);
|
||||
assert(col.index_type == VEC0_INDEX_TYPE_RESCORE);
|
||||
assert(col.rescore.quantizer_type == VEC0_RESCORE_QUANTIZER_BIT);
|
||||
assert(col.rescore.oversample == 16);
|
||||
sqlite3_free(col.name);
|
||||
}
|
||||
|
||||
// Error: non-float element type
|
||||
{
|
||||
const char *input = "emb int8[128] indexed by rescore(quantizer=bit)";
|
||||
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||
assert(rc == SQLITE_ERROR);
|
||||
}
|
||||
|
||||
// Error: dims not divisible by 8 for bit quantizer
|
||||
{
|
||||
const char *input = "emb float[100] indexed by rescore(quantizer=bit)";
|
||||
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||
assert(rc == SQLITE_ERROR);
|
||||
}
|
||||
|
||||
// Error: missing quantizer
|
||||
{
|
||||
const char *input = "emb float[128] indexed by rescore(oversample=8)";
|
||||
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||
assert(rc == SQLITE_ERROR);
|
||||
}
|
||||
|
||||
// With distance_metric=cosine
|
||||
{
|
||||
const char *input = "emb float[128] distance_metric=cosine indexed by rescore(quantizer=int8)";
|
||||
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||
assert(rc == SQLITE_OK);
|
||||
assert(col.index_type == VEC0_INDEX_TYPE_RESCORE);
|
||||
assert(col.distance_metric == VEC0_DISTANCE_METRIC_COSINE);
|
||||
assert(col.rescore.quantizer_type == VEC0_RESCORE_QUANTIZER_INT8);
|
||||
sqlite3_free(col.name);
|
||||
}
|
||||
|
||||
printf(" All vec0_parse_vector_column_rescore tests passed.\n");
|
||||
}
|
||||
|
||||
#endif /* SQLITE_VEC_ENABLE_RESCORE */
|
||||
|
||||
int main() {
|
||||
printf("Starting unit tests...\n");
|
||||
#ifdef SQLITE_VEC_ENABLE_AVX
|
||||
|
|
@ -768,6 +964,9 @@ int main() {
|
|||
#ifdef SQLITE_VEC_ENABLE_NEON
|
||||
printf("SQLITE_VEC_ENABLE_NEON=1\n");
|
||||
#endif
|
||||
#ifdef SQLITE_VEC_ENABLE_RESCORE
|
||||
printf("SQLITE_VEC_ENABLE_RESCORE=1\n");
|
||||
#endif
|
||||
#if !defined(SQLITE_VEC_ENABLE_AVX) && !defined(SQLITE_VEC_ENABLE_NEON)
|
||||
printf("SIMD: none\n");
|
||||
#endif
|
||||
|
|
@ -778,5 +977,11 @@ int main() {
|
|||
test_distance_l2_sqr_float();
|
||||
test_distance_cosine_float();
|
||||
test_distance_hamming();
|
||||
#ifdef SQLITE_VEC_ENABLE_RESCORE
|
||||
test_rescore_quantize_float_to_bit();
|
||||
test_rescore_quantize_float_to_int8();
|
||||
test_rescore_quantized_byte_size();
|
||||
test_vec0_parse_vector_column_rescore();
|
||||
#endif
|
||||
printf("All unit tests passed.\n");
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue