mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 16:56:27 +02:00
Add FTS5-style command column and runtime oversample for rescore
Replace the old INSERT INTO t(rowid) VALUES('command') hack with a
proper hidden command column named after the table (FTS5 pattern):
INSERT INTO t(t) VALUES ('oversample=16')
The command column is the first hidden column (before distance and k)
to reserve ability for future table-valued function argument use.
Schema: CREATE TABLE x(rowid, <cols>, "<table>" hidden, distance hidden, k hidden)
For backwards compat, pre-v0.1.10 tables (detected via _info shadow
table version) skip the command column to avoid name conflicts with
user columns that may share the table's name. Verified with legacy
fixture DB generated by sqlite-vec v0.1.6.
Changes:
- Add hidden command column to sqlite3_declare_vtab for new tables
- Version-gate via _info shadow table for existing tables
- Validate at CREATE time that no column name matches table name
- Add rescore_handle_command() with oversample=N support
- rescore_knn() prefers runtime oversample_search over CREATE default
- Remove old rowid-based command dispatch
- Migrate all DiskANN/IVF/fuzz tests and benchmarks to new syntax
- Add legacy DB fixture (v0.1.6) and 9 backwards-compat tests
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
b7fc459be4
commit
6e2c4c6bab
21 changed files with 512 additions and 105 deletions
|
|
@ -159,7 +159,7 @@ INDEX_REGISTRY = {
|
||||||
def _ivf_train(conn):
|
def _ivf_train(conn):
|
||||||
"""Trigger built-in k-means training for IVF."""
|
"""Trigger built-in k-means training for IVF."""
|
||||||
t0 = now_ns()
|
t0 = now_ns()
|
||||||
conn.execute("INSERT INTO vec_items(id) VALUES ('compute-centroids')")
|
conn.execute("INSERT INTO vec_items(vec_items) VALUES ('compute-centroids')")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
return ns_to_s(now_ns() - t0)
|
return ns_to_s(now_ns() - t0)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -456,7 +456,7 @@ def _ivf_create_table_sql(params):
|
||||||
def _ivf_post_insert_hook(conn, params):
|
def _ivf_post_insert_hook(conn, params):
|
||||||
print(" Training k-means centroids (built-in)...", flush=True)
|
print(" Training k-means centroids (built-in)...", flush=True)
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
conn.execute("INSERT INTO vec_items(id) VALUES ('compute-centroids')")
|
conn.execute("INSERT INTO vec_items(vec_items) VALUES ('compute-centroids')")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
elapsed = time.perf_counter() - t0
|
elapsed = time.perf_counter() - t0
|
||||||
print(f" Training done in {elapsed:.1f}s", flush=True)
|
print(f" Training done in {elapsed:.1f}s", flush=True)
|
||||||
|
|
@ -514,7 +514,7 @@ def _ivf_faiss_kmeans_hook(conn, params):
|
||||||
|
|
||||||
for cid, blob in centroids:
|
for cid, blob in centroids:
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT INTO vec_items(id, embedding) VALUES (?, ?)",
|
"INSERT INTO vec_items(vec_items, embedding) VALUES (?, ?)",
|
||||||
(f"set-centroid:{cid}", blob),
|
(f"set-centroid:{cid}", blob),
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
@ -540,7 +540,7 @@ def _ivf_pre_query_hook(conn, params):
|
||||||
nprobe = params.get("nprobe")
|
nprobe = params.get("nprobe")
|
||||||
if nprobe:
|
if nprobe:
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT INTO vec_items(id) VALUES (?)",
|
"INSERT INTO vec_items(vec_items) VALUES (?)",
|
||||||
(f"nprobe={nprobe}",),
|
(f"nprobe={nprobe}",),
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
@ -572,7 +572,7 @@ INDEX_REGISTRY["ivf"] = {
|
||||||
"insert_sql": None,
|
"insert_sql": None,
|
||||||
"post_insert_hook": _ivf_post_insert_hook,
|
"post_insert_hook": _ivf_post_insert_hook,
|
||||||
"pre_query_hook": _ivf_pre_query_hook,
|
"pre_query_hook": _ivf_pre_query_hook,
|
||||||
"train_sql": lambda _: "INSERT INTO vec_items(id) VALUES ('compute-centroids')",
|
"train_sql": lambda _: "INSERT INTO vec_items(vec_items) VALUES ('compute-centroids')",
|
||||||
"run_query": None,
|
"run_query": None,
|
||||||
"query_sql": None,
|
"query_sql": None,
|
||||||
"describe": _ivf_describe,
|
"describe": _ivf_describe,
|
||||||
|
|
@ -616,7 +616,7 @@ def _diskann_pre_query_hook(conn, params):
|
||||||
L_search = params.get("L_search", 0)
|
L_search = params.get("L_search", 0)
|
||||||
if L_search:
|
if L_search:
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT INTO vec_items(id) VALUES (?)",
|
"INSERT INTO vec_items(vec_items) VALUES (?)",
|
||||||
(f"search_list_size_search={L_search}",),
|
(f"search_list_size_search={L_search}",),
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
|
||||||
|
|
@ -351,7 +351,9 @@ static int rescore_knn(vec0_vtab *p, vec0_cursor *pCur,
|
||||||
(void)pCur;
|
(void)pCur;
|
||||||
(void)aMetadataIn;
|
(void)aMetadataIn;
|
||||||
int rc = SQLITE_OK;
|
int rc = SQLITE_OK;
|
||||||
int oversample = vector_column->rescore.oversample;
|
int oversample = vector_column->rescore.oversample_search > 0
|
||||||
|
? vector_column->rescore.oversample_search
|
||||||
|
: vector_column->rescore.oversample;
|
||||||
i64 k_oversample = k * oversample;
|
i64 k_oversample = k * oversample;
|
||||||
if (k_oversample > 4096)
|
if (k_oversample > 4096)
|
||||||
k_oversample = 4096;
|
k_oversample = 4096;
|
||||||
|
|
@ -640,6 +642,27 @@ cleanup:
|
||||||
return rc;
|
return rc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle FTS5-style command dispatch for rescore parameters.
|
||||||
|
* Returns SQLITE_OK if handled, SQLITE_EMPTY if not a rescore command.
|
||||||
|
*/
|
||||||
|
static int rescore_handle_command(vec0_vtab *p, const char *command) {
|
||||||
|
if (strncmp(command, "oversample=", 11) == 0) {
|
||||||
|
int val = atoi(command + 11);
|
||||||
|
if (val < 1) {
|
||||||
|
vtab_set_error(&p->base, "oversample must be >= 1");
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||||
|
if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE) {
|
||||||
|
p->vector_columns[i].rescore.oversample_search = val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return SQLITE_OK;
|
||||||
|
}
|
||||||
|
return SQLITE_EMPTY;
|
||||||
|
}
|
||||||
|
|
||||||
#ifdef SQLITE_VEC_TEST
|
#ifdef SQLITE_VEC_TEST
|
||||||
void _test_rescore_quantize_float_to_bit(const float *src, uint8_t *dst, size_t dim) {
|
void _test_rescore_quantize_float_to_bit(const float *src, uint8_t *dst, size_t dim) {
|
||||||
rescore_quantize_float_to_bit(src, dst, dim);
|
rescore_quantize_float_to_bit(src, dst, dim);
|
||||||
|
|
|
||||||
129
sqlite-vec.c
129
sqlite-vec.c
|
|
@ -2588,7 +2588,8 @@ enum Vec0RescoreQuantizerType {
|
||||||
|
|
||||||
struct Vec0RescoreConfig {
|
struct Vec0RescoreConfig {
|
||||||
enum Vec0RescoreQuantizerType quantizer_type;
|
enum Vec0RescoreQuantizerType quantizer_type;
|
||||||
int oversample;
|
int oversample; // CREATE-time default
|
||||||
|
int oversample_search; // runtime override (0 = use default)
|
||||||
};
|
};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
@ -3399,8 +3400,9 @@ static sqlite3_module vec_eachModule = {
|
||||||
|
|
||||||
#define VEC0_COLUMN_ID 0
|
#define VEC0_COLUMN_ID 0
|
||||||
#define VEC0_COLUMN_USERN_START 1
|
#define VEC0_COLUMN_USERN_START 1
|
||||||
#define VEC0_COLUMN_OFFSET_DISTANCE 1
|
#define VEC0_COLUMN_OFFSET_COMMAND 1
|
||||||
#define VEC0_COLUMN_OFFSET_K 2
|
#define VEC0_COLUMN_OFFSET_DISTANCE 2
|
||||||
|
#define VEC0_COLUMN_OFFSET_K 3
|
||||||
|
|
||||||
#define VEC0_SHADOW_INFO_NAME "\"%w\".\"%w_info\""
|
#define VEC0_SHADOW_INFO_NAME "\"%w\".\"%w_info\""
|
||||||
|
|
||||||
|
|
@ -3498,6 +3500,10 @@ struct vec0_vtab {
|
||||||
// Will change the schema of the _rowids table, and insert/query logic.
|
// Will change the schema of the _rowids table, and insert/query logic.
|
||||||
int pkIsText;
|
int pkIsText;
|
||||||
|
|
||||||
|
// True if the hidden command column (named after the table) exists.
|
||||||
|
// Tables created before v0.1.10 or without _info table don't have it.
|
||||||
|
int hasCommandColumn;
|
||||||
|
|
||||||
// number of defined vector columns.
|
// number of defined vector columns.
|
||||||
int numVectorColumns;
|
int numVectorColumns;
|
||||||
|
|
||||||
|
|
@ -3777,20 +3783,19 @@ int vec0_num_defined_user_columns(vec0_vtab *p) {
|
||||||
* @param p vec0 table
|
* @param p vec0 table
|
||||||
* @return int
|
* @return int
|
||||||
*/
|
*/
|
||||||
int vec0_column_distance_idx(vec0_vtab *p) {
|
int vec0_column_command_idx(vec0_vtab *p) {
|
||||||
return VEC0_COLUMN_USERN_START + (vec0_num_defined_user_columns(p) - 1) +
|
// Command column is the first hidden column (right after user columns)
|
||||||
VEC0_COLUMN_OFFSET_DISTANCE;
|
return VEC0_COLUMN_USERN_START + vec0_num_defined_user_columns(p);
|
||||||
|
}
|
||||||
|
|
||||||
|
int vec0_column_distance_idx(vec0_vtab *p) {
|
||||||
|
int base = VEC0_COLUMN_USERN_START + vec0_num_defined_user_columns(p);
|
||||||
|
return base + (p->hasCommandColumn ? 1 : 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Returns the index of the k hidden column for the given vec0 table.
|
|
||||||
*
|
|
||||||
* @param p vec0 table
|
|
||||||
* @return int k column index
|
|
||||||
*/
|
|
||||||
int vec0_column_k_idx(vec0_vtab *p) {
|
int vec0_column_k_idx(vec0_vtab *p) {
|
||||||
return VEC0_COLUMN_USERN_START + (vec0_num_defined_user_columns(p) - 1) +
|
int base = VEC0_COLUMN_USERN_START + vec0_num_defined_user_columns(p);
|
||||||
VEC0_COLUMN_OFFSET_K;
|
return base + (p->hasCommandColumn ? 2 : 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -5205,6 +5210,74 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Determine whether to add the FTS5-style hidden command column.
|
||||||
|
// New tables (isCreate) always get it; existing tables only if created
|
||||||
|
// with v0.1.10+ (which validated no column name == table name).
|
||||||
|
int hasCommandColumn = 0;
|
||||||
|
if (isCreate) {
|
||||||
|
// Validate no user column name conflicts with the table name
|
||||||
|
const char *tblName = argv[2];
|
||||||
|
int tblNameLen = (int)strlen(tblName);
|
||||||
|
for (int i = 0; i < numVectorColumns; i++) {
|
||||||
|
if (pNew->vector_columns[i].name_length == tblNameLen &&
|
||||||
|
sqlite3_strnicmp(pNew->vector_columns[i].name, tblName, tblNameLen) == 0) {
|
||||||
|
*pzErr = sqlite3_mprintf(
|
||||||
|
VEC_CONSTRUCTOR_ERROR
|
||||||
|
"column name '%s' conflicts with table name (reserved for command column)",
|
||||||
|
tblName);
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < numPartitionColumns; i++) {
|
||||||
|
if (pNew->paritition_columns[i].name_length == tblNameLen &&
|
||||||
|
sqlite3_strnicmp(pNew->paritition_columns[i].name, tblName, tblNameLen) == 0) {
|
||||||
|
*pzErr = sqlite3_mprintf(
|
||||||
|
VEC_CONSTRUCTOR_ERROR
|
||||||
|
"column name '%s' conflicts with table name (reserved for command column)",
|
||||||
|
tblName);
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < numAuxiliaryColumns; i++) {
|
||||||
|
if (pNew->auxiliary_columns[i].name_length == tblNameLen &&
|
||||||
|
sqlite3_strnicmp(pNew->auxiliary_columns[i].name, tblName, tblNameLen) == 0) {
|
||||||
|
*pzErr = sqlite3_mprintf(
|
||||||
|
VEC_CONSTRUCTOR_ERROR
|
||||||
|
"column name '%s' conflicts with table name (reserved for command column)",
|
||||||
|
tblName);
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < numMetadataColumns; i++) {
|
||||||
|
if (pNew->metadata_columns[i].name_length == tblNameLen &&
|
||||||
|
sqlite3_strnicmp(pNew->metadata_columns[i].name, tblName, tblNameLen) == 0) {
|
||||||
|
*pzErr = sqlite3_mprintf(
|
||||||
|
VEC_CONSTRUCTOR_ERROR
|
||||||
|
"column name '%s' conflicts with table name (reserved for command column)",
|
||||||
|
tblName);
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
hasCommandColumn = 1;
|
||||||
|
} else {
|
||||||
|
// xConnect: check _info shadow table for version
|
||||||
|
sqlite3_stmt *stmtInfo = NULL;
|
||||||
|
char *zInfoSql = sqlite3_mprintf(
|
||||||
|
"SELECT value FROM " VEC0_SHADOW_INFO_NAME " WHERE key = 'CREATE_VERSION_PATCH'",
|
||||||
|
argv[1], argv[2]);
|
||||||
|
if (zInfoSql) {
|
||||||
|
int infoRc = sqlite3_prepare_v2(db, zInfoSql, -1, &stmtInfo, NULL);
|
||||||
|
sqlite3_free(zInfoSql);
|
||||||
|
if (infoRc == SQLITE_OK && sqlite3_step(stmtInfo) == SQLITE_ROW) {
|
||||||
|
int patch = sqlite3_column_int(stmtInfo, 0);
|
||||||
|
hasCommandColumn = (patch >= 10); // v0.1.10+
|
||||||
|
}
|
||||||
|
// If _info doesn't exist or has no version, assume old table
|
||||||
|
sqlite3_finalize(stmtInfo);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pNew->hasCommandColumn = hasCommandColumn;
|
||||||
|
|
||||||
sqlite3_str *createStr = sqlite3_str_new(NULL);
|
sqlite3_str *createStr = sqlite3_str_new(NULL);
|
||||||
sqlite3_str_appendall(createStr, "CREATE TABLE x(");
|
sqlite3_str_appendall(createStr, "CREATE TABLE x(");
|
||||||
if (pkColumnName) {
|
if (pkColumnName) {
|
||||||
|
|
@ -5246,7 +5319,11 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
if (hasCommandColumn) {
|
||||||
|
sqlite3_str_appendf(createStr, " \"%w\" hidden, distance hidden, k hidden) ", argv[2]);
|
||||||
|
} else {
|
||||||
sqlite3_str_appendall(createStr, " distance hidden, k hidden) ");
|
sqlite3_str_appendall(createStr, " distance hidden, k hidden) ");
|
||||||
|
}
|
||||||
if (pkColumnName) {
|
if (pkColumnName) {
|
||||||
sqlite3_str_appendall(createStr, "without rowid ");
|
sqlite3_str_appendall(createStr, "without rowid ");
|
||||||
}
|
}
|
||||||
|
|
@ -10161,25 +10238,31 @@ static int vec0Update(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv,
|
||||||
}
|
}
|
||||||
// INSERT operation
|
// INSERT operation
|
||||||
else if (argc > 1 && sqlite3_value_type(argv[0]) == SQLITE_NULL) {
|
else if (argc > 1 && sqlite3_value_type(argv[0]) == SQLITE_NULL) {
|
||||||
#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE || SQLITE_VEC_ENABLE_DISKANN
|
|
||||||
// Check for command inserts: INSERT INTO t(rowid) VALUES ('command-string')
|
|
||||||
// The id column holds the command string.
|
|
||||||
sqlite3_value *idVal = argv[2 + VEC0_COLUMN_ID];
|
|
||||||
if (sqlite3_value_type(idVal) == SQLITE_TEXT) {
|
|
||||||
const char *cmd = (const char *)sqlite3_value_text(idVal);
|
|
||||||
vec0_vtab *p = (vec0_vtab *)pVTab;
|
vec0_vtab *p = (vec0_vtab *)pVTab;
|
||||||
|
// FTS5-style command dispatch via hidden column named after table
|
||||||
|
if (p->hasCommandColumn) {
|
||||||
|
sqlite3_value *cmdVal = argv[2 + vec0_column_command_idx(p)];
|
||||||
|
if (sqlite3_value_type(cmdVal) == SQLITE_TEXT) {
|
||||||
|
const char *cmd = (const char *)sqlite3_value_text(cmdVal);
|
||||||
int cmdRc = SQLITE_EMPTY;
|
int cmdRc = SQLITE_EMPTY;
|
||||||
|
#if SQLITE_VEC_ENABLE_RESCORE
|
||||||
|
cmdRc = rescore_handle_command(p, cmd);
|
||||||
|
#endif
|
||||||
#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE
|
#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE
|
||||||
|
if (cmdRc == SQLITE_EMPTY)
|
||||||
cmdRc = ivf_handle_command(p, cmd, argc, argv);
|
cmdRc = ivf_handle_command(p, cmd, argc, argv);
|
||||||
#endif
|
#endif
|
||||||
#if SQLITE_VEC_ENABLE_DISKANN
|
#if SQLITE_VEC_ENABLE_DISKANN
|
||||||
if (cmdRc == SQLITE_EMPTY)
|
if (cmdRc == SQLITE_EMPTY)
|
||||||
cmdRc = diskann_handle_command(p, cmd);
|
cmdRc = diskann_handle_command(p, cmd);
|
||||||
#endif
|
#endif
|
||||||
if (cmdRc != SQLITE_EMPTY) return cmdRc; // handled (or error)
|
if (cmdRc == SQLITE_EMPTY) {
|
||||||
// SQLITE_EMPTY means not a recognized command — fall through to normal insert
|
vtab_set_error(pVTab, "unknown vec0 command: '%s'", cmd);
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
|
return cmdRc;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
return vec0Update_Insert(pVTab, argc, argv, pRowid);
|
return vec0Update_Insert(pVTab, argc, argv, pRowid);
|
||||||
}
|
}
|
||||||
// UPDATE operation
|
// UPDATE operation
|
||||||
|
|
|
||||||
BIN
tests/fixtures/legacy-v0.1.6.db
vendored
Normal file
BIN
tests/fixtures/legacy-v0.1.6.db
vendored
Normal file
Binary file not shown.
|
|
@ -50,7 +50,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
{
|
{
|
||||||
sqlite3_stmt *stmt;
|
sqlite3_stmt *stmt;
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmt, NULL);
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmt, NULL);
|
||||||
for (int i = 1; i <= 8; i++) {
|
for (int i = 1; i <= 8; i++) {
|
||||||
float vec[8];
|
float vec[8];
|
||||||
for (int j = 0; j < 8; j++) vec[j] = (float)i * 0.1f + (float)j * 0.01f;
|
for (int j = 0; j < 8; j++) vec[j] = (float)i * 0.1f + (float)j * 0.01f;
|
||||||
|
|
@ -66,11 +66,11 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
sqlite3_stmt *stmtInsert = NULL;
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
sqlite3_stmt *stmtKnn = NULL;
|
sqlite3_stmt *stmtKnn = NULL;
|
||||||
|
|
||||||
/* Commands are dispatched via INSERT INTO t(rowid) VALUES ('cmd_string') */
|
/* Commands are dispatched via INSERT INTO t(t) VALUES ('cmd_string') */
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"INSERT INTO v(rowid) VALUES (?)", -1, &stmtCmd, NULL);
|
"INSERT INTO v(v) VALUES (?)", -1, &stmtCmd, NULL);
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?",
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?",
|
||||||
-1, &stmtKnn, NULL);
|
-1, &stmtKnn, NULL);
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
// Insert enough vectors to overflow at least one cell
|
// Insert enough vectors to overflow at least one cell
|
||||||
sqlite3_stmt *stmtInsert = NULL;
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
|
|
@ -81,7 +81,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
|
||||||
// Train to assign vectors to centroids (triggers cell building)
|
// Train to assign vectors to centroids (triggers cell building)
|
||||||
sqlite3_exec(db,
|
sqlite3_exec(db,
|
||||||
"INSERT INTO v(rowid) VALUES ('compute-centroids')",
|
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||||
NULL, NULL, NULL);
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
// Delete vectors at boundary positions based on fuzz data
|
// Delete vectors at boundary positions based on fuzz data
|
||||||
|
|
@ -102,7 +102,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
{
|
{
|
||||||
sqlite3_stmt *si = NULL;
|
sqlite3_stmt *si = NULL;
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &si, NULL);
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &si, NULL);
|
||||||
if (si) {
|
if (si) {
|
||||||
for (int i = 0; i < 10; i++) {
|
for (int i = 0; i < 10; i++) {
|
||||||
float *vec = sqlite3_malloc(dim * sizeof(float));
|
float *vec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
|
@ -140,7 +140,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
// Test assign-vectors with multi-cell state
|
// Test assign-vectors with multi-cell state
|
||||||
// First clear centroids
|
// First clear centroids
|
||||||
sqlite3_exec(db,
|
sqlite3_exec(db,
|
||||||
"INSERT INTO v(rowid) VALUES ('clear-centroids')",
|
"INSERT INTO v(v) VALUES ('clear-centroids')",
|
||||||
NULL, NULL, NULL);
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
// Set centroids manually, then assign
|
// Set centroids manually, then assign
|
||||||
|
|
@ -151,7 +151,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
|
||||||
char cmd[128];
|
char cmd[128];
|
||||||
snprintf(cmd, sizeof(cmd),
|
snprintf(cmd, sizeof(cmd),
|
||||||
"INSERT INTO v(rowid, emb) VALUES ('set-centroid:%d', ?)", c);
|
"INSERT INTO v(v, emb) VALUES ('set-centroid:%d', ?)", c);
|
||||||
sqlite3_stmt *sc = NULL;
|
sqlite3_stmt *sc = NULL;
|
||||||
sqlite3_prepare_v2(db, cmd, -1, &sc, NULL);
|
sqlite3_prepare_v2(db, cmd, -1, &sc, NULL);
|
||||||
if (sc) {
|
if (sc) {
|
||||||
|
|
@ -163,7 +163,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
}
|
}
|
||||||
|
|
||||||
sqlite3_exec(db,
|
sqlite3_exec(db,
|
||||||
"INSERT INTO v(rowid) VALUES ('assign-vectors')",
|
"INSERT INTO v(v) VALUES ('assign-vectors')",
|
||||||
NULL, NULL, NULL);
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
// Final query after assign-vectors
|
// Final query after assign-vectors
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
// Insert vectors
|
// Insert vectors
|
||||||
sqlite3_stmt *stmtInsert = NULL;
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
|
|
@ -125,14 +125,14 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
|
||||||
// Clear centroids and re-compute to test round-trip
|
// Clear centroids and re-compute to test round-trip
|
||||||
sqlite3_exec(db,
|
sqlite3_exec(db,
|
||||||
"INSERT INTO v(rowid) VALUES ('clear-centroids')",
|
"INSERT INTO v(v) VALUES ('clear-centroids')",
|
||||||
NULL, NULL, NULL);
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
// Insert a few more vectors in untrained state
|
// Insert a few more vectors in untrained state
|
||||||
{
|
{
|
||||||
sqlite3_stmt *si = NULL;
|
sqlite3_stmt *si = NULL;
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &si, NULL);
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &si, NULL);
|
||||||
if (si) {
|
if (si) {
|
||||||
for (int i = 0; i < 3; i++) {
|
for (int i = 0; i < 3; i++) {
|
||||||
float *vec = sqlite3_malloc(dim * sizeof(float));
|
float *vec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
|
@ -150,7 +150,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
|
||||||
// Re-train
|
// Re-train
|
||||||
sqlite3_exec(db,
|
sqlite3_exec(db,
|
||||||
"INSERT INTO v(rowid) VALUES ('compute-centroids')",
|
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||||
NULL, NULL, NULL);
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
// Delete some rows after training, then query
|
// Delete some rows after training, then query
|
||||||
|
|
|
||||||
|
|
@ -92,7 +92,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
// Insert vectors
|
// Insert vectors
|
||||||
sqlite3_stmt *stmtInsert = NULL;
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
|
|
@ -134,14 +134,14 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
|
||||||
// Train
|
// Train
|
||||||
sqlite3_exec(db,
|
sqlite3_exec(db,
|
||||||
"INSERT INTO v(rowid) VALUES ('compute-centroids')",
|
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||||
NULL, NULL, NULL);
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
// Change nprobe at runtime (can exceed nlist -- tests clamping in query)
|
// Change nprobe at runtime (can exceed nlist -- tests clamping in query)
|
||||||
{
|
{
|
||||||
char cmd[64];
|
char cmd[64];
|
||||||
snprintf(cmd, sizeof(cmd),
|
snprintf(cmd, sizeof(cmd),
|
||||||
"INSERT INTO v(rowid) VALUES ('nprobe=%d')", nprobe_initial);
|
"INSERT INTO v(v) VALUES ('nprobe=%d')", nprobe_initial);
|
||||||
sqlite3_exec(db, cmd, NULL, NULL, NULL);
|
sqlite3_exec(db, cmd, NULL, NULL, NULL);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
|
"DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
|
|
@ -82,14 +82,14 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
case 4: {
|
case 4: {
|
||||||
// compute-centroids command
|
// compute-centroids command
|
||||||
sqlite3_exec(db,
|
sqlite3_exec(db,
|
||||||
"INSERT INTO v(rowid) VALUES ('compute-centroids')",
|
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||||
NULL, NULL, NULL);
|
NULL, NULL, NULL);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case 5: {
|
case 5: {
|
||||||
// clear-centroids command
|
// clear-centroids command
|
||||||
sqlite3_exec(db,
|
sqlite3_exec(db,
|
||||||
"INSERT INTO v(rowid) VALUES ('clear-centroids')",
|
"INSERT INTO v(v) VALUES ('clear-centroids')",
|
||||||
NULL, NULL, NULL);
|
NULL, NULL, NULL);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
@ -100,7 +100,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
int nprobe = (n % 4) + 1;
|
int nprobe = (n % 4) + 1;
|
||||||
char buf[64];
|
char buf[64];
|
||||||
snprintf(buf, sizeof(buf),
|
snprintf(buf, sizeof(buf),
|
||||||
"INSERT INTO v(rowid) VALUES ('nprobe=%d')", nprobe);
|
"INSERT INTO v(v) VALUES ('nprobe=%d')", nprobe);
|
||||||
sqlite3_exec(db, buf, NULL, NULL, NULL);
|
sqlite3_exec(db, buf, NULL, NULL, NULL);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
|
||||||
|
|
@ -61,7 +61,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
// Insert vectors with fuzz-controlled float values
|
// Insert vectors with fuzz-controlled float values
|
||||||
sqlite3_stmt *stmtInsert = NULL;
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
|
|
@ -93,7 +93,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
|
||||||
// Trigger compute-centroids to exercise kmeans + quantization together
|
// Trigger compute-centroids to exercise kmeans + quantization together
|
||||||
sqlite3_exec(db,
|
sqlite3_exec(db,
|
||||||
"INSERT INTO v(rowid) VALUES ('compute-centroids')",
|
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||||
NULL, NULL, NULL);
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
// KNN query with fuzz-derived query vector
|
// KNN query with fuzz-derived query vector
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
// Insert vectors with diverse values
|
// Insert vectors with diverse values
|
||||||
sqlite3_stmt *stmtInsert = NULL;
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
|
|
@ -103,7 +103,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
|
||||||
// Train
|
// Train
|
||||||
sqlite3_exec(db,
|
sqlite3_exec(db,
|
||||||
"INSERT INTO v(rowid) VALUES ('compute-centroids')",
|
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||||
NULL, NULL, NULL);
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
// Multiple KNN queries to exercise rescore path
|
// Multiple KNN queries to exercise rescore path
|
||||||
|
|
@ -156,7 +156,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
|
||||||
// Retrain after deletions
|
// Retrain after deletions
|
||||||
sqlite3_exec(db,
|
sqlite3_exec(db,
|
||||||
"INSERT INTO v(rowid) VALUES ('compute-centroids')",
|
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||||
NULL, NULL, NULL);
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
// Query after retrain
|
// Query after retrain
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
{
|
{
|
||||||
sqlite3_stmt *si = NULL;
|
sqlite3_stmt *si = NULL;
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &si, NULL);
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &si, NULL);
|
||||||
if (!si) { sqlite3_close(db); return 0; }
|
if (!si) { sqlite3_close(db); return 0; }
|
||||||
for (int i = 0; i < 10; i++) {
|
for (int i = 0; i < 10; i++) {
|
||||||
float vec[8];
|
float vec[8];
|
||||||
|
|
@ -63,7 +63,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
|
||||||
// Train
|
// Train
|
||||||
sqlite3_exec(db,
|
sqlite3_exec(db,
|
||||||
"INSERT INTO v(rowid) VALUES ('compute-centroids')",
|
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||||
NULL, NULL, NULL);
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
// Now corrupt shadow tables based on fuzz input
|
// Now corrupt shadow tables based on fuzz input
|
||||||
|
|
@ -204,7 +204,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
float newvec[8] = {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f};
|
float newvec[8] = {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f};
|
||||||
sqlite3_stmt *si = NULL;
|
sqlite3_stmt *si = NULL;
|
||||||
sqlite3_prepare_v2(db,
|
sqlite3_prepare_v2(db,
|
||||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &si, NULL);
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &si, NULL);
|
||||||
if (si) {
|
if (si) {
|
||||||
sqlite3_bind_int64(si, 1, 100);
|
sqlite3_bind_int64(si, 1, 100);
|
||||||
sqlite3_bind_blob(si, 2, newvec, sizeof(newvec), SQLITE_STATIC);
|
sqlite3_bind_blob(si, 2, newvec, sizeof(newvec), SQLITE_STATIC);
|
||||||
|
|
@ -215,12 +215,12 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
|
||||||
// compute-centroids over corrupted state
|
// compute-centroids over corrupted state
|
||||||
sqlite3_exec(db,
|
sqlite3_exec(db,
|
||||||
"INSERT INTO v(rowid) VALUES ('compute-centroids')",
|
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||||
NULL, NULL, NULL);
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
// clear-centroids
|
// clear-centroids
|
||||||
sqlite3_exec(db,
|
sqlite3_exec(db,
|
||||||
"INSERT INTO v(rowid) VALUES ('clear-centroids')",
|
"INSERT INTO v(v) VALUES ('clear-centroids')",
|
||||||
NULL, NULL, NULL);
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
sqlite3_close(db);
|
sqlite3_close(db);
|
||||||
|
|
|
||||||
81
tests/generate_legacy_db.py
Normal file
81
tests/generate_legacy_db.py
Normal file
|
|
@ -0,0 +1,81 @@
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.10"
|
||||||
|
# dependencies = ["sqlite-vec==0.1.6"]
|
||||||
|
# ///
|
||||||
|
"""Generate a legacy sqlite-vec database for backwards-compat testing.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
uv run --script generate_legacy_db.py
|
||||||
|
|
||||||
|
Creates tests/fixtures/legacy-v0.1.6.db with a vec0 table containing
|
||||||
|
test data that can be read by the current version of sqlite-vec.
|
||||||
|
"""
|
||||||
|
import sqlite3
|
||||||
|
import sqlite_vec
|
||||||
|
import struct
|
||||||
|
import os
|
||||||
|
|
||||||
|
FIXTURE_DIR = os.path.join(os.path.dirname(__file__), "fixtures")
|
||||||
|
DB_PATH = os.path.join(FIXTURE_DIR, "legacy-v0.1.6.db")
|
||||||
|
|
||||||
|
DIMS = 4
|
||||||
|
N_ROWS = 50
|
||||||
|
|
||||||
|
|
||||||
|
def _f32(vals):
|
||||||
|
return struct.pack(f"{len(vals)}f", *vals)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
os.makedirs(FIXTURE_DIR, exist_ok=True)
|
||||||
|
if os.path.exists(DB_PATH):
|
||||||
|
os.remove(DB_PATH)
|
||||||
|
|
||||||
|
db = sqlite3.connect(DB_PATH)
|
||||||
|
db.enable_load_extension(True)
|
||||||
|
sqlite_vec.load(db)
|
||||||
|
|
||||||
|
# Print version for verification
|
||||||
|
version = db.execute("SELECT vec_version()").fetchone()[0]
|
||||||
|
print(f"sqlite-vec version: {version}")
|
||||||
|
|
||||||
|
# Create a basic vec0 table — flat index, no fancy features
|
||||||
|
db.execute(f"CREATE VIRTUAL TABLE legacy_vectors USING vec0(emb float[{DIMS}])")
|
||||||
|
|
||||||
|
# Insert test data: vectors where element[0] == rowid for easy verification
|
||||||
|
for i in range(1, N_ROWS + 1):
|
||||||
|
vec = [float(i), 0.0, 0.0, 0.0]
|
||||||
|
db.execute("INSERT INTO legacy_vectors(rowid, emb) VALUES (?, ?)", [i, _f32(vec)])
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
# Verify
|
||||||
|
count = db.execute("SELECT count(*) FROM legacy_vectors").fetchone()[0]
|
||||||
|
print(f"Inserted {count} rows")
|
||||||
|
|
||||||
|
# Test KNN works
|
||||||
|
query = _f32([1.0, 0.0, 0.0, 0.0])
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT rowid, distance FROM legacy_vectors WHERE emb MATCH ? AND k = 5",
|
||||||
|
[query],
|
||||||
|
).fetchall()
|
||||||
|
print(f"KNN top 5: {[(r[0], round(r[1], 4)) for r in rows]}")
|
||||||
|
assert rows[0][0] == 1 # closest to [1,0,0,0]
|
||||||
|
assert len(rows) == 5
|
||||||
|
|
||||||
|
# Also create a table with name == column name (the conflict case)
|
||||||
|
# This was allowed in old versions — new code must not break on reconnect
|
||||||
|
db.execute("CREATE VIRTUAL TABLE emb USING vec0(emb float[4])")
|
||||||
|
for i in range(1, 11):
|
||||||
|
db.execute("INSERT INTO emb(rowid, emb) VALUES (?, ?)", [i, _f32([float(i), 0, 0, 0])])
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
count2 = db.execute("SELECT count(*) FROM emb").fetchone()[0]
|
||||||
|
print(f"Table 'emb' with column 'emb': {count2} rows (name conflict case)")
|
||||||
|
|
||||||
|
db.close()
|
||||||
|
print(f"\nGenerated: {DB_PATH}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -589,7 +589,7 @@ def test_diskann_command_search_list_size(db):
|
||||||
assert len(results_before) == 5
|
assert len(results_before) == 5
|
||||||
|
|
||||||
# Override search_list_size_search at runtime
|
# Override search_list_size_search at runtime
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('search_list_size_search=256')")
|
db.execute("INSERT INTO t(t) VALUES ('search_list_size_search=256')")
|
||||||
|
|
||||||
# Query should still work
|
# Query should still work
|
||||||
results_after = db.execute(
|
results_after = db.execute(
|
||||||
|
|
@ -598,14 +598,14 @@ def test_diskann_command_search_list_size(db):
|
||||||
assert len(results_after) == 5
|
assert len(results_after) == 5
|
||||||
|
|
||||||
# Override search_list_size_insert at runtime
|
# Override search_list_size_insert at runtime
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('search_list_size_insert=32')")
|
db.execute("INSERT INTO t(t) VALUES ('search_list_size_insert=32')")
|
||||||
|
|
||||||
# Inserts should still work
|
# Inserts should still work
|
||||||
vec = struct.pack("64f", *[random.random() for _ in range(64)])
|
vec = struct.pack("64f", *[random.random() for _ in range(64)])
|
||||||
db.execute("INSERT INTO t(emb) VALUES (?)", [vec])
|
db.execute("INSERT INTO t(emb) VALUES (?)", [vec])
|
||||||
|
|
||||||
# Override unified search_list_size
|
# Override unified search_list_size
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('search_list_size=64')")
|
db.execute("INSERT INTO t(t) VALUES ('search_list_size=64')")
|
||||||
|
|
||||||
results_final = db.execute(
|
results_final = db.execute(
|
||||||
"SELECT rowid, distance FROM t WHERE emb MATCH ? AND k = 5", [query]
|
"SELECT rowid, distance FROM t WHERE emb MATCH ? AND k = 5", [query]
|
||||||
|
|
@ -620,9 +620,9 @@ def test_diskann_command_search_list_size_error(db):
|
||||||
emb float[64] INDEXED BY diskann(neighbor_quantizer=binary)
|
emb float[64] INDEXED BY diskann(neighbor_quantizer=binary)
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
result = exec(db, "INSERT INTO t(rowid) VALUES ('search_list_size=0')")
|
result = exec(db, "INSERT INTO t(t) VALUES ('search_list_size=0')")
|
||||||
assert "error" in result
|
assert "error" in result
|
||||||
result = exec(db, "INSERT INTO t(rowid) VALUES ('search_list_size=-1')")
|
result = exec(db, "INSERT INTO t(t) VALUES ('search_list_size=-1')")
|
||||||
assert "error" in result
|
assert "error" in result
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -27,3 +27,15 @@ def test_info(db, snapshot):
|
||||||
assert exec(db, "select key, typeof(value) from v_info order by 1") == snapshot()
|
assert exec(db, "select key, typeof(value) from v_info order by 1") == snapshot()
|
||||||
|
|
||||||
|
|
||||||
|
def test_command_column_name_conflict(db):
|
||||||
|
"""Table name matching a column name should error (command column conflict)."""
|
||||||
|
# This would conflict: hidden command column 'embeddings' vs vector column 'embeddings'
|
||||||
|
with pytest.raises(sqlite3.OperationalError, match="conflicts with table name"):
|
||||||
|
db.execute(
|
||||||
|
"create virtual table embeddings using vec0(embeddings float[4])"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Different names should work fine
|
||||||
|
db.execute("create virtual table t using vec0(embeddings float[4])")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -78,7 +78,7 @@ def test_batch_insert_knn_recall(db):
|
||||||
)
|
)
|
||||||
assert ivf_total_vectors(db) == 200
|
assert ivf_total_vectors(db) == 200
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
assert ivf_assigned_count(db) == 200
|
assert ivf_assigned_count(db) == 200
|
||||||
|
|
||||||
# Query near 100 -- closest should be rowid 100
|
# Query near 100 -- closest should be rowid 100
|
||||||
|
|
@ -107,7 +107,7 @@ def test_delete_rows_gone_from_knn(db):
|
||||||
[i, _f32([float(i), 0, 0, 0])],
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
# Delete rowid 10
|
# Delete rowid 10
|
||||||
db.execute("DELETE FROM t WHERE rowid = 10")
|
db.execute("DELETE FROM t WHERE rowid = 10")
|
||||||
|
|
@ -127,7 +127,7 @@ def test_delete_all_rows_empty_results(db):
|
||||||
[i, _f32([float(i), 0, 0, 0])],
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
db.execute("DELETE FROM t WHERE rowid = ?", [i])
|
db.execute("DELETE FROM t WHERE rowid = ?", [i])
|
||||||
|
|
@ -152,7 +152,7 @@ def test_insert_after_delete_reuse_rowid(db):
|
||||||
[i, _f32([float(i), 0, 0, 0])],
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
# Delete rowid 5
|
# Delete rowid 5
|
||||||
db.execute("DELETE FROM t WHERE rowid = 5")
|
db.execute("DELETE FROM t WHERE rowid = 5")
|
||||||
|
|
@ -184,7 +184,7 @@ def test_update_vector_via_delete_insert(db):
|
||||||
[i, _f32([float(i), 0, 0, 0])],
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
# "Update" rowid 3: delete and re-insert with new vector
|
# "Update" rowid 3: delete and re-insert with new vector
|
||||||
db.execute("DELETE FROM t WHERE rowid = 3")
|
db.execute("DELETE FROM t WHERE rowid = 3")
|
||||||
|
|
@ -316,7 +316,7 @@ def test_single_row_compute_centroids(db):
|
||||||
db.execute(
|
db.execute(
|
||||||
"INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 2, 3, 4])]
|
"INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 2, 3, 4])]
|
||||||
)
|
)
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
assert ivf_assigned_count(db) == 1
|
assert ivf_assigned_count(db) == 1
|
||||||
|
|
||||||
results = knn(db, [1, 2, 3, 4], 1)
|
results = knn(db, [1, 2, 3, 4], 1)
|
||||||
|
|
@ -343,10 +343,10 @@ def test_cell_overflow_many_vectors(db):
|
||||||
|
|
||||||
# Set a single centroid so all vectors go there
|
# Set a single centroid so all vectors go there
|
||||||
db.execute(
|
db.execute(
|
||||||
"INSERT INTO t(rowid, v) VALUES ('set-centroid:0', ?)",
|
"INSERT INTO t(t, v) VALUES ('set-centroid:0', ?)",
|
||||||
[_f32([1.0, 0, 0, 0])],
|
[_f32([1.0, 0, 0, 0])],
|
||||||
)
|
)
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('assign-vectors')")
|
db.execute("INSERT INTO t(t) VALUES ('assign-vectors')")
|
||||||
|
|
||||||
assert ivf_assigned_count(db) == 100
|
assert ivf_assigned_count(db) == 100
|
||||||
|
|
||||||
|
|
@ -377,7 +377,7 @@ def test_large_batch_with_training(db):
|
||||||
[i, _f32([float(i), 0, 0, 0])],
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
for i in range(500, 1000):
|
for i in range(500, 1000):
|
||||||
db.execute(
|
db.execute(
|
||||||
|
|
@ -409,7 +409,7 @@ def test_knn_after_interleaved_insert_delete(db):
|
||||||
[i, _f32([float(i), 0, 0, 0])],
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
# Delete rowids 0-9 (closest to query at 5.0)
|
# Delete rowids 0-9 (closest to query at 5.0)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
|
|
@ -434,7 +434,7 @@ def test_knn_empty_centroids_after_deletes(db):
|
||||||
[i, _f32([float(i % 10) * 10, 0, 0, 0])],
|
[i, _f32([float(i % 10) * 10, 0, 0, 0])],
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
# Delete a bunch, potentially emptying some centroids
|
# Delete a bunch, potentially emptying some centroids
|
||||||
for i in range(30):
|
for i in range(30):
|
||||||
|
|
@ -458,7 +458,7 @@ def test_knn_correct_distances(db):
|
||||||
db.execute("INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([3, 0, 0, 0])])
|
db.execute("INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([3, 0, 0, 0])])
|
||||||
db.execute("INSERT INTO t(rowid, v) VALUES (3, ?)", [_f32([0, 4, 0, 0])])
|
db.execute("INSERT INTO t(rowid, v) VALUES (3, ?)", [_f32([0, 4, 0, 0])])
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
results = knn(db, [0, 0, 0, 0], 3)
|
results = knn(db, [0, 0, 0, 0], 3)
|
||||||
result_map = {r[0]: r[1] for r in results}
|
result_map = {r[0]: r[1] for r in results}
|
||||||
|
|
@ -547,7 +547,7 @@ def test_interleaved_ops_correctness(db):
|
||||||
[i, _f32([float(i), 0, 0, 0])],
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
# Phase 2: Delete even-numbered rowids
|
# Phase 2: Delete even-numbered rowids
|
||||||
for i in range(0, 50, 2):
|
for i in range(0, 50, 2):
|
||||||
|
|
|
||||||
|
|
@ -122,7 +122,7 @@ def test_ivf_int8_insert_and_query(db):
|
||||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
# Should be able to query
|
# Should be able to query
|
||||||
rows = db.execute(
|
rows = db.execute(
|
||||||
|
|
@ -151,7 +151,7 @@ def test_ivf_binary_insert_and_query(db):
|
||||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32(v)]
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32(v)]
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
rows = db.execute(
|
rows = db.execute(
|
||||||
"SELECT rowid FROM t WHERE v MATCH ? AND k = 5",
|
"SELECT rowid FROM t WHERE v MATCH ? AND k = 5",
|
||||||
|
|
@ -221,10 +221,10 @@ def test_ivf_int8_oversample_improves_recall(db):
|
||||||
db.execute("INSERT INTO t1(rowid, v) VALUES (?, ?)", [i, v])
|
db.execute("INSERT INTO t1(rowid, v) VALUES (?, ?)", [i, v])
|
||||||
db.execute("INSERT INTO t2(rowid, v) VALUES (?, ?)", [i, v])
|
db.execute("INSERT INTO t2(rowid, v) VALUES (?, ?)", [i, v])
|
||||||
|
|
||||||
db.execute("INSERT INTO t1(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t1(t1) VALUES ('compute-centroids')")
|
||||||
db.execute("INSERT INTO t2(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t2(t2) VALUES ('compute-centroids')")
|
||||||
db.execute("INSERT INTO t1(rowid) VALUES ('nprobe=4')")
|
db.execute("INSERT INTO t1(t1) VALUES ('nprobe=4')")
|
||||||
db.execute("INSERT INTO t2(rowid) VALUES ('nprobe=4')")
|
db.execute("INSERT INTO t2(t2) VALUES ('nprobe=4')")
|
||||||
|
|
||||||
query = _f32([5.0, 1.5, 2.5, 0.5])
|
query = _f32([5.0, 1.5, 2.5, 0.5])
|
||||||
r1 = db.execute("SELECT rowid FROM t1 WHERE v MATCH ? AND k=10", [query]).fetchall()
|
r1 = db.execute("SELECT rowid FROM t1 WHERE v MATCH ? AND k=10", [query]).fetchall()
|
||||||
|
|
@ -247,7 +247,7 @@ def test_ivf_quantized_delete(db):
|
||||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
assert db.execute("SELECT count(*) FROM t_ivf_vectors00").fetchone()[0] == 10
|
assert db.execute("SELECT count(*) FROM t_ivf_vectors00").fetchone()[0] == 10
|
||||||
|
|
||||||
db.execute("DELETE FROM t WHERE rowid = 5")
|
db.execute("DELETE FROM t WHERE rowid = 5")
|
||||||
|
|
|
||||||
|
|
@ -217,7 +217,7 @@ def test_compute_centroids(db):
|
||||||
|
|
||||||
assert ivf_unassigned_count(db) == 40
|
assert ivf_unassigned_count(db) == 40
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
# After training: unassigned cell should be gone (or empty), vectors in trained cells
|
# After training: unassigned cell should be gone (or empty), vectors in trained cells
|
||||||
assert ivf_unassigned_count(db) == 0
|
assert ivf_unassigned_count(db) == 0
|
||||||
|
|
@ -238,10 +238,10 @@ def test_compute_centroids_recompute(db):
|
||||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
|
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
|
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
|
||||||
assert ivf_assigned_count(db) == 20
|
assert ivf_assigned_count(db) == 20
|
||||||
|
|
||||||
|
|
@ -260,7 +260,7 @@ def test_ivf_insert_after_training(db):
|
||||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
db.execute(
|
db.execute(
|
||||||
"INSERT INTO t(rowid, v) VALUES (100, ?)", [_f32([5, 0, 0, 0])]
|
"INSERT INTO t(rowid, v) VALUES (100, ?)", [_f32([5, 0, 0, 0])]
|
||||||
|
|
@ -290,7 +290,7 @@ def test_ivf_knn_after_training(db):
|
||||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
rows = db.execute(
|
rows = db.execute(
|
||||||
"SELECT rowid, distance FROM t WHERE v MATCH ? AND k = 5",
|
"SELECT rowid, distance FROM t WHERE v MATCH ? AND k = 5",
|
||||||
|
|
@ -310,7 +310,7 @@ def test_ivf_knn_k_larger_than_n(db):
|
||||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
rows = db.execute(
|
rows = db.execute(
|
||||||
"SELECT rowid FROM t WHERE v MATCH ? AND k = 100",
|
"SELECT rowid FROM t WHERE v MATCH ? AND k = 100",
|
||||||
|
|
@ -334,17 +334,17 @@ def test_set_centroid_and_assign(db):
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute(
|
db.execute(
|
||||||
"INSERT INTO t(rowid, v) VALUES ('set-centroid:0', ?)",
|
"INSERT INTO t(t, v) VALUES ('set-centroid:0', ?)",
|
||||||
[_f32([5, 0, 0, 0])],
|
[_f32([5, 0, 0, 0])],
|
||||||
)
|
)
|
||||||
db.execute(
|
db.execute(
|
||||||
"INSERT INTO t(rowid, v) VALUES ('set-centroid:1', ?)",
|
"INSERT INTO t(t, v) VALUES ('set-centroid:1', ?)",
|
||||||
[_f32([15, 0, 0, 0])],
|
[_f32([15, 0, 0, 0])],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
|
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('assign-vectors')")
|
db.execute("INSERT INTO t(t) VALUES ('assign-vectors')")
|
||||||
|
|
||||||
assert ivf_unassigned_count(db) == 0
|
assert ivf_unassigned_count(db) == 0
|
||||||
assert ivf_assigned_count(db) == 20
|
assert ivf_assigned_count(db) == 20
|
||||||
|
|
@ -364,10 +364,10 @@ def test_clear_centroids(db):
|
||||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
|
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('clear-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('clear-centroids')")
|
||||||
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 0
|
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 0
|
||||||
assert ivf_unassigned_count(db) == 20
|
assert ivf_unassigned_count(db) == 20
|
||||||
trained = db.execute(
|
trained = db.execute(
|
||||||
|
|
@ -390,7 +390,7 @@ def test_ivf_delete_after_training(db):
|
||||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
assert ivf_assigned_count(db) == 10
|
assert ivf_assigned_count(db) == 10
|
||||||
|
|
||||||
db.execute("DELETE FROM t WHERE rowid = 5")
|
db.execute("DELETE FROM t WHERE rowid = 5")
|
||||||
|
|
@ -412,7 +412,7 @@ def test_ivf_recall_nprobe_equals_nlist(db):
|
||||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
)
|
)
|
||||||
|
|
||||||
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
rows = db.execute(
|
rows = db.execute(
|
||||||
"SELECT rowid FROM t WHERE v MATCH ? AND k = 10",
|
"SELECT rowid FROM t WHERE v MATCH ? AND k = 10",
|
||||||
|
|
|
||||||
138
tests/test-legacy-compat.py
Normal file
138
tests/test-legacy-compat.py
Normal file
|
|
@ -0,0 +1,138 @@
|
||||||
|
"""Backwards compatibility tests: current sqlite-vec reading legacy databases.
|
||||||
|
|
||||||
|
The fixture file tests/fixtures/legacy-v0.1.6.db was generated by
|
||||||
|
tests/generate_legacy_db.py using sqlite-vec v0.1.6. These tests verify
|
||||||
|
that the current version can fully read, query, insert into, and delete
|
||||||
|
from tables created by older versions.
|
||||||
|
"""
|
||||||
|
import sqlite3
|
||||||
|
import struct
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
FIXTURE_PATH = os.path.join(os.path.dirname(__file__), "fixtures", "legacy-v0.1.6.db")
|
||||||
|
|
||||||
|
|
||||||
|
def _f32(vals):
|
||||||
|
return struct.pack(f"{len(vals)}f", *vals)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def legacy_db(tmp_path):
|
||||||
|
"""Copy the legacy fixture to a temp dir so tests can modify it."""
|
||||||
|
if not os.path.exists(FIXTURE_PATH):
|
||||||
|
pytest.skip("Legacy fixture not found — run: uv run --script tests/generate_legacy_db.py")
|
||||||
|
db_path = str(tmp_path / "legacy.db")
|
||||||
|
shutil.copy2(FIXTURE_PATH, db_path)
|
||||||
|
db = sqlite3.connect(db_path)
|
||||||
|
db.row_factory = sqlite3.Row
|
||||||
|
db.enable_load_extension(True)
|
||||||
|
db.load_extension("dist/vec0")
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_select_count(legacy_db):
|
||||||
|
"""Basic SELECT count should return all rows."""
|
||||||
|
count = legacy_db.execute("SELECT count(*) FROM legacy_vectors").fetchone()[0]
|
||||||
|
assert count == 50
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_point_query(legacy_db):
|
||||||
|
"""Point query by rowid should return correct vector."""
|
||||||
|
row = legacy_db.execute(
|
||||||
|
"SELECT rowid, emb FROM legacy_vectors WHERE rowid = 1"
|
||||||
|
).fetchone()
|
||||||
|
assert row["rowid"] == 1
|
||||||
|
vec = struct.unpack("4f", row["emb"])
|
||||||
|
assert vec[0] == pytest.approx(1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_knn(legacy_db):
|
||||||
|
"""KNN query on legacy table should return correct results."""
|
||||||
|
query = _f32([1.0, 0.0, 0.0, 0.0])
|
||||||
|
rows = legacy_db.execute(
|
||||||
|
"SELECT rowid, distance FROM legacy_vectors "
|
||||||
|
"WHERE emb MATCH ? AND k = 5",
|
||||||
|
[query],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) == 5
|
||||||
|
assert rows[0]["rowid"] == 1
|
||||||
|
assert rows[0]["distance"] == pytest.approx(0.0)
|
||||||
|
for i in range(len(rows) - 1):
|
||||||
|
assert rows[i]["distance"] <= rows[i + 1]["distance"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_insert(legacy_db):
|
||||||
|
"""INSERT into legacy table should work."""
|
||||||
|
legacy_db.execute(
|
||||||
|
"INSERT INTO legacy_vectors(rowid, emb) VALUES (100, ?)",
|
||||||
|
[_f32([100.0, 0.0, 0.0, 0.0])],
|
||||||
|
)
|
||||||
|
count = legacy_db.execute("SELECT count(*) FROM legacy_vectors").fetchone()[0]
|
||||||
|
assert count == 51
|
||||||
|
|
||||||
|
rows = legacy_db.execute(
|
||||||
|
"SELECT rowid FROM legacy_vectors WHERE emb MATCH ? AND k = 1",
|
||||||
|
[_f32([100.0, 0.0, 0.0, 0.0])],
|
||||||
|
).fetchall()
|
||||||
|
assert rows[0]["rowid"] == 100
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_delete(legacy_db):
|
||||||
|
"""DELETE from legacy table should work."""
|
||||||
|
legacy_db.execute("DELETE FROM legacy_vectors WHERE rowid = 1")
|
||||||
|
count = legacy_db.execute("SELECT count(*) FROM legacy_vectors").fetchone()[0]
|
||||||
|
assert count == 49
|
||||||
|
|
||||||
|
rows = legacy_db.execute(
|
||||||
|
"SELECT rowid FROM legacy_vectors WHERE emb MATCH ? AND k = 5",
|
||||||
|
[_f32([1.0, 0.0, 0.0, 0.0])],
|
||||||
|
).fetchall()
|
||||||
|
assert 1 not in [r["rowid"] for r in rows]
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_fullscan(legacy_db):
|
||||||
|
"""Full scan should work."""
|
||||||
|
rows = legacy_db.execute(
|
||||||
|
"SELECT rowid FROM legacy_vectors ORDER BY rowid LIMIT 5"
|
||||||
|
).fetchall()
|
||||||
|
assert [r["rowid"] for r in rows] == [1, 2, 3, 4, 5]
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_name_conflict_table(legacy_db):
|
||||||
|
"""Legacy table where column name == table name should work.
|
||||||
|
|
||||||
|
The v0.1.6 DB has: CREATE VIRTUAL TABLE emb USING vec0(emb float[4])
|
||||||
|
Current code should NOT add the command column for this table
|
||||||
|
(detected via _info version check), avoiding the name conflict.
|
||||||
|
"""
|
||||||
|
count = legacy_db.execute("SELECT count(*) FROM emb").fetchone()[0]
|
||||||
|
assert count == 10
|
||||||
|
|
||||||
|
rows = legacy_db.execute(
|
||||||
|
"SELECT rowid, distance FROM emb WHERE emb MATCH ? AND k = 3",
|
||||||
|
[_f32([1.0, 0.0, 0.0, 0.0])],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) == 3
|
||||||
|
assert rows[0]["rowid"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_name_conflict_insert_delete(legacy_db):
|
||||||
|
"""INSERT and DELETE on legacy name-conflict table."""
|
||||||
|
legacy_db.execute(
|
||||||
|
"INSERT INTO emb(rowid, emb) VALUES (100, ?)",
|
||||||
|
[_f32([100.0, 0.0, 0.0, 0.0])],
|
||||||
|
)
|
||||||
|
assert legacy_db.execute("SELECT count(*) FROM emb").fetchone()[0] == 11
|
||||||
|
|
||||||
|
legacy_db.execute("DELETE FROM emb WHERE rowid = 5")
|
||||||
|
assert legacy_db.execute("SELECT count(*) FROM emb").fetchone()[0] == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_no_command_column(legacy_db):
|
||||||
|
"""Legacy tables should NOT have the command column."""
|
||||||
|
with pytest.raises(sqlite3.OperationalError):
|
||||||
|
legacy_db.execute(
|
||||||
|
"INSERT INTO legacy_vectors(legacy_vectors) VALUES ('some_command')"
|
||||||
|
)
|
||||||
|
|
@ -655,3 +655,73 @@ def test_rescore_text_pk_insert_knn_delete(db):
|
||||||
ids = [r["id"] for r in rows]
|
ids = [r["id"] for r in rows]
|
||||||
assert "alpha" not in ids
|
assert "alpha" not in ids
|
||||||
assert len(rows) >= 1 # other results still returned
|
assert len(rows) >= 1 # other results still returned
|
||||||
|
|
||||||
|
|
||||||
|
def test_runtime_oversample(db):
|
||||||
|
"""oversample can be changed at query time via FTS5-style command."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
" embedding float[128] indexed by rescore(quantizer=bit, oversample=2)"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
random.seed(200)
|
||||||
|
for i in range(200):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, embedding) VALUES (?, ?)",
|
||||||
|
[i + 1, float_vec([random.gauss(0, 1) for _ in range(128)])],
|
||||||
|
)
|
||||||
|
|
||||||
|
query = float_vec([random.gauss(0, 1) for _ in range(128)])
|
||||||
|
|
||||||
|
# KNN with default oversample=2 (low)
|
||||||
|
rows_low = db.execute(
|
||||||
|
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10",
|
||||||
|
[query],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows_low) == 10
|
||||||
|
|
||||||
|
# Change oversample at runtime to high value
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('oversample=32')")
|
||||||
|
|
||||||
|
# KNN with oversample=32 (high) — same or better recall
|
||||||
|
rows_high = db.execute(
|
||||||
|
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10",
|
||||||
|
[query],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows_high) == 10
|
||||||
|
|
||||||
|
# Reset to original
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('oversample=2')")
|
||||||
|
|
||||||
|
rows_reset = db.execute(
|
||||||
|
"SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10",
|
||||||
|
[query],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows_reset) == 10
|
||||||
|
# After reset, should match the original low-oversample results
|
||||||
|
assert [r["rowid"] for r in rows_reset] == [r["rowid"] for r in rows_low]
|
||||||
|
|
||||||
|
|
||||||
|
def test_runtime_oversample_error(db):
|
||||||
|
"""Invalid oversample values should error."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
" embedding float[128] indexed by rescore(quantizer=bit)"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
with pytest.raises(sqlite3.OperationalError, match="oversample must be >= 1"):
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('oversample=0')")
|
||||||
|
|
||||||
|
with pytest.raises(sqlite3.OperationalError, match="oversample must be >= 1"):
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('oversample=-5')")
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_command_errors(db):
|
||||||
|
"""Unknown command strings should produce a clear error."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
" embedding float[128] indexed by rescore(quantizer=bit)"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
with pytest.raises(sqlite3.OperationalError, match="unknown vec0 command"):
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('not_a_real_command')")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue