diff --git a/benchmarks-ann/bench-delete/bench_delete.py b/benchmarks-ann/bench-delete/bench_delete.py index 802f0a4..0ebd2ec 100644 --- a/benchmarks-ann/bench-delete/bench_delete.py +++ b/benchmarks-ann/bench-delete/bench_delete.py @@ -159,7 +159,7 @@ INDEX_REGISTRY = { def _ivf_train(conn): """Trigger built-in k-means training for IVF.""" t0 = now_ns() - conn.execute("INSERT INTO vec_items(id) VALUES ('compute-centroids')") + conn.execute("INSERT INTO vec_items(vec_items) VALUES ('compute-centroids')") conn.commit() return ns_to_s(now_ns() - t0) diff --git a/benchmarks-ann/bench.py b/benchmarks-ann/bench.py index a4cbbe4..966c458 100644 --- a/benchmarks-ann/bench.py +++ b/benchmarks-ann/bench.py @@ -456,7 +456,7 @@ def _ivf_create_table_sql(params): def _ivf_post_insert_hook(conn, params): print(" Training k-means centroids (built-in)...", flush=True) t0 = time.perf_counter() - conn.execute("INSERT INTO vec_items(id) VALUES ('compute-centroids')") + conn.execute("INSERT INTO vec_items(vec_items) VALUES ('compute-centroids')") conn.commit() elapsed = time.perf_counter() - t0 print(f" Training done in {elapsed:.1f}s", flush=True) @@ -514,7 +514,7 @@ def _ivf_faiss_kmeans_hook(conn, params): for cid, blob in centroids: conn.execute( - "INSERT INTO vec_items(id, embedding) VALUES (?, ?)", + "INSERT INTO vec_items(vec_items, embedding) VALUES (?, ?)", (f"set-centroid:{cid}", blob), ) conn.commit() @@ -540,7 +540,7 @@ def _ivf_pre_query_hook(conn, params): nprobe = params.get("nprobe") if nprobe: conn.execute( - "INSERT INTO vec_items(id) VALUES (?)", + "INSERT INTO vec_items(vec_items) VALUES (?)", (f"nprobe={nprobe}",), ) conn.commit() @@ -572,7 +572,7 @@ INDEX_REGISTRY["ivf"] = { "insert_sql": None, "post_insert_hook": _ivf_post_insert_hook, "pre_query_hook": _ivf_pre_query_hook, - "train_sql": lambda _: "INSERT INTO vec_items(id) VALUES ('compute-centroids')", + "train_sql": lambda _: "INSERT INTO vec_items(vec_items) VALUES ('compute-centroids')", "run_query": None, "query_sql": None, "describe": _ivf_describe, @@ -616,7 +616,7 @@ def _diskann_pre_query_hook(conn, params): L_search = params.get("L_search", 0) if L_search: conn.execute( - "INSERT INTO vec_items(id) VALUES (?)", + "INSERT INTO vec_items(vec_items) VALUES (?)", (f"search_list_size_search={L_search}",), ) conn.commit() diff --git a/sqlite-vec-rescore.c b/sqlite-vec-rescore.c index 5432612..6a47214 100644 --- a/sqlite-vec-rescore.c +++ b/sqlite-vec-rescore.c @@ -351,7 +351,9 @@ static int rescore_knn(vec0_vtab *p, vec0_cursor *pCur, (void)pCur; (void)aMetadataIn; int rc = SQLITE_OK; - int oversample = vector_column->rescore.oversample; + int oversample = vector_column->rescore.oversample_search > 0 + ? vector_column->rescore.oversample_search + : vector_column->rescore.oversample; i64 k_oversample = k * oversample; if (k_oversample > 4096) k_oversample = 4096; @@ -640,6 +642,27 @@ cleanup: return rc; } +/** + * Handle FTS5-style command dispatch for rescore parameters. + * Returns SQLITE_OK if handled, SQLITE_EMPTY if not a rescore command. + */ +static int rescore_handle_command(vec0_vtab *p, const char *command) { + if (strncmp(command, "oversample=", 11) == 0) { + int val = atoi(command + 11); + if (val < 1) { + vtab_set_error(&p->base, "oversample must be >= 1"); + return SQLITE_ERROR; + } + for (int i = 0; i < p->numVectorColumns; i++) { + if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE) { + p->vector_columns[i].rescore.oversample_search = val; + } + } + return SQLITE_OK; + } + return SQLITE_EMPTY; +} + #ifdef SQLITE_VEC_TEST void _test_rescore_quantize_float_to_bit(const float *src, uint8_t *dst, size_t dim) { rescore_quantize_float_to_bit(src, dst, dim); diff --git a/sqlite-vec.c b/sqlite-vec.c index 16c3b4d..40fe0bf 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -2588,7 +2588,8 @@ enum Vec0RescoreQuantizerType { struct Vec0RescoreConfig { enum Vec0RescoreQuantizerType quantizer_type; - int oversample; + int oversample; // CREATE-time default + int oversample_search; // runtime override (0 = use default) }; #endif @@ -3399,8 +3400,9 @@ static sqlite3_module vec_eachModule = { #define VEC0_COLUMN_ID 0 #define VEC0_COLUMN_USERN_START 1 -#define VEC0_COLUMN_OFFSET_DISTANCE 1 -#define VEC0_COLUMN_OFFSET_K 2 +#define VEC0_COLUMN_OFFSET_COMMAND 1 +#define VEC0_COLUMN_OFFSET_DISTANCE 2 +#define VEC0_COLUMN_OFFSET_K 3 #define VEC0_SHADOW_INFO_NAME "\"%w\".\"%w_info\"" @@ -3498,6 +3500,10 @@ struct vec0_vtab { // Will change the schema of the _rowids table, and insert/query logic. int pkIsText; + // True if the hidden command column (named after the table) exists. + // Tables created before v0.1.10 or without _info table don't have it. + int hasCommandColumn; + // number of defined vector columns. int numVectorColumns; @@ -3777,20 +3783,19 @@ int vec0_num_defined_user_columns(vec0_vtab *p) { * @param p vec0 table * @return int */ -int vec0_column_distance_idx(vec0_vtab *p) { - return VEC0_COLUMN_USERN_START + (vec0_num_defined_user_columns(p) - 1) + - VEC0_COLUMN_OFFSET_DISTANCE; +int vec0_column_command_idx(vec0_vtab *p) { + // Command column is the first hidden column (right after user columns) + return VEC0_COLUMN_USERN_START + vec0_num_defined_user_columns(p); +} + +int vec0_column_distance_idx(vec0_vtab *p) { + int base = VEC0_COLUMN_USERN_START + vec0_num_defined_user_columns(p); + return base + (p->hasCommandColumn ? 1 : 0); } -/** - * @brief Returns the index of the k hidden column for the given vec0 table. - * - * @param p vec0 table - * @return int k column index - */ int vec0_column_k_idx(vec0_vtab *p) { - return VEC0_COLUMN_USERN_START + (vec0_num_defined_user_columns(p) - 1) + - VEC0_COLUMN_OFFSET_K; + int base = VEC0_COLUMN_USERN_START + vec0_num_defined_user_columns(p); + return base + (p->hasCommandColumn ? 2 : 1); } /** @@ -5205,6 +5210,74 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, } } + // Determine whether to add the FTS5-style hidden command column. + // New tables (isCreate) always get it; existing tables only if created + // with v0.1.10+ (which validated no column name == table name). + int hasCommandColumn = 0; + if (isCreate) { + // Validate no user column name conflicts with the table name + const char *tblName = argv[2]; + int tblNameLen = (int)strlen(tblName); + for (int i = 0; i < numVectorColumns; i++) { + if (pNew->vector_columns[i].name_length == tblNameLen && + sqlite3_strnicmp(pNew->vector_columns[i].name, tblName, tblNameLen) == 0) { + *pzErr = sqlite3_mprintf( + VEC_CONSTRUCTOR_ERROR + "column name '%s' conflicts with table name (reserved for command column)", + tblName); + goto error; + } + } + for (int i = 0; i < numPartitionColumns; i++) { + if (pNew->paritition_columns[i].name_length == tblNameLen && + sqlite3_strnicmp(pNew->paritition_columns[i].name, tblName, tblNameLen) == 0) { + *pzErr = sqlite3_mprintf( + VEC_CONSTRUCTOR_ERROR + "column name '%s' conflicts with table name (reserved for command column)", + tblName); + goto error; + } + } + for (int i = 0; i < numAuxiliaryColumns; i++) { + if (pNew->auxiliary_columns[i].name_length == tblNameLen && + sqlite3_strnicmp(pNew->auxiliary_columns[i].name, tblName, tblNameLen) == 0) { + *pzErr = sqlite3_mprintf( + VEC_CONSTRUCTOR_ERROR + "column name '%s' conflicts with table name (reserved for command column)", + tblName); + goto error; + } + } + for (int i = 0; i < numMetadataColumns; i++) { + if (pNew->metadata_columns[i].name_length == tblNameLen && + sqlite3_strnicmp(pNew->metadata_columns[i].name, tblName, tblNameLen) == 0) { + *pzErr = sqlite3_mprintf( + VEC_CONSTRUCTOR_ERROR + "column name '%s' conflicts with table name (reserved for command column)", + tblName); + goto error; + } + } + hasCommandColumn = 1; + } else { + // xConnect: check _info shadow table for version + sqlite3_stmt *stmtInfo = NULL; + char *zInfoSql = sqlite3_mprintf( + "SELECT value FROM " VEC0_SHADOW_INFO_NAME " WHERE key = 'CREATE_VERSION_PATCH'", + argv[1], argv[2]); + if (zInfoSql) { + int infoRc = sqlite3_prepare_v2(db, zInfoSql, -1, &stmtInfo, NULL); + sqlite3_free(zInfoSql); + if (infoRc == SQLITE_OK && sqlite3_step(stmtInfo) == SQLITE_ROW) { + int patch = sqlite3_column_int(stmtInfo, 0); + hasCommandColumn = (patch >= 10); // v0.1.10+ + } + // If _info doesn't exist or has no version, assume old table + sqlite3_finalize(stmtInfo); + } + } + pNew->hasCommandColumn = hasCommandColumn; + sqlite3_str *createStr = sqlite3_str_new(NULL); sqlite3_str_appendall(createStr, "CREATE TABLE x("); if (pkColumnName) { @@ -5246,7 +5319,11 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, } } - sqlite3_str_appendall(createStr, " distance hidden, k hidden) "); + if (hasCommandColumn) { + sqlite3_str_appendf(createStr, " \"%w\" hidden, distance hidden, k hidden) ", argv[2]); + } else { + sqlite3_str_appendall(createStr, " distance hidden, k hidden) "); + } if (pkColumnName) { sqlite3_str_appendall(createStr, "without rowid "); } @@ -10161,25 +10238,31 @@ static int vec0Update(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv, } // INSERT operation else if (argc > 1 && sqlite3_value_type(argv[0]) == SQLITE_NULL) { -#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE || SQLITE_VEC_ENABLE_DISKANN - // Check for command inserts: INSERT INTO t(rowid) VALUES ('command-string') - // The id column holds the command string. - sqlite3_value *idVal = argv[2 + VEC0_COLUMN_ID]; - if (sqlite3_value_type(idVal) == SQLITE_TEXT) { - const char *cmd = (const char *)sqlite3_value_text(idVal); - vec0_vtab *p = (vec0_vtab *)pVTab; - int cmdRc = SQLITE_EMPTY; + vec0_vtab *p = (vec0_vtab *)pVTab; + // FTS5-style command dispatch via hidden column named after table + if (p->hasCommandColumn) { + sqlite3_value *cmdVal = argv[2 + vec0_column_command_idx(p)]; + if (sqlite3_value_type(cmdVal) == SQLITE_TEXT) { + const char *cmd = (const char *)sqlite3_value_text(cmdVal); + int cmdRc = SQLITE_EMPTY; +#if SQLITE_VEC_ENABLE_RESCORE + cmdRc = rescore_handle_command(p, cmd); +#endif #if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE - cmdRc = ivf_handle_command(p, cmd, argc, argv); + if (cmdRc == SQLITE_EMPTY) + cmdRc = ivf_handle_command(p, cmd, argc, argv); #endif #if SQLITE_VEC_ENABLE_DISKANN - if (cmdRc == SQLITE_EMPTY) - cmdRc = diskann_handle_command(p, cmd); + if (cmdRc == SQLITE_EMPTY) + cmdRc = diskann_handle_command(p, cmd); #endif - if (cmdRc != SQLITE_EMPTY) return cmdRc; // handled (or error) - // SQLITE_EMPTY means not a recognized command — fall through to normal insert + if (cmdRc == SQLITE_EMPTY) { + vtab_set_error(pVTab, "unknown vec0 command: '%s'", cmd); + return SQLITE_ERROR; + } + return cmdRc; + } } -#endif return vec0Update_Insert(pVTab, argc, argv, pRowid); } // UPDATE operation diff --git a/tests/fixtures/legacy-v0.1.6.db b/tests/fixtures/legacy-v0.1.6.db new file mode 100644 index 0000000..58bd89d Binary files /dev/null and b/tests/fixtures/legacy-v0.1.6.db differ diff --git a/tests/fuzz/diskann-command-inject.c b/tests/fuzz/diskann-command-inject.c index ef62884..22661bf 100644 --- a/tests/fuzz/diskann-command-inject.c +++ b/tests/fuzz/diskann-command-inject.c @@ -50,7 +50,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { { sqlite3_stmt *stmt; sqlite3_prepare_v2(db, - "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmt, NULL); + "INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmt, NULL); for (int i = 1; i <= 8; i++) { float vec[8]; for (int j = 0; j < 8; j++) vec[j] = (float)i * 0.1f + (float)j * 0.01f; @@ -66,11 +66,11 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { sqlite3_stmt *stmtInsert = NULL; sqlite3_stmt *stmtKnn = NULL; - /* Commands are dispatched via INSERT INTO t(rowid) VALUES ('cmd_string') */ + /* Commands are dispatched via INSERT INTO t(t) VALUES ('cmd_string') */ sqlite3_prepare_v2(db, - "INSERT INTO v(rowid) VALUES (?)", -1, &stmtCmd, NULL); + "INSERT INTO v(v) VALUES (?)", -1, &stmtCmd, NULL); sqlite3_prepare_v2(db, - "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); + "INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); sqlite3_prepare_v2(db, "SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?", -1, &stmtKnn, NULL); diff --git a/tests/fuzz/ivf-cell-overflow.c b/tests/fuzz/ivf-cell-overflow.c index 4b18ba2..65ae6b2 100644 --- a/tests/fuzz/ivf-cell-overflow.c +++ b/tests/fuzz/ivf-cell-overflow.c @@ -55,7 +55,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { // Insert enough vectors to overflow at least one cell sqlite3_stmt *stmtInsert = NULL; sqlite3_prepare_v2(db, - "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); + "INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); if (!stmtInsert) { sqlite3_close(db); return 0; } size_t offset = 0; @@ -81,7 +81,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { // Train to assign vectors to centroids (triggers cell building) sqlite3_exec(db, - "INSERT INTO v(rowid) VALUES ('compute-centroids')", + "INSERT INTO v(v) VALUES ('compute-centroids')", NULL, NULL, NULL); // Delete vectors at boundary positions based on fuzz data @@ -102,7 +102,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { { sqlite3_stmt *si = NULL; sqlite3_prepare_v2(db, - "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &si, NULL); + "INSERT INTO v(v, emb) VALUES (?, ?)", -1, &si, NULL); if (si) { for (int i = 0; i < 10; i++) { float *vec = sqlite3_malloc(dim * sizeof(float)); @@ -140,7 +140,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { // Test assign-vectors with multi-cell state // First clear centroids sqlite3_exec(db, - "INSERT INTO v(rowid) VALUES ('clear-centroids')", + "INSERT INTO v(v) VALUES ('clear-centroids')", NULL, NULL, NULL); // Set centroids manually, then assign @@ -151,7 +151,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { char cmd[128]; snprintf(cmd, sizeof(cmd), - "INSERT INTO v(rowid, emb) VALUES ('set-centroid:%d', ?)", c); + "INSERT INTO v(v, emb) VALUES ('set-centroid:%d', ?)", c); sqlite3_stmt *sc = NULL; sqlite3_prepare_v2(db, cmd, -1, &sc, NULL); if (sc) { @@ -163,7 +163,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { } sqlite3_exec(db, - "INSERT INTO v(rowid) VALUES ('assign-vectors')", + "INSERT INTO v(v) VALUES ('assign-vectors')", NULL, NULL, NULL); // Final query after assign-vectors diff --git a/tests/fuzz/ivf-kmeans.c b/tests/fuzz/ivf-kmeans.c index 46804d0..1d37184 100644 --- a/tests/fuzz/ivf-kmeans.c +++ b/tests/fuzz/ivf-kmeans.c @@ -64,7 +64,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { // Insert vectors sqlite3_stmt *stmtInsert = NULL; sqlite3_prepare_v2(db, - "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); + "INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); if (!stmtInsert) { sqlite3_close(db); return 0; } size_t offset = 0; @@ -125,14 +125,14 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { // Clear centroids and re-compute to test round-trip sqlite3_exec(db, - "INSERT INTO v(rowid) VALUES ('clear-centroids')", + "INSERT INTO v(v) VALUES ('clear-centroids')", NULL, NULL, NULL); // Insert a few more vectors in untrained state { sqlite3_stmt *si = NULL; sqlite3_prepare_v2(db, - "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &si, NULL); + "INSERT INTO v(v, emb) VALUES (?, ?)", -1, &si, NULL); if (si) { for (int i = 0; i < 3; i++) { float *vec = sqlite3_malloc(dim * sizeof(float)); @@ -150,7 +150,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { // Re-train sqlite3_exec(db, - "INSERT INTO v(rowid) VALUES ('compute-centroids')", + "INSERT INTO v(v) VALUES ('compute-centroids')", NULL, NULL, NULL); // Delete some rows after training, then query diff --git a/tests/fuzz/ivf-knn-deep.c b/tests/fuzz/ivf-knn-deep.c index 27d19a1..f5adb1e 100644 --- a/tests/fuzz/ivf-knn-deep.c +++ b/tests/fuzz/ivf-knn-deep.c @@ -92,7 +92,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { // Insert vectors sqlite3_stmt *stmtInsert = NULL; sqlite3_prepare_v2(db, - "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); + "INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); if (!stmtInsert) { sqlite3_close(db); return 0; } size_t offset = 0; @@ -134,14 +134,14 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { // Train sqlite3_exec(db, - "INSERT INTO v(rowid) VALUES ('compute-centroids')", + "INSERT INTO v(v) VALUES ('compute-centroids')", NULL, NULL, NULL); // Change nprobe at runtime (can exceed nlist -- tests clamping in query) { char cmd[64]; snprintf(cmd, sizeof(cmd), - "INSERT INTO v(rowid) VALUES ('nprobe=%d')", nprobe_initial); + "INSERT INTO v(v) VALUES ('nprobe=%d')", nprobe_initial); sqlite3_exec(db, cmd, NULL, NULL, NULL); } diff --git a/tests/fuzz/ivf-operations.c b/tests/fuzz/ivf-operations.c index a955870..c8d0c01 100644 --- a/tests/fuzz/ivf-operations.c +++ b/tests/fuzz/ivf-operations.c @@ -28,7 +28,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { if (rc != SQLITE_OK) { sqlite3_close(db); return 0; } sqlite3_prepare_v2(db, - "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); + "INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); sqlite3_prepare_v2(db, "DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL); sqlite3_prepare_v2(db, @@ -82,14 +82,14 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { case 4: { // compute-centroids command sqlite3_exec(db, - "INSERT INTO v(rowid) VALUES ('compute-centroids')", + "INSERT INTO v(v) VALUES ('compute-centroids')", NULL, NULL, NULL); break; } case 5: { // clear-centroids command sqlite3_exec(db, - "INSERT INTO v(rowid) VALUES ('clear-centroids')", + "INSERT INTO v(v) VALUES ('clear-centroids')", NULL, NULL, NULL); break; } @@ -100,7 +100,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { int nprobe = (n % 4) + 1; char buf[64]; snprintf(buf, sizeof(buf), - "INSERT INTO v(rowid) VALUES ('nprobe=%d')", nprobe); + "INSERT INTO v(v) VALUES ('nprobe=%d')", nprobe); sqlite3_exec(db, buf, NULL, NULL, NULL); } break; diff --git a/tests/fuzz/ivf-quantize.c b/tests/fuzz/ivf-quantize.c index 22149ee..bc8800b 100644 --- a/tests/fuzz/ivf-quantize.c +++ b/tests/fuzz/ivf-quantize.c @@ -61,7 +61,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { // Insert vectors with fuzz-controlled float values sqlite3_stmt *stmtInsert = NULL; sqlite3_prepare_v2(db, - "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); + "INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); if (!stmtInsert) { sqlite3_close(db); return 0; } size_t offset = 0; @@ -93,7 +93,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { // Trigger compute-centroids to exercise kmeans + quantization together sqlite3_exec(db, - "INSERT INTO v(rowid) VALUES ('compute-centroids')", + "INSERT INTO v(v) VALUES ('compute-centroids')", NULL, NULL, NULL); // KNN query with fuzz-derived query vector diff --git a/tests/fuzz/ivf-rescore.c b/tests/fuzz/ivf-rescore.c index 1c3f34a..3cddf88 100644 --- a/tests/fuzz/ivf-rescore.c +++ b/tests/fuzz/ivf-rescore.c @@ -68,7 +68,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { // Insert vectors with diverse values sqlite3_stmt *stmtInsert = NULL; sqlite3_prepare_v2(db, - "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); + "INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); if (!stmtInsert) { sqlite3_close(db); return 0; } size_t offset = 0; @@ -103,7 +103,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { // Train sqlite3_exec(db, - "INSERT INTO v(rowid) VALUES ('compute-centroids')", + "INSERT INTO v(v) VALUES ('compute-centroids')", NULL, NULL, NULL); // Multiple KNN queries to exercise rescore path @@ -156,7 +156,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { // Retrain after deletions sqlite3_exec(db, - "INSERT INTO v(rowid) VALUES ('compute-centroids')", + "INSERT INTO v(v) VALUES ('compute-centroids')", NULL, NULL, NULL); // Query after retrain diff --git a/tests/fuzz/ivf-shadow-corrupt.c b/tests/fuzz/ivf-shadow-corrupt.c index 1153ac9..74d72c3 100644 --- a/tests/fuzz/ivf-shadow-corrupt.c +++ b/tests/fuzz/ivf-shadow-corrupt.c @@ -46,7 +46,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { { sqlite3_stmt *si = NULL; sqlite3_prepare_v2(db, - "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &si, NULL); + "INSERT INTO v(v, emb) VALUES (?, ?)", -1, &si, NULL); if (!si) { sqlite3_close(db); return 0; } for (int i = 0; i < 10; i++) { float vec[8]; @@ -63,7 +63,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { // Train sqlite3_exec(db, - "INSERT INTO v(rowid) VALUES ('compute-centroids')", + "INSERT INTO v(v) VALUES ('compute-centroids')", NULL, NULL, NULL); // Now corrupt shadow tables based on fuzz input @@ -204,7 +204,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { float newvec[8] = {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}; sqlite3_stmt *si = NULL; sqlite3_prepare_v2(db, - "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &si, NULL); + "INSERT INTO v(v, emb) VALUES (?, ?)", -1, &si, NULL); if (si) { sqlite3_bind_int64(si, 1, 100); sqlite3_bind_blob(si, 2, newvec, sizeof(newvec), SQLITE_STATIC); @@ -215,12 +215,12 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { // compute-centroids over corrupted state sqlite3_exec(db, - "INSERT INTO v(rowid) VALUES ('compute-centroids')", + "INSERT INTO v(v) VALUES ('compute-centroids')", NULL, NULL, NULL); // clear-centroids sqlite3_exec(db, - "INSERT INTO v(rowid) VALUES ('clear-centroids')", + "INSERT INTO v(v) VALUES ('clear-centroids')", NULL, NULL, NULL); sqlite3_close(db); diff --git a/tests/generate_legacy_db.py b/tests/generate_legacy_db.py new file mode 100644 index 0000000..4611690 --- /dev/null +++ b/tests/generate_legacy_db.py @@ -0,0 +1,81 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = ["sqlite-vec==0.1.6"] +# /// +"""Generate a legacy sqlite-vec database for backwards-compat testing. + +Usage: + uv run --script generate_legacy_db.py + +Creates tests/fixtures/legacy-v0.1.6.db with a vec0 table containing +test data that can be read by the current version of sqlite-vec. +""" +import sqlite3 +import sqlite_vec +import struct +import os + +FIXTURE_DIR = os.path.join(os.path.dirname(__file__), "fixtures") +DB_PATH = os.path.join(FIXTURE_DIR, "legacy-v0.1.6.db") + +DIMS = 4 +N_ROWS = 50 + + +def _f32(vals): + return struct.pack(f"{len(vals)}f", *vals) + + +def main(): + os.makedirs(FIXTURE_DIR, exist_ok=True) + if os.path.exists(DB_PATH): + os.remove(DB_PATH) + + db = sqlite3.connect(DB_PATH) + db.enable_load_extension(True) + sqlite_vec.load(db) + + # Print version for verification + version = db.execute("SELECT vec_version()").fetchone()[0] + print(f"sqlite-vec version: {version}") + + # Create a basic vec0 table — flat index, no fancy features + db.execute(f"CREATE VIRTUAL TABLE legacy_vectors USING vec0(emb float[{DIMS}])") + + # Insert test data: vectors where element[0] == rowid for easy verification + for i in range(1, N_ROWS + 1): + vec = [float(i), 0.0, 0.0, 0.0] + db.execute("INSERT INTO legacy_vectors(rowid, emb) VALUES (?, ?)", [i, _f32(vec)]) + + db.commit() + + # Verify + count = db.execute("SELECT count(*) FROM legacy_vectors").fetchone()[0] + print(f"Inserted {count} rows") + + # Test KNN works + query = _f32([1.0, 0.0, 0.0, 0.0]) + rows = db.execute( + "SELECT rowid, distance FROM legacy_vectors WHERE emb MATCH ? AND k = 5", + [query], + ).fetchall() + print(f"KNN top 5: {[(r[0], round(r[1], 4)) for r in rows]}") + assert rows[0][0] == 1 # closest to [1,0,0,0] + assert len(rows) == 5 + + # Also create a table with name == column name (the conflict case) + # This was allowed in old versions — new code must not break on reconnect + db.execute("CREATE VIRTUAL TABLE emb USING vec0(emb float[4])") + for i in range(1, 11): + db.execute("INSERT INTO emb(rowid, emb) VALUES (?, ?)", [i, _f32([float(i), 0, 0, 0])]) + db.commit() + + count2 = db.execute("SELECT count(*) FROM emb").fetchone()[0] + print(f"Table 'emb' with column 'emb': {count2} rows (name conflict case)") + + db.close() + print(f"\nGenerated: {DB_PATH}") + + +if __name__ == "__main__": + main() diff --git a/tests/test-diskann.py b/tests/test-diskann.py index 4369a8b..4e65160 100644 --- a/tests/test-diskann.py +++ b/tests/test-diskann.py @@ -589,7 +589,7 @@ def test_diskann_command_search_list_size(db): assert len(results_before) == 5 # Override search_list_size_search at runtime - db.execute("INSERT INTO t(rowid) VALUES ('search_list_size_search=256')") + db.execute("INSERT INTO t(t) VALUES ('search_list_size_search=256')") # Query should still work results_after = db.execute( @@ -598,14 +598,14 @@ def test_diskann_command_search_list_size(db): assert len(results_after) == 5 # Override search_list_size_insert at runtime - db.execute("INSERT INTO t(rowid) VALUES ('search_list_size_insert=32')") + db.execute("INSERT INTO t(t) VALUES ('search_list_size_insert=32')") # Inserts should still work vec = struct.pack("64f", *[random.random() for _ in range(64)]) db.execute("INSERT INTO t(emb) VALUES (?)", [vec]) # Override unified search_list_size - db.execute("INSERT INTO t(rowid) VALUES ('search_list_size=64')") + db.execute("INSERT INTO t(t) VALUES ('search_list_size=64')") results_final = db.execute( "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k = 5", [query] @@ -620,9 +620,9 @@ def test_diskann_command_search_list_size_error(db): emb float[64] INDEXED BY diskann(neighbor_quantizer=binary) ) """) - result = exec(db, "INSERT INTO t(rowid) VALUES ('search_list_size=0')") + result = exec(db, "INSERT INTO t(t) VALUES ('search_list_size=0')") assert "error" in result - result = exec(db, "INSERT INTO t(rowid) VALUES ('search_list_size=-1')") + result = exec(db, "INSERT INTO t(t) VALUES ('search_list_size=-1')") assert "error" in result diff --git a/tests/test-general.py b/tests/test-general.py index 9b6de5d..4446641 100644 --- a/tests/test-general.py +++ b/tests/test-general.py @@ -27,3 +27,15 @@ def test_info(db, snapshot): assert exec(db, "select key, typeof(value) from v_info order by 1") == snapshot() +def test_command_column_name_conflict(db): + """Table name matching a column name should error (command column conflict).""" + # This would conflict: hidden command column 'embeddings' vs vector column 'embeddings' + with pytest.raises(sqlite3.OperationalError, match="conflicts with table name"): + db.execute( + "create virtual table embeddings using vec0(embeddings float[4])" + ) + + # Different names should work fine + db.execute("create virtual table t using vec0(embeddings float[4])") + + diff --git a/tests/test-ivf-mutations.py b/tests/test-ivf-mutations.py index c20dac3..76c2e1f 100644 --- a/tests/test-ivf-mutations.py +++ b/tests/test-ivf-mutations.py @@ -78,7 +78,7 @@ def test_batch_insert_knn_recall(db): ) assert ivf_total_vectors(db) == 200 - db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + db.execute("INSERT INTO t(t) VALUES ('compute-centroids')") assert ivf_assigned_count(db) == 200 # Query near 100 -- closest should be rowid 100 @@ -107,7 +107,7 @@ def test_delete_rows_gone_from_knn(db): [i, _f32([float(i), 0, 0, 0])], ) - db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + db.execute("INSERT INTO t(t) VALUES ('compute-centroids')") # Delete rowid 10 db.execute("DELETE FROM t WHERE rowid = 10") @@ -127,7 +127,7 @@ def test_delete_all_rows_empty_results(db): [i, _f32([float(i), 0, 0, 0])], ) - db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + db.execute("INSERT INTO t(t) VALUES ('compute-centroids')") for i in range(10): db.execute("DELETE FROM t WHERE rowid = ?", [i]) @@ -152,7 +152,7 @@ def test_insert_after_delete_reuse_rowid(db): [i, _f32([float(i), 0, 0, 0])], ) - db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + db.execute("INSERT INTO t(t) VALUES ('compute-centroids')") # Delete rowid 5 db.execute("DELETE FROM t WHERE rowid = 5") @@ -184,7 +184,7 @@ def test_update_vector_via_delete_insert(db): [i, _f32([float(i), 0, 0, 0])], ) - db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + db.execute("INSERT INTO t(t) VALUES ('compute-centroids')") # "Update" rowid 3: delete and re-insert with new vector db.execute("DELETE FROM t WHERE rowid = 3") @@ -316,7 +316,7 @@ def test_single_row_compute_centroids(db): db.execute( "INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 2, 3, 4])] ) - db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + db.execute("INSERT INTO t(t) VALUES ('compute-centroids')") assert ivf_assigned_count(db) == 1 results = knn(db, [1, 2, 3, 4], 1) @@ -343,10 +343,10 @@ def test_cell_overflow_many_vectors(db): # Set a single centroid so all vectors go there db.execute( - "INSERT INTO t(rowid, v) VALUES ('set-centroid:0', ?)", + "INSERT INTO t(t, v) VALUES ('set-centroid:0', ?)", [_f32([1.0, 0, 0, 0])], ) - db.execute("INSERT INTO t(rowid) VALUES ('assign-vectors')") + db.execute("INSERT INTO t(t) VALUES ('assign-vectors')") assert ivf_assigned_count(db) == 100 @@ -377,7 +377,7 @@ def test_large_batch_with_training(db): [i, _f32([float(i), 0, 0, 0])], ) - db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + db.execute("INSERT INTO t(t) VALUES ('compute-centroids')") for i in range(500, 1000): db.execute( @@ -409,7 +409,7 @@ def test_knn_after_interleaved_insert_delete(db): [i, _f32([float(i), 0, 0, 0])], ) - db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + db.execute("INSERT INTO t(t) VALUES ('compute-centroids')") # Delete rowids 0-9 (closest to query at 5.0) for i in range(10): @@ -434,7 +434,7 @@ def test_knn_empty_centroids_after_deletes(db): [i, _f32([float(i % 10) * 10, 0, 0, 0])], ) - db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + db.execute("INSERT INTO t(t) VALUES ('compute-centroids')") # Delete a bunch, potentially emptying some centroids for i in range(30): @@ -458,7 +458,7 @@ def test_knn_correct_distances(db): db.execute("INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([3, 0, 0, 0])]) db.execute("INSERT INTO t(rowid, v) VALUES (3, ?)", [_f32([0, 4, 0, 0])]) - db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + db.execute("INSERT INTO t(t) VALUES ('compute-centroids')") results = knn(db, [0, 0, 0, 0], 3) result_map = {r[0]: r[1] for r in results} @@ -547,7 +547,7 @@ def test_interleaved_ops_correctness(db): [i, _f32([float(i), 0, 0, 0])], ) - db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + db.execute("INSERT INTO t(t) VALUES ('compute-centroids')") # Phase 2: Delete even-numbered rowids for i in range(0, 50, 2): diff --git a/tests/test-ivf-quantization.py b/tests/test-ivf-quantization.py index b4d6ae3..1529dad 100644 --- a/tests/test-ivf-quantization.py +++ b/tests/test-ivf-quantization.py @@ -122,7 +122,7 @@ def test_ivf_int8_insert_and_query(db): "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])] ) - db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + db.execute("INSERT INTO t(t) VALUES ('compute-centroids')") # Should be able to query rows = db.execute( @@ -151,7 +151,7 @@ def test_ivf_binary_insert_and_query(db): "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32(v)] ) - db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + db.execute("INSERT INTO t(t) VALUES ('compute-centroids')") rows = db.execute( "SELECT rowid FROM t WHERE v MATCH ? AND k = 5", @@ -221,10 +221,10 @@ def test_ivf_int8_oversample_improves_recall(db): db.execute("INSERT INTO t1(rowid, v) VALUES (?, ?)", [i, v]) db.execute("INSERT INTO t2(rowid, v) VALUES (?, ?)", [i, v]) - db.execute("INSERT INTO t1(rowid) VALUES ('compute-centroids')") - db.execute("INSERT INTO t2(rowid) VALUES ('compute-centroids')") - db.execute("INSERT INTO t1(rowid) VALUES ('nprobe=4')") - db.execute("INSERT INTO t2(rowid) VALUES ('nprobe=4')") + db.execute("INSERT INTO t1(t1) VALUES ('compute-centroids')") + db.execute("INSERT INTO t2(t2) VALUES ('compute-centroids')") + db.execute("INSERT INTO t1(t1) VALUES ('nprobe=4')") + db.execute("INSERT INTO t2(t2) VALUES ('nprobe=4')") query = _f32([5.0, 1.5, 2.5, 0.5]) r1 = db.execute("SELECT rowid FROM t1 WHERE v MATCH ? AND k=10", [query]).fetchall() @@ -247,7 +247,7 @@ def test_ivf_quantized_delete(db): "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])] ) - db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + db.execute("INSERT INTO t(t) VALUES ('compute-centroids')") assert db.execute("SELECT count(*) FROM t_ivf_vectors00").fetchone()[0] == 10 db.execute("DELETE FROM t WHERE rowid = 5") diff --git a/tests/test-ivf.py b/tests/test-ivf.py index 18a7532..8b7f566 100644 --- a/tests/test-ivf.py +++ b/tests/test-ivf.py @@ -217,7 +217,7 @@ def test_compute_centroids(db): assert ivf_unassigned_count(db) == 40 - db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + db.execute("INSERT INTO t(t) VALUES ('compute-centroids')") # After training: unassigned cell should be gone (or empty), vectors in trained cells assert ivf_unassigned_count(db) == 0 @@ -238,10 +238,10 @@ def test_compute_centroids_recompute(db): "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])] ) - db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + db.execute("INSERT INTO t(t) VALUES ('compute-centroids')") assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2 - db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + db.execute("INSERT INTO t(t) VALUES ('compute-centroids')") assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2 assert ivf_assigned_count(db) == 20 @@ -260,7 +260,7 @@ def test_ivf_insert_after_training(db): "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])] ) - db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + db.execute("INSERT INTO t(t) VALUES ('compute-centroids')") db.execute( "INSERT INTO t(rowid, v) VALUES (100, ?)", [_f32([5, 0, 0, 0])] @@ -290,7 +290,7 @@ def test_ivf_knn_after_training(db): "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])] ) - db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + db.execute("INSERT INTO t(t) VALUES ('compute-centroids')") rows = db.execute( "SELECT rowid, distance FROM t WHERE v MATCH ? AND k = 5", @@ -310,7 +310,7 @@ def test_ivf_knn_k_larger_than_n(db): "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])] ) - db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + db.execute("INSERT INTO t(t) VALUES ('compute-centroids')") rows = db.execute( "SELECT rowid FROM t WHERE v MATCH ? AND k = 100", @@ -334,17 +334,17 @@ def test_set_centroid_and_assign(db): ) db.execute( - "INSERT INTO t(rowid, v) VALUES ('set-centroid:0', ?)", + "INSERT INTO t(t, v) VALUES ('set-centroid:0', ?)", [_f32([5, 0, 0, 0])], ) db.execute( - "INSERT INTO t(rowid, v) VALUES ('set-centroid:1', ?)", + "INSERT INTO t(t, v) VALUES ('set-centroid:1', ?)", [_f32([15, 0, 0, 0])], ) assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2 - db.execute("INSERT INTO t(rowid) VALUES ('assign-vectors')") + db.execute("INSERT INTO t(t) VALUES ('assign-vectors')") assert ivf_unassigned_count(db) == 0 assert ivf_assigned_count(db) == 20 @@ -364,10 +364,10 @@ def test_clear_centroids(db): "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])] ) - db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + db.execute("INSERT INTO t(t) VALUES ('compute-centroids')") assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2 - db.execute("INSERT INTO t(rowid) VALUES ('clear-centroids')") + db.execute("INSERT INTO t(t) VALUES ('clear-centroids')") assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 0 assert ivf_unassigned_count(db) == 20 trained = db.execute( @@ -390,7 +390,7 @@ def test_ivf_delete_after_training(db): "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])] ) - db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + db.execute("INSERT INTO t(t) VALUES ('compute-centroids')") assert ivf_assigned_count(db) == 10 db.execute("DELETE FROM t WHERE rowid = 5") @@ -412,7 +412,7 @@ def test_ivf_recall_nprobe_equals_nlist(db): "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])] ) - db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + db.execute("INSERT INTO t(t) VALUES ('compute-centroids')") rows = db.execute( "SELECT rowid FROM t WHERE v MATCH ? AND k = 10", diff --git a/tests/test-legacy-compat.py b/tests/test-legacy-compat.py new file mode 100644 index 0000000..8f94d31 --- /dev/null +++ b/tests/test-legacy-compat.py @@ -0,0 +1,138 @@ +"""Backwards compatibility tests: current sqlite-vec reading legacy databases. + +The fixture file tests/fixtures/legacy-v0.1.6.db was generated by +tests/generate_legacy_db.py using sqlite-vec v0.1.6. These tests verify +that the current version can fully read, query, insert into, and delete +from tables created by older versions. +""" +import sqlite3 +import struct +import os +import shutil +import pytest + +FIXTURE_PATH = os.path.join(os.path.dirname(__file__), "fixtures", "legacy-v0.1.6.db") + + +def _f32(vals): + return struct.pack(f"{len(vals)}f", *vals) + + +@pytest.fixture() +def legacy_db(tmp_path): + """Copy the legacy fixture to a temp dir so tests can modify it.""" + if not os.path.exists(FIXTURE_PATH): + pytest.skip("Legacy fixture not found — run: uv run --script tests/generate_legacy_db.py") + db_path = str(tmp_path / "legacy.db") + shutil.copy2(FIXTURE_PATH, db_path) + db = sqlite3.connect(db_path) + db.row_factory = sqlite3.Row + db.enable_load_extension(True) + db.load_extension("dist/vec0") + return db + + +def test_legacy_select_count(legacy_db): + """Basic SELECT count should return all rows.""" + count = legacy_db.execute("SELECT count(*) FROM legacy_vectors").fetchone()[0] + assert count == 50 + + +def test_legacy_point_query(legacy_db): + """Point query by rowid should return correct vector.""" + row = legacy_db.execute( + "SELECT rowid, emb FROM legacy_vectors WHERE rowid = 1" + ).fetchone() + assert row["rowid"] == 1 + vec = struct.unpack("4f", row["emb"]) + assert vec[0] == pytest.approx(1.0) + + +def test_legacy_knn(legacy_db): + """KNN query on legacy table should return correct results.""" + query = _f32([1.0, 0.0, 0.0, 0.0]) + rows = legacy_db.execute( + "SELECT rowid, distance FROM legacy_vectors " + "WHERE emb MATCH ? AND k = 5", + [query], + ).fetchall() + assert len(rows) == 5 + assert rows[0]["rowid"] == 1 + assert rows[0]["distance"] == pytest.approx(0.0) + for i in range(len(rows) - 1): + assert rows[i]["distance"] <= rows[i + 1]["distance"] + + +def test_legacy_insert(legacy_db): + """INSERT into legacy table should work.""" + legacy_db.execute( + "INSERT INTO legacy_vectors(rowid, emb) VALUES (100, ?)", + [_f32([100.0, 0.0, 0.0, 0.0])], + ) + count = legacy_db.execute("SELECT count(*) FROM legacy_vectors").fetchone()[0] + assert count == 51 + + rows = legacy_db.execute( + "SELECT rowid FROM legacy_vectors WHERE emb MATCH ? AND k = 1", + [_f32([100.0, 0.0, 0.0, 0.0])], + ).fetchall() + assert rows[0]["rowid"] == 100 + + +def test_legacy_delete(legacy_db): + """DELETE from legacy table should work.""" + legacy_db.execute("DELETE FROM legacy_vectors WHERE rowid = 1") + count = legacy_db.execute("SELECT count(*) FROM legacy_vectors").fetchone()[0] + assert count == 49 + + rows = legacy_db.execute( + "SELECT rowid FROM legacy_vectors WHERE emb MATCH ? AND k = 5", + [_f32([1.0, 0.0, 0.0, 0.0])], + ).fetchall() + assert 1 not in [r["rowid"] for r in rows] + + +def test_legacy_fullscan(legacy_db): + """Full scan should work.""" + rows = legacy_db.execute( + "SELECT rowid FROM legacy_vectors ORDER BY rowid LIMIT 5" + ).fetchall() + assert [r["rowid"] for r in rows] == [1, 2, 3, 4, 5] + + +def test_legacy_name_conflict_table(legacy_db): + """Legacy table where column name == table name should work. + + The v0.1.6 DB has: CREATE VIRTUAL TABLE emb USING vec0(emb float[4]) + Current code should NOT add the command column for this table + (detected via _info version check), avoiding the name conflict. + """ + count = legacy_db.execute("SELECT count(*) FROM emb").fetchone()[0] + assert count == 10 + + rows = legacy_db.execute( + "SELECT rowid, distance FROM emb WHERE emb MATCH ? AND k = 3", + [_f32([1.0, 0.0, 0.0, 0.0])], + ).fetchall() + assert len(rows) == 3 + assert rows[0]["rowid"] == 1 + + +def test_legacy_name_conflict_insert_delete(legacy_db): + """INSERT and DELETE on legacy name-conflict table.""" + legacy_db.execute( + "INSERT INTO emb(rowid, emb) VALUES (100, ?)", + [_f32([100.0, 0.0, 0.0, 0.0])], + ) + assert legacy_db.execute("SELECT count(*) FROM emb").fetchone()[0] == 11 + + legacy_db.execute("DELETE FROM emb WHERE rowid = 5") + assert legacy_db.execute("SELECT count(*) FROM emb").fetchone()[0] == 10 + + +def test_legacy_no_command_column(legacy_db): + """Legacy tables should NOT have the command column.""" + with pytest.raises(sqlite3.OperationalError): + legacy_db.execute( + "INSERT INTO legacy_vectors(legacy_vectors) VALUES ('some_command')" + ) diff --git a/tests/test-rescore.py b/tests/test-rescore.py index aa8586e..9fda08a 100644 --- a/tests/test-rescore.py +++ b/tests/test-rescore.py @@ -655,3 +655,73 @@ def test_rescore_text_pk_insert_knn_delete(db): ids = [r["id"] for r in rows] assert "alpha" not in ids assert len(rows) >= 1 # other results still returned + + +def test_runtime_oversample(db): + """oversample can be changed at query time via FTS5-style command.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[128] indexed by rescore(quantizer=bit, oversample=2)" + ")" + ) + random.seed(200) + for i in range(200): + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (?, ?)", + [i + 1, float_vec([random.gauss(0, 1) for _ in range(128)])], + ) + + query = float_vec([random.gauss(0, 1) for _ in range(128)]) + + # KNN with default oversample=2 (low) + rows_low = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10", + [query], + ).fetchall() + assert len(rows_low) == 10 + + # Change oversample at runtime to high value + db.execute("INSERT INTO t(t) VALUES ('oversample=32')") + + # KNN with oversample=32 (high) — same or better recall + rows_high = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10", + [query], + ).fetchall() + assert len(rows_high) == 10 + + # Reset to original + db.execute("INSERT INTO t(t) VALUES ('oversample=2')") + + rows_reset = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10", + [query], + ).fetchall() + assert len(rows_reset) == 10 + # After reset, should match the original low-oversample results + assert [r["rowid"] for r in rows_reset] == [r["rowid"] for r in rows_low] + + +def test_runtime_oversample_error(db): + """Invalid oversample values should error.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[128] indexed by rescore(quantizer=bit)" + ")" + ) + with pytest.raises(sqlite3.OperationalError, match="oversample must be >= 1"): + db.execute("INSERT INTO t(t) VALUES ('oversample=0')") + + with pytest.raises(sqlite3.OperationalError, match="oversample must be >= 1"): + db.execute("INSERT INTO t(t) VALUES ('oversample=-5')") + + +def test_unknown_command_errors(db): + """Unknown command strings should produce a clear error.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[128] indexed by rescore(quantizer=bit)" + ")" + ) + with pytest.raises(sqlite3.OperationalError, match="unknown vec0 command"): + db.execute("INSERT INTO t(t) VALUES ('not_a_real_command')")