diff --git a/Makefile b/Makefile index b50751b..2758ee5 100644 --- a/Makefile +++ b/Makefile @@ -206,6 +206,21 @@ test-loadable-watch: test-unit: $(CC) -DSQLITE_CORE -DSQLITE_VEC_TEST -DSQLITE_VEC_ENABLE_RESCORE tests/test-unit.c sqlite-vec.c vendor/sqlite3.c -I./ -Ivendor -o $(prefix)/test-unit && $(prefix)/test-unit +# Standalone sqlite3 CLI with vec0 compiled in. Useful for benchmarking, +# profiling (has debug symbols), and scripting without .load_extension. +# make cli +# dist/sqlite3 :memory: "SELECT vec_version()" +# dist/sqlite3 < script.sql +cli: sqlite-vec.h $(prefix) + $(CC) -O2 -g \ + -DSQLITE_CORE \ + -DSQLITE_EXTRA_INIT=core_init \ + -DSQLITE_THREADSAFE=0 \ + -Ivendor/ -I./ \ + $(CFLAGS) \ + vendor/sqlite3.c vendor/shell.c sqlite-vec.c examples/sqlite3-cli/core_init.c \ + -ldl -lm -o $(prefix)/sqlite3 + fuzz-build: $(MAKE) -C tests/fuzz all diff --git a/benchmarks-ann/Makefile b/benchmarks-ann/Makefile index 0789d38..6081457 100644 --- a/benchmarks-ann/Makefile +++ b/benchmarks-ann/Makefile @@ -8,27 +8,20 @@ BASELINES = \ "brute-int8:type=baseline,variant=int8" \ "brute-bit:type=baseline,variant=bit" -# --- Index-specific configs --- -# Each index branch should add its own configs here. Example: -# -# DISKANN_CONFIGS = \ -# "diskann-R48-binary:type=diskann,R=48,L=128,quantizer=binary" \ -# "diskann-R72-int8:type=diskann,R=72,L=128,quantizer=int8" -# -# IVF_CONFIGS = \ -# "ivf-n128-p16:type=ivf,nlist=128,nprobe=16" -# -# ANNOY_CONFIGS = \ -# "annoy-t50:type=annoy,n_trees=50" +# --- IVF configs --- +IVF_CONFIGS = \ + "ivf-n32-p8:type=ivf,nlist=32,nprobe=8" \ + "ivf-n128-p16:type=ivf,nlist=128,nprobe=16" \ + "ivf-n512-p32:type=ivf,nlist=512,nprobe=32" RESCORE_CONFIGS = \ "rescore-bit-os8:type=rescore,quantizer=bit,oversample=8" \ "rescore-bit-os16:type=rescore,quantizer=bit,oversample=16" \ "rescore-int8-os8:type=rescore,quantizer=int8,oversample=8" -ALL_CONFIGS = $(BASELINES) $(RESCORE_CONFIGS) +ALL_CONFIGS = $(BASELINES) $(RESCORE_CONFIGS) $(IVF_CONFIGS) -.PHONY: seed ground-truth bench-smoke bench-rescore bench-10k bench-50k bench-100k bench-all \ +.PHONY: seed ground-truth bench-smoke bench-rescore bench-ivf bench-10k bench-50k bench-100k bench-all \ report clean # --- Data preparation --- @@ -43,7 +36,8 @@ ground-truth: seed # --- Quick smoke test --- bench-smoke: seed $(BENCH) --subset-size 5000 -k 10 -n 20 -o runs/smoke \ - $(BASELINES) + "brute-float:type=baseline,variant=float" \ + "ivf-quick:type=ivf,nlist=16,nprobe=4" bench-rescore: seed $(BENCH) --subset-size 10000 -k 10 -o runs/rescore \ @@ -62,6 +56,12 @@ bench-100k: seed bench-all: bench-10k bench-50k bench-100k +# --- IVF across sizes --- +bench-ivf: seed + $(BENCH) --subset-size 10000 -k 10 -o runs/ivf $(BASELINES) $(IVF_CONFIGS) + $(BENCH) --subset-size 50000 -k 10 -o runs/ivf $(BASELINES) $(IVF_CONFIGS) + $(BENCH) --subset-size 100000 -k 10 -o runs/ivf $(BASELINES) $(IVF_CONFIGS) + # --- Report --- report: @echo "Use: sqlite3 runs//results.db 'SELECT * FROM bench_results ORDER BY recall DESC'" diff --git a/benchmarks-ann/bench.py b/benchmarks-ann/bench.py index c1179d6..c640628 100644 --- a/benchmarks-ann/bench.py +++ b/benchmarks-ann/bench.py @@ -173,6 +173,48 @@ INDEX_REGISTRY["rescore"] = { } +# ============================================================================ +# IVF implementation +# ============================================================================ + + +def _ivf_create_table_sql(params): + return ( + f"CREATE VIRTUAL TABLE vec_items USING vec0(" + f" id integer primary key," + f" embedding float[768] distance_metric=cosine" + f" indexed by ivf(" + f" nlist={params['nlist']}," + f" nprobe={params['nprobe']}" + f" )" + f")" + ) + + +def _ivf_post_insert_hook(conn, params): + print(" Training k-means centroids...", flush=True) + t0 = time.perf_counter() + conn.execute("INSERT INTO vec_items(id) VALUES ('compute-centroids')") + conn.commit() + elapsed = time.perf_counter() - t0 + print(f" Training done in {elapsed:.1f}s", flush=True) + return elapsed + + +def _ivf_describe(params): + return f"ivf nlist={params['nlist']:<4} nprobe={params['nprobe']}" + + +INDEX_REGISTRY["ivf"] = { + "defaults": {"nlist": 128, "nprobe": 16}, + "create_table_sql": _ivf_create_table_sql, + "insert_sql": None, + "post_insert_hook": _ivf_post_insert_hook, + "run_query": None, + "describe": _ivf_describe, +} + + # ============================================================================ # Config parsing # ============================================================================ diff --git a/sqlite-vec-ivf-kmeans.c b/sqlite-vec-ivf-kmeans.c new file mode 100644 index 0000000..0faa803 --- /dev/null +++ b/sqlite-vec-ivf-kmeans.c @@ -0,0 +1,214 @@ +/** + * sqlite-vec-ivf-kmeans.c — Pure k-means clustering algorithm. + * + * No SQLite dependency. Operates on float arrays in memory. + * #include'd into sqlite-vec.c after struct definitions. + */ + +#ifndef SQLITE_VEC_IVF_KMEANS_C +#define SQLITE_VEC_IVF_KMEANS_C + +// When opened standalone in an editor, pull in types so the LSP is happy. +// When #include'd from sqlite-vec.c, SQLITE_VEC_H is already defined. +#ifndef SQLITE_VEC_H +#include "sqlite-vec.c" // IWYU pragma: keep +#endif + +#include +#include + +#define VEC0_IVF_KMEANS_MAX_ITER 25 +#define VEC0_IVF_KMEANS_DEFAULT_SEED 0 + +// Simple xorshift32 PRNG +static uint32_t ivf_xorshift32(uint32_t *state) { + uint32_t x = *state; + x ^= x << 13; + x ^= x >> 17; + x ^= x << 5; + *state = x; + return x; +} + +// L2 squared distance between two float vectors +static float ivf_l2_dist(const float *a, const float *b, int D) { + float sum = 0.0f; + for (int d = 0; d < D; d++) { + float diff = a[d] - b[d]; + sum += diff * diff; + } + return sum; +} + +// Find nearest centroid for a single vector. Returns centroid index. +static int ivf_nearest_centroid(const float *vec, const float *centroids, + int D, int k) { + float min_dist = FLT_MAX; + int best = 0; + for (int c = 0; c < k; c++) { + float dist = ivf_l2_dist(vec, ¢roids[c * D], D); + if (dist < min_dist) { + min_dist = dist; + best = c; + } + } + return best; +} + +/** + * K-means++ initialization. + * Picks k initial centroids from the data with probability proportional + * to squared distance from nearest existing centroid. + */ +static int ivf_kmeans_init_plusplus(const float *vectors, int N, int D, + int k, uint32_t seed, float *centroids) { + if (N <= 0 || k <= 0 || D <= 0) + return -1; + if (seed == 0) + seed = 42; + + // Pick first centroid randomly + int first = ivf_xorshift32(&seed) % N; + memcpy(centroids, &vectors[first * D], D * sizeof(float)); + + if (k == 1) + return 0; + + // Allocate distance array + float *dists = sqlite3_malloc64((i64)N * sizeof(float)); + if (!dists) + return -1; + + for (int c = 1; c < k; c++) { + // Compute D(x) = distance to nearest existing centroid + double total = 0.0; + for (int i = 0; i < N; i++) { + float d = ivf_l2_dist(&vectors[i * D], ¢roids[(c - 1) * D], D); + if (c == 1 || d < dists[i]) { + dists[i] = d; + } + total += dists[i]; + } + + // Weighted random selection + if (total <= 0.0) { + // All distances zero — pick randomly + int pick = ivf_xorshift32(&seed) % N; + memcpy(¢roids[c * D], &vectors[pick * D], D * sizeof(float)); + } else { + double threshold = ((double)ivf_xorshift32(&seed) / (double)0xFFFFFFFF) * total; + double cumulative = 0.0; + int pick = N - 1; + for (int i = 0; i < N; i++) { + cumulative += dists[i]; + if (cumulative >= threshold) { + pick = i; + break; + } + } + memcpy(¢roids[c * D], &vectors[pick * D], D * sizeof(float)); + } + } + + sqlite3_free(dists); + return 0; +} + +/** + * Lloyd's k-means algorithm. + * + * @param vectors N*D float array (row-major) + * @param N number of vectors + * @param D dimensionality + * @param k number of clusters + * @param max_iter maximum iterations + * @param seed PRNG seed for initialization + * @param out_centroids output: k*D float array (caller-allocated) + * @return 0 on success, -1 on error + */ +static int ivf_kmeans(const float *vectors, int N, int D, int k, + int max_iter, uint32_t seed, float *out_centroids) { + if (N <= 0 || D <= 0 || k <= 0) + return -1; + + // Clamp k to N + if (k > N) + k = N; + + // Allocate working memory + int *assignments = sqlite3_malloc64((i64)N * sizeof(int)); + float *new_centroids = sqlite3_malloc64((i64)k * D * sizeof(float)); + int *counts = sqlite3_malloc64((i64)k * sizeof(int)); + + if (!assignments || !new_centroids || !counts) { + sqlite3_free(assignments); + sqlite3_free(new_centroids); + sqlite3_free(counts); + return -1; + } + + memset(assignments, -1, N * sizeof(int)); + + // Initialize centroids via k-means++ + if (ivf_kmeans_init_plusplus(vectors, N, D, k, seed, out_centroids) != 0) { + sqlite3_free(assignments); + sqlite3_free(new_centroids); + sqlite3_free(counts); + return -1; + } + + for (int iter = 0; iter < max_iter; iter++) { + // Assignment step + int changed = 0; + for (int i = 0; i < N; i++) { + int nearest = ivf_nearest_centroid(&vectors[i * D], out_centroids, D, k); + if (nearest != assignments[i]) { + assignments[i] = nearest; + changed++; + } + } + if (changed == 0) + break; + + // Update step + memset(new_centroids, 0, (size_t)k * D * sizeof(float)); + memset(counts, 0, k * sizeof(int)); + + for (int i = 0; i < N; i++) { + int c = assignments[i]; + counts[c]++; + for (int d = 0; d < D; d++) { + new_centroids[c * D + d] += vectors[i * D + d]; + } + } + + for (int c = 0; c < k; c++) { + if (counts[c] == 0) { + // Empty cluster: reassign to farthest point from its nearest centroid + float max_dist = -1.0f; + int farthest = 0; + for (int i = 0; i < N; i++) { + float d = ivf_l2_dist(&vectors[i * D], + &out_centroids[assignments[i] * D], D); + if (d > max_dist) { + max_dist = d; + farthest = i; + } + } + memcpy(&out_centroids[c * D], &vectors[farthest * D], + D * sizeof(float)); + } else { + for (int d = 0; d < D; d++) { + out_centroids[c * D + d] = new_centroids[c * D + d] / counts[c]; + } + } + } + } + + sqlite3_free(assignments); + sqlite3_free(new_centroids); + sqlite3_free(counts); + return 0; +} + +#endif /* SQLITE_VEC_IVF_KMEANS_C */ diff --git a/sqlite-vec-ivf.c b/sqlite-vec-ivf.c new file mode 100644 index 0000000..5bc8edb --- /dev/null +++ b/sqlite-vec-ivf.c @@ -0,0 +1,1445 @@ +/** + * sqlite-vec-ivf.c — IVF (Inverted File Index) for sqlite-vec. + * + * #include'd into sqlite-vec.c after struct definitions and before vec0_init(). + * + * Storage: fixed-size packed blob cells (capped at IVF_CELL_MAX_VECTORS). + * Multiple cell rows per centroid. cell_id is auto-increment rowid, + * centroid_id is indexed for lookup. This keeps blobs small (~200KB) + * and avoids expensive overflow page traversal on insert. + */ + +#ifndef SQLITE_VEC_IVF_C +#define SQLITE_VEC_IVF_C + +#ifdef SQLITE_VEC_TEST +#define IVF_STATIC +#else +#define IVF_STATIC static +#endif + +// When opened standalone in an editor, pull in sqlite-vec.c so the LSP +// can resolve all types (vec0_vtab, VectorColumnDefinition, etc.). +// When #include'd from sqlite-vec.c, SQLITE_VEC_H is already defined. +#ifndef SQLITE_VEC_H +#include "sqlite-vec.c" // IWYU pragma: keep +#endif + +#define VEC0_IVF_DEFAULT_NLIST 128 +#define VEC0_IVF_DEFAULT_NPROBE 10 +#define VEC0_IVF_MAX_NLIST 65536 +#define VEC0_IVF_CELL_MAX_VECTORS 64 // ~200KB per cell at 768-dim f32 +#define VEC0_IVF_UNASSIGNED_CENTROID_ID (-1) + +#define VEC0_SHADOW_IVF_CENTROIDS_NAME "\"%w\".\"%w_ivf_centroids%02d\"" +#define VEC0_SHADOW_IVF_CELLS_NAME "\"%w\".\"%w_ivf_cells%02d\"" +#define VEC0_SHADOW_IVF_ROWID_MAP_NAME "\"%w\".\"%w_ivf_rowid_map%02d\"" +#define VEC0_SHADOW_IVF_VECTORS_NAME "\"%w\".\"%w_ivf_vectors%02d\"" + +// ============================================================================ +// Parser +// ============================================================================ + +static int vec0_parse_ivf_options(struct Vec0Scanner *scanner, + struct Vec0IvfConfig *config) { + struct Vec0Token token; + int rc; + config->nlist = VEC0_IVF_DEFAULT_NLIST; + config->nprobe = -1; + config->quantizer = VEC0_IVF_QUANTIZER_NONE; + config->oversample = 1; + int nprobe_explicit = 0; + + rc = vec0_scanner_next(scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME || token.token_type != TOKEN_TYPE_LPAREN) + return SQLITE_ERROR; + + rc = vec0_scanner_next(scanner, &token); + if (rc == VEC0_TOKEN_RESULT_SOME && token.token_type == TOKEN_TYPE_RPAREN) { + config->nprobe = VEC0_IVF_DEFAULT_NPROBE; + return SQLITE_OK; + } + + while (1) { + if (rc != VEC0_TOKEN_RESULT_SOME || token.token_type != TOKEN_TYPE_IDENTIFIER) + return SQLITE_ERROR; + char *key = token.start; + int keyLength = token.end - token.start; + + rc = vec0_scanner_next(scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME || token.token_type != TOKEN_TYPE_EQ) + return SQLITE_ERROR; + + // Read value — can be digit or identifier + rc = vec0_scanner_next(scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME) return SQLITE_ERROR; + if (token.token_type != TOKEN_TYPE_DIGIT && + token.token_type != TOKEN_TYPE_IDENTIFIER) + return SQLITE_ERROR; + + char *val = token.start; + int valLength = token.end - token.start; + + if (sqlite3_strnicmp(key, "nlist", keyLength) == 0) { + if (token.token_type != TOKEN_TYPE_DIGIT) return SQLITE_ERROR; + int v = atoi(val); + if (v < 0 || v > VEC0_IVF_MAX_NLIST) return SQLITE_ERROR; + config->nlist = v; + } else if (sqlite3_strnicmp(key, "nprobe", keyLength) == 0) { + if (token.token_type != TOKEN_TYPE_DIGIT) return SQLITE_ERROR; + int v = atoi(val); + if (v < 1 || v > VEC0_IVF_MAX_NLIST) return SQLITE_ERROR; + config->nprobe = v; + nprobe_explicit = 1; + } else if (sqlite3_strnicmp(key, "quantizer", keyLength) == 0) { + if (token.token_type != TOKEN_TYPE_IDENTIFIER) return SQLITE_ERROR; + if (sqlite3_strnicmp(val, "none", valLength) == 0) { + config->quantizer = VEC0_IVF_QUANTIZER_NONE; + } else if (sqlite3_strnicmp(val, "int8", valLength) == 0) { + config->quantizer = VEC0_IVF_QUANTIZER_INT8; + } else if (sqlite3_strnicmp(val, "binary", valLength) == 0) { + config->quantizer = VEC0_IVF_QUANTIZER_BINARY; + } else { + return SQLITE_ERROR; + } + } else if (sqlite3_strnicmp(key, "oversample", keyLength) == 0) { + if (token.token_type != TOKEN_TYPE_DIGIT) return SQLITE_ERROR; + int v = atoi(val); + if (v < 1) return SQLITE_ERROR; + config->oversample = v; + } else { + return SQLITE_ERROR; + } + + rc = vec0_scanner_next(scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME) return SQLITE_ERROR; + if (token.token_type == TOKEN_TYPE_RPAREN) break; + if (token.token_type != TOKEN_TYPE_COMMA) return SQLITE_ERROR; + rc = vec0_scanner_next(scanner, &token); + } + + if (config->nprobe < 0) config->nprobe = VEC0_IVF_DEFAULT_NPROBE; + if (config->nlist > 0 && config->nprobe > config->nlist) { + if (nprobe_explicit) return SQLITE_ERROR; + config->nprobe = config->nlist; + } + + // Validation: oversample > 1 only makes sense with quantization + if (config->oversample > 1 && config->quantizer == VEC0_IVF_QUANTIZER_NONE) { + return SQLITE_ERROR; + } + + return SQLITE_OK; +} + +// ============================================================================ +// Helpers +// ============================================================================ + +/** + * Size of a stored vector in bytes, accounting for quantization. + */ +static int ivf_vec_size(vec0_vtab *p, int col_idx) { + int D = (int)p->vector_columns[col_idx].dimensions; + switch (p->vector_columns[col_idx].ivf.quantizer) { + case VEC0_IVF_QUANTIZER_INT8: return D; + case VEC0_IVF_QUANTIZER_BINARY: return D / 8; + default: return D * (int)sizeof(float); + } +} + +/** + * Size of the full-precision vector in bytes (always float32). + */ +static int ivf_full_vec_size(vec0_vtab *p, int col_idx) { + return (int)(p->vector_columns[col_idx].dimensions * sizeof(float)); +} + +/** + * Quantize float32 vector to int8. + * Uses unit normalization: clamp to [-1,1], scale to [-127,127]. + */ +IVF_STATIC void ivf_quantize_int8(const float *src, int8_t *dst, int D) { + for (int i = 0; i < D; i++) { + float v = src[i]; + if (v > 1.0f) v = 1.0f; + if (v < -1.0f) v = -1.0f; + dst[i] = (int8_t)(v * 127.0f); + } +} + +/** + * Quantize float32 vector to binary (sign-bit quantization). + * Each bit = 1 if src[i] > 0, else 0. + */ +IVF_STATIC void ivf_quantize_binary(const float *src, uint8_t *dst, int D) { + memset(dst, 0, D / 8); + for (int i = 0; i < D; i++) { + if (src[i] > 0.0f) { + dst[i / 8] |= (1 << (i % 8)); + } + } +} + +/** + * Quantize a float32 vector to the target type based on config. + * dst must be pre-allocated to ivf_vec_size() bytes. + * If quantizer=none, copies src as-is. + */ +static void ivf_quantize(vec0_vtab *p, int col_idx, + const float *src, void *dst) { + int D = (int)p->vector_columns[col_idx].dimensions; + switch (p->vector_columns[col_idx].ivf.quantizer) { + case VEC0_IVF_QUANTIZER_INT8: + ivf_quantize_int8(src, (int8_t *)dst, D); + break; + case VEC0_IVF_QUANTIZER_BINARY: + ivf_quantize_binary(src, (uint8_t *)dst, D); + break; + default: + memcpy(dst, src, D * sizeof(float)); + break; + } +} + +// Forward declaration +static float ivf_distance(vec0_vtab *p, int col_idx, const void *a, const void *b); + +/** + * Find nearest centroid. Works with quantized or float centroids. + * vec and centroids must be in the same representation (both quantized or both float). + * vecSize = size of one vector in bytes. + */ +static int ivf_find_nearest_centroid(vec0_vtab *p, int col_idx, + const void *vec, const void *centroids, + int vecSize, int k) { + float min_dist = FLT_MAX; + int best = 0; + const unsigned char *cdata = (const unsigned char *)centroids; + for (int c = 0; c < k; c++) { + float dist = ivf_distance(p, col_idx, vec, cdata + c * vecSize); + if (dist < min_dist) { min_dist = dist; best = c; } + } + return best; +} + +/** + * Compute distance between two vectors using the column's distance_metric. + * Dispatches to SIMD-optimized functions (NEON/AVX) via distance_*_float(). + * For float32 (non-quantized) vectors. + */ +static float ivf_distance_float(vec0_vtab *p, int col_idx, + const float *a, const float *b) { + size_t dims = p->vector_columns[col_idx].dimensions; + switch (p->vector_columns[col_idx].distance_metric) { + case VEC0_DISTANCE_METRIC_COSINE: + return distance_cosine_float(a, b, &dims); + case VEC0_DISTANCE_METRIC_L1: + return (float)distance_l1_f32(a, b, &dims); + case VEC0_DISTANCE_METRIC_L2: + default: + return distance_l2_sqr_float(a, b, &dims); + } +} + +/** + * Compute distance between two quantized vectors. + * For int8: uses L2 or cosine on int8. + * For binary: uses hamming distance. + * For none: delegates to ivf_distance_float. + */ +static float ivf_distance(vec0_vtab *p, int col_idx, + const void *a, const void *b) { + size_t dims = p->vector_columns[col_idx].dimensions; + switch (p->vector_columns[col_idx].ivf.quantizer) { + case VEC0_IVF_QUANTIZER_INT8: + return distance_l2_sqr_int8(a, b, &dims); + case VEC0_IVF_QUANTIZER_BINARY: + return distance_hamming(a, b, &dims); + default: + return ivf_distance_float(p, col_idx, (const float *)a, (const float *)b); + } +} + +static int ivf_ensure_stmt(vec0_vtab *p, sqlite3_stmt **pStmt, const char *fmt, + int col_idx) { + if (*pStmt) return SQLITE_OK; + char *zSql = sqlite3_mprintf(fmt, p->schemaName, p->tableName, col_idx); + if (!zSql) return SQLITE_NOMEM; + int rc = sqlite3_prepare_v2(p->db, zSql, -1, pStmt, NULL); + sqlite3_free(zSql); + return rc; +} + +static int ivf_exec(vec0_vtab *p, const char *fmt, int col_idx) { + sqlite3_stmt *stmt = NULL; + char *zSql = sqlite3_mprintf(fmt, p->schemaName, p->tableName, col_idx); + if (!zSql) return SQLITE_NOMEM; + int rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); + sqlite3_free(zSql); + if (rc == SQLITE_OK) { sqlite3_step(stmt); sqlite3_finalize(stmt); } + return SQLITE_OK; +} + +static int ivf_is_trained(vec0_vtab *p, int col_idx) { + if (p->ivfTrainedCache[col_idx] >= 0) return p->ivfTrainedCache[col_idx]; + sqlite3_stmt *stmt = NULL; + int trained = 0; + char *zSql = sqlite3_mprintf( + "SELECT value FROM " VEC0_SHADOW_INFO_NAME " WHERE key = 'ivf_trained_%d'", + p->schemaName, p->tableName, col_idx); + if (!zSql) return 0; + if (sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL) == SQLITE_OK) { + if (sqlite3_step(stmt) == SQLITE_ROW) + trained = (sqlite3_column_int(stmt, 0) == 1); + } + sqlite3_free(zSql); + sqlite3_finalize(stmt); + p->ivfTrainedCache[col_idx] = trained; + return trained; +} + +// ============================================================================ +// Cell operations — fixed-size cells, multiple rows per centroid +// ============================================================================ + +/** + * Create a new cell row. Returns the new cell_id (rowid) via *out_cell_id. + */ +static int ivf_cell_create(vec0_vtab *p, int col_idx, i64 centroid_id, + i64 *out_cell_id) { + sqlite3_stmt *stmt = NULL; + int rc; + int cap = VEC0_IVF_CELL_MAX_VECTORS; + int vecSize = ivf_vec_size(p, col_idx); + char *zSql = sqlite3_mprintf( + "INSERT INTO " VEC0_SHADOW_IVF_CELLS_NAME + " (centroid_id, n_vectors, validity, rowids, vectors) VALUES (?, 0, ?, ?, ?)", + p->schemaName, p->tableName, col_idx); + if (!zSql) return SQLITE_NOMEM; + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) return rc; + sqlite3_bind_int64(stmt, 1, centroid_id); + sqlite3_bind_zeroblob(stmt, 2, cap / 8); + sqlite3_bind_zeroblob(stmt, 3, cap * (int)sizeof(i64)); + sqlite3_bind_zeroblob(stmt, 4, cap * vecSize); + rc = sqlite3_step(stmt); + sqlite3_finalize(stmt); + if (rc != SQLITE_DONE) return SQLITE_ERROR; + if (out_cell_id) *out_cell_id = sqlite3_last_insert_rowid(p->db); + return SQLITE_OK; +} + +/** + * Find a cell with space for the given centroid, or create one. + * Returns cell_id (rowid) and current n_vectors. + */ +static int ivf_cell_find_or_create(vec0_vtab *p, int col_idx, i64 centroid_id, + i64 *out_cell_id, int *out_n) { + int rc; + // Find existing cell with space + rc = ivf_ensure_stmt(p, &p->stmtIvfCellMeta[col_idx], + "SELECT rowid, n_vectors FROM " VEC0_SHADOW_IVF_CELLS_NAME + " WHERE centroid_id = ? AND n_vectors < %d LIMIT 1", + col_idx); + // The %d in the format won't work with ivf_ensure_stmt since it only has 3 + // format args. Use a direct approach instead. + sqlite3_finalize(p->stmtIvfCellMeta[col_idx]); + p->stmtIvfCellMeta[col_idx] = NULL; + + char *zSql = sqlite3_mprintf( + "SELECT rowid, n_vectors FROM " VEC0_SHADOW_IVF_CELLS_NAME + " WHERE centroid_id = ? AND n_vectors < %d LIMIT 1", + p->schemaName, p->tableName, col_idx, VEC0_IVF_CELL_MAX_VECTORS); + if (!zSql) return SQLITE_NOMEM; + // Cache this manually + if (!p->stmtIvfCellMeta[col_idx]) { + rc = sqlite3_prepare_v2(p->db, zSql, -1, &p->stmtIvfCellMeta[col_idx], NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) return rc; + } else { + sqlite3_free(zSql); + } + + sqlite3_stmt *stmt = p->stmtIvfCellMeta[col_idx]; + sqlite3_reset(stmt); + sqlite3_bind_int64(stmt, 1, centroid_id); + + if (sqlite3_step(stmt) == SQLITE_ROW) { + *out_cell_id = sqlite3_column_int64(stmt, 0); + *out_n = sqlite3_column_int(stmt, 1); + return SQLITE_OK; + } + + // No cell with space — create new one + rc = ivf_cell_create(p, col_idx, centroid_id, out_cell_id); + *out_n = 0; + return rc; +} + +/** + * Insert vector into cell at slot = n_vectors (append). + * Cell must have space (n_vectors < VEC0_IVF_CELL_MAX_VECTORS). + */ +static int ivf_cell_insert(vec0_vtab *p, int col_idx, i64 centroid_id, + i64 rowid, const void *vectorData, int vectorSize) { + int rc; + i64 cell_id; + int n_vectors; + + rc = ivf_cell_find_or_create(p, col_idx, centroid_id, &cell_id, &n_vectors); + if (rc != SQLITE_OK) return rc; + + int slot = n_vectors; + char *cellsTable = p->shadowIvfCellsNames[col_idx]; + + // Set validity bit + sqlite3_blob *blob = NULL; + rc = sqlite3_blob_open(p->db, p->schemaName, cellsTable, "validity", + cell_id, 1, &blob); + if (rc != SQLITE_OK) return rc; + unsigned char bx; + sqlite3_blob_read(blob, &bx, 1, slot / 8); + bx |= (1 << (slot % 8)); + sqlite3_blob_write(blob, &bx, 1, slot / 8); + sqlite3_blob_close(blob); + + // Write rowid + rc = sqlite3_blob_open(p->db, p->schemaName, cellsTable, "rowids", + cell_id, 1, &blob); + if (rc == SQLITE_OK) { + sqlite3_blob_write(blob, &rowid, sizeof(i64), slot * (int)sizeof(i64)); + sqlite3_blob_close(blob); + } + + // Write vector + rc = sqlite3_blob_open(p->db, p->schemaName, cellsTable, "vectors", + cell_id, 1, &blob); + if (rc == SQLITE_OK) { + sqlite3_blob_write(blob, vectorData, vectorSize, slot * vectorSize); + sqlite3_blob_close(blob); + } + + // Increment n_vectors (cached) + ivf_ensure_stmt(p, &p->stmtIvfCellUpdateN[col_idx], + "UPDATE " VEC0_SHADOW_IVF_CELLS_NAME + " SET n_vectors = n_vectors + 1 WHERE rowid = ?", col_idx); + if (p->stmtIvfCellUpdateN[col_idx]) { + sqlite3_stmt *s = p->stmtIvfCellUpdateN[col_idx]; + sqlite3_reset(s); + sqlite3_bind_int64(s, 1, cell_id); + sqlite3_step(s); + } + + // Insert rowid_map (cached) + ivf_ensure_stmt(p, &p->stmtIvfRowidMapInsert[col_idx], + "INSERT INTO " VEC0_SHADOW_IVF_ROWID_MAP_NAME + " (rowid, cell_id, slot) VALUES (?, ?, ?)", col_idx); + if (p->stmtIvfRowidMapInsert[col_idx]) { + sqlite3_stmt *s = p->stmtIvfRowidMapInsert[col_idx]; + sqlite3_reset(s); + sqlite3_bind_int64(s, 1, rowid); + sqlite3_bind_int64(s, 2, cell_id); + sqlite3_bind_int(s, 3, slot); + sqlite3_step(s); + } + + return SQLITE_OK; +} + +// ============================================================================ +// Shadow tables +// ============================================================================ + +static int ivf_create_shadow_tables(vec0_vtab *p, int col_idx) { + sqlite3_stmt *stmt = NULL; + int rc; + char *zSql; + + zSql = sqlite3_mprintf( + "CREATE TABLE " VEC0_SHADOW_IVF_CENTROIDS_NAME + " (centroid_id INTEGER PRIMARY KEY, centroid BLOB NOT NULL)", + p->schemaName, p->tableName, col_idx); + if (!zSql) return SQLITE_NOMEM; + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql); + if (rc != SQLITE_OK || sqlite3_step(stmt) != SQLITE_DONE) { sqlite3_finalize(stmt); return SQLITE_ERROR; } + sqlite3_finalize(stmt); + + // cell_id is rowid (auto-increment), centroid_id is indexed + zSql = sqlite3_mprintf( + "CREATE TABLE " VEC0_SHADOW_IVF_CELLS_NAME + " (centroid_id INTEGER NOT NULL," + " n_vectors INTEGER NOT NULL DEFAULT 0," + " validity BLOB NOT NULL," + " rowids BLOB NOT NULL," + " vectors BLOB NOT NULL)", + p->schemaName, p->tableName, col_idx); + if (!zSql) return SQLITE_NOMEM; + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql); + if (rc != SQLITE_OK || sqlite3_step(stmt) != SQLITE_DONE) { sqlite3_finalize(stmt); return SQLITE_ERROR; } + sqlite3_finalize(stmt); + + // Index on centroid_id for cell lookup + zSql = sqlite3_mprintf( + "CREATE INDEX \"%w_ivf_cells%02d_centroid\" ON \"%w_ivf_cells%02d\" (centroid_id)", + p->tableName, col_idx, p->tableName, col_idx); + if (!zSql) return SQLITE_NOMEM; + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql); + if (rc != SQLITE_OK || sqlite3_step(stmt) != SQLITE_DONE) { sqlite3_finalize(stmt); return SQLITE_ERROR; } + sqlite3_finalize(stmt); + + zSql = sqlite3_mprintf( + "CREATE TABLE " VEC0_SHADOW_IVF_ROWID_MAP_NAME + " (rowid INTEGER PRIMARY KEY, cell_id INTEGER NOT NULL, slot INTEGER NOT NULL)", + p->schemaName, p->tableName, col_idx); + if (!zSql) return SQLITE_NOMEM; + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql); + if (rc != SQLITE_OK || sqlite3_step(stmt) != SQLITE_DONE) { sqlite3_finalize(stmt); return SQLITE_ERROR; } + sqlite3_finalize(stmt); + + // _ivf_vectors — full-precision KV store (only when quantizer != none) + if (p->vector_columns[col_idx].ivf.quantizer != VEC0_IVF_QUANTIZER_NONE) { + zSql = sqlite3_mprintf( + "CREATE TABLE " VEC0_SHADOW_IVF_VECTORS_NAME + " (rowid INTEGER PRIMARY KEY, vector BLOB NOT NULL)", + p->schemaName, p->tableName, col_idx); + if (!zSql) return SQLITE_NOMEM; + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql); + if (rc != SQLITE_OK || sqlite3_step(stmt) != SQLITE_DONE) { sqlite3_finalize(stmt); return SQLITE_ERROR; } + sqlite3_finalize(stmt); + } + + zSql = sqlite3_mprintf( + "INSERT INTO " VEC0_SHADOW_INFO_NAME " (key, value) VALUES ('ivf_trained_%d', '0')", + p->schemaName, p->tableName, col_idx); + if (!zSql) return SQLITE_NOMEM; + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql); + if (rc != SQLITE_OK || sqlite3_step(stmt) != SQLITE_DONE) { sqlite3_finalize(stmt); return SQLITE_ERROR; } + sqlite3_finalize(stmt); + + return SQLITE_OK; +} + +static int ivf_drop_shadow_tables(vec0_vtab *p, int col_idx) { + ivf_exec(p, "DROP TABLE IF EXISTS " VEC0_SHADOW_IVF_CENTROIDS_NAME, col_idx); + ivf_exec(p, "DROP TABLE IF EXISTS " VEC0_SHADOW_IVF_CELLS_NAME, col_idx); + ivf_exec(p, "DROP TABLE IF EXISTS " VEC0_SHADOW_IVF_ROWID_MAP_NAME, col_idx); + ivf_exec(p, "DROP TABLE IF EXISTS " VEC0_SHADOW_IVF_VECTORS_NAME, col_idx); + return SQLITE_OK; +} + +// ============================================================================ +// Insert / Delete +// ============================================================================ + +static int ivf_insert(vec0_vtab *p, int col_idx, i64 rowid, + const void *vectorData, int vectorSize) { + UNUSED_PARAMETER(vectorSize); + int quantizer = p->vector_columns[col_idx].ivf.quantizer; + int qvecSize = ivf_vec_size(p, col_idx); + int rc; + + // Quantize the input vector (or copy as-is if no quantization) + void *qvec = sqlite3_malloc(qvecSize); + if (!qvec) return SQLITE_NOMEM; + ivf_quantize(p, col_idx, (const float *)vectorData, qvec); + + if (!ivf_is_trained(p, col_idx)) { + rc = ivf_cell_insert(p, col_idx, VEC0_IVF_UNASSIGNED_CENTROID_ID, + rowid, qvec, qvecSize); + } else { + // Find nearest centroid using quantized distance + int best_centroid = -1; + float min_dist = FLT_MAX; + + rc = ivf_ensure_stmt(p, &p->stmtIvfCentroidsAll[col_idx], + "SELECT centroid_id, centroid FROM " VEC0_SHADOW_IVF_CENTROIDS_NAME, col_idx); + if (rc != SQLITE_OK) { sqlite3_free(qvec); return rc; } + sqlite3_stmt *stmt = p->stmtIvfCentroidsAll[col_idx]; + sqlite3_reset(stmt); + while (sqlite3_step(stmt) == SQLITE_ROW) { + int cid = sqlite3_column_int(stmt, 0); + const void *c = sqlite3_column_blob(stmt, 1); + int cBytes = sqlite3_column_bytes(stmt, 1); + if (!c || cBytes != qvecSize) continue; + float dist = ivf_distance(p, col_idx, qvec, c); + if (dist < min_dist) { min_dist = dist; best_centroid = cid; } + } + if (best_centroid < 0) { sqlite3_free(qvec); return SQLITE_ERROR; } + + rc = ivf_cell_insert(p, col_idx, best_centroid, rowid, qvec, qvecSize); + } + + sqlite3_free(qvec); + if (rc != SQLITE_OK) return rc; + + // Store full-precision vector in KV table when quantized + if (quantizer != VEC0_IVF_QUANTIZER_NONE) { + sqlite3_stmt *stmt = NULL; + char *zSql = sqlite3_mprintf( + "INSERT INTO " VEC0_SHADOW_IVF_VECTORS_NAME " (rowid, vector) VALUES (?, ?)", + p->schemaName, p->tableName, col_idx); + if (!zSql) return SQLITE_NOMEM; + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql); + if (rc != SQLITE_OK) return rc; + sqlite3_bind_int64(stmt, 1, rowid); + sqlite3_bind_blob(stmt, 2, vectorData, ivf_full_vec_size(p, col_idx), SQLITE_STATIC); + rc = sqlite3_step(stmt); + sqlite3_finalize(stmt); + if (rc != SQLITE_DONE) return SQLITE_ERROR; + } + + return SQLITE_OK; +} + +static int ivf_delete(vec0_vtab *p, int col_idx, i64 rowid) { + int rc; + i64 cell_id = 0; + int slot = -1; + + rc = ivf_ensure_stmt(p, &p->stmtIvfRowidMapLookup[col_idx], + "SELECT cell_id, slot FROM " VEC0_SHADOW_IVF_ROWID_MAP_NAME + " WHERE rowid = ?", col_idx); + if (rc != SQLITE_OK) return rc; + sqlite3_stmt *s = p->stmtIvfRowidMapLookup[col_idx]; + sqlite3_reset(s); + sqlite3_bind_int64(s, 1, rowid); + if (sqlite3_step(s) == SQLITE_ROW) { + cell_id = sqlite3_column_int64(s, 0); + slot = sqlite3_column_int(s, 1); + } + if (slot < 0) return SQLITE_OK; + + // Clear validity bit + char *cellsTable = p->shadowIvfCellsNames[col_idx]; + sqlite3_blob *blob = NULL; + rc = sqlite3_blob_open(p->db, p->schemaName, cellsTable, "validity", + cell_id, 1, &blob); + if (rc == SQLITE_OK) { + unsigned char bx; + sqlite3_blob_read(blob, &bx, 1, slot / 8); + bx &= ~(1 << (slot % 8)); + sqlite3_blob_write(blob, &bx, 1, slot / 8); + sqlite3_blob_close(blob); + } + + // Decrement n_vectors + if (p->stmtIvfCellUpdateN[col_idx]) { + // This stmt does +1, but we want -1. Use a different cached stmt. + } + // Just use inline for decrement (not hot path) + { + sqlite3_stmt *stmtDec = NULL; + char *zSql = sqlite3_mprintf( + "UPDATE " VEC0_SHADOW_IVF_CELLS_NAME + " SET n_vectors = n_vectors - 1 WHERE rowid = ?", + p->schemaName, p->tableName, col_idx); + if (zSql) { + sqlite3_prepare_v2(p->db, zSql, -1, &stmtDec, NULL); sqlite3_free(zSql); + if (stmtDec) { sqlite3_bind_int64(stmtDec, 1, cell_id); sqlite3_step(stmtDec); sqlite3_finalize(stmtDec); } + } + } + + // Delete from rowid_map + ivf_ensure_stmt(p, &p->stmtIvfRowidMapDelete[col_idx], + "DELETE FROM " VEC0_SHADOW_IVF_ROWID_MAP_NAME " WHERE rowid = ?", col_idx); + if (p->stmtIvfRowidMapDelete[col_idx]) { + sqlite3_stmt *sd = p->stmtIvfRowidMapDelete[col_idx]; + sqlite3_reset(sd); + sqlite3_bind_int64(sd, 1, rowid); + sqlite3_step(sd); + } + + // Delete from _ivf_vectors (full-precision KV) when quantized + if (p->vector_columns[col_idx].ivf.quantizer != VEC0_IVF_QUANTIZER_NONE) { + sqlite3_stmt *stmtDelVec = NULL; + char *zSql = sqlite3_mprintf( + "DELETE FROM " VEC0_SHADOW_IVF_VECTORS_NAME " WHERE rowid = ?", + p->schemaName, p->tableName, col_idx); + if (zSql) { + sqlite3_prepare_v2(p->db, zSql, -1, &stmtDelVec, NULL); sqlite3_free(zSql); + if (stmtDelVec) { sqlite3_bind_int64(stmtDelVec, 1, rowid); sqlite3_step(stmtDelVec); sqlite3_finalize(stmtDelVec); } + } + } + + return SQLITE_OK; +} + +// ============================================================================ +// Point query +// ============================================================================ + +static int ivf_get_vector_data(vec0_vtab *p, i64 rowid, int col_idx, + void **outVector, int *outVectorSize) { + int rc; + int vecSize = ivf_vec_size(p, col_idx); + i64 cell_id = 0; + int slot = -1; + + rc = ivf_ensure_stmt(p, &p->stmtIvfRowidMapLookup[col_idx], + "SELECT cell_id, slot FROM " VEC0_SHADOW_IVF_ROWID_MAP_NAME + " WHERE rowid = ?", col_idx); + if (rc != SQLITE_OK) return rc; + sqlite3_stmt *s = p->stmtIvfRowidMapLookup[col_idx]; + sqlite3_reset(s); + sqlite3_bind_int64(s, 1, rowid); + if (sqlite3_step(s) != SQLITE_ROW) return SQLITE_EMPTY; + cell_id = sqlite3_column_int64(s, 0); + slot = sqlite3_column_int(s, 1); + + void *buf = sqlite3_malloc(vecSize); + if (!buf) return SQLITE_NOMEM; + + sqlite3_blob *blob = NULL; + rc = sqlite3_blob_open(p->db, p->schemaName, p->shadowIvfCellsNames[col_idx], + "vectors", cell_id, 0, &blob); + if (rc != SQLITE_OK) { sqlite3_free(buf); return rc; } + rc = sqlite3_blob_read(blob, buf, vecSize, slot * vecSize); + sqlite3_blob_close(blob); + if (rc != SQLITE_OK) { sqlite3_free(buf); return rc; } + + *outVector = buf; + if (outVectorSize) *outVectorSize = vecSize; + return SQLITE_OK; +} + +// ============================================================================ +// Centroid commands +// ============================================================================ + +static int ivf_load_all_vectors(vec0_vtab *p, int col_idx, + float **out_vectors, i64 **out_rowids, int *out_N) { + sqlite3_stmt *stmt = NULL; + int rc; + int D = (int)p->vector_columns[col_idx].dimensions; + int vecSize = D * (int)sizeof(float); + int quantizer = p->vector_columns[col_idx].ivf.quantizer; + + // When quantized, load full-precision vectors from _ivf_vectors KV table + if (quantizer != VEC0_IVF_QUANTIZER_NONE) { + int total = 0; + char *zSql = sqlite3_mprintf( + "SELECT count(*) FROM " VEC0_SHADOW_IVF_VECTORS_NAME, + p->schemaName, p->tableName, col_idx); + if (!zSql) return SQLITE_NOMEM; + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql); + if (rc == SQLITE_OK && sqlite3_step(stmt) == SQLITE_ROW) total = sqlite3_column_int(stmt, 0); + sqlite3_finalize(stmt); + if (total == 0) { *out_vectors = NULL; *out_rowids = NULL; *out_N = 0; return SQLITE_OK; } + + float *vectors = sqlite3_malloc64((i64)total * D * sizeof(float)); + i64 *rowids = sqlite3_malloc64((i64)total * sizeof(i64)); + if (!vectors || !rowids) { sqlite3_free(vectors); sqlite3_free(rowids); return SQLITE_NOMEM; } + + int idx = 0; + zSql = sqlite3_mprintf( + "SELECT rowid, vector FROM " VEC0_SHADOW_IVF_VECTORS_NAME, + p->schemaName, p->tableName, col_idx); + if (!zSql) { sqlite3_free(vectors); sqlite3_free(rowids); return SQLITE_NOMEM; } + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql); + if (rc == SQLITE_OK) { + while (sqlite3_step(stmt) == SQLITE_ROW && idx < total) { + rowids[idx] = sqlite3_column_int64(stmt, 0); + const void *blob = sqlite3_column_blob(stmt, 1); + int blobBytes = sqlite3_column_bytes(stmt, 1); + if (blob && blobBytes == vecSize) { + memcpy(&vectors[idx * D], blob, vecSize); + idx++; + } + } + } + sqlite3_finalize(stmt); + *out_vectors = vectors; *out_rowids = rowids; *out_N = idx; + return SQLITE_OK; + } + + // Non-quantized: load from cells (existing path) + + // Count total + int total = 0; + char *zSql = sqlite3_mprintf( + "SELECT COALESCE(SUM(n_vectors),0) FROM " VEC0_SHADOW_IVF_CELLS_NAME, + p->schemaName, p->tableName, col_idx); + if (!zSql) return SQLITE_NOMEM; + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql); + if (rc == SQLITE_OK && sqlite3_step(stmt) == SQLITE_ROW) total = sqlite3_column_int(stmt, 0); + sqlite3_finalize(stmt); + + if (total == 0) { *out_vectors = NULL; *out_rowids = NULL; *out_N = 0; return SQLITE_OK; } + + float *vectors = sqlite3_malloc64((i64)total * D * sizeof(float)); + i64 *rowids = sqlite3_malloc64((i64)total * sizeof(i64)); + if (!vectors || !rowids) { sqlite3_free(vectors); sqlite3_free(rowids); return SQLITE_NOMEM; } + + int idx = 0; + zSql = sqlite3_mprintf( + "SELECT n_vectors, validity, rowids, vectors FROM " VEC0_SHADOW_IVF_CELLS_NAME, + p->schemaName, p->tableName, col_idx); + if (!zSql) { sqlite3_free(vectors); sqlite3_free(rowids); return SQLITE_NOMEM; } + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql); + if (rc != SQLITE_OK) { sqlite3_free(vectors); sqlite3_free(rowids); return rc; } + + while (sqlite3_step(stmt) == SQLITE_ROW) { + int n = sqlite3_column_int(stmt, 0); + if (n == 0) continue; + const unsigned char *val = (const unsigned char *)sqlite3_column_blob(stmt, 1); + const i64 *rids = (const i64 *)sqlite3_column_blob(stmt, 2); + const float *vecs = (const float *)sqlite3_column_blob(stmt, 3); + int valBytes = sqlite3_column_bytes(stmt, 1); + int ridsBytes = sqlite3_column_bytes(stmt, 2); + int vecsBytes = sqlite3_column_bytes(stmt, 3); + if (!val || !rids || !vecs) continue; + int cap = valBytes * 8; + // Clamp cap to the number of entries actually backed by the rowids and vectors blobs + if (ridsBytes / (int)sizeof(i64) < cap) cap = ridsBytes / (int)sizeof(i64); + if (vecsBytes / vecSize < cap) cap = vecsBytes / vecSize; + for (int i = 0; i < cap && idx < total; i++) { + if (val[i / 8] & (1 << (i % 8))) { + rowids[idx] = rids[i]; + memcpy(&vectors[idx * D], &vecs[i * D], vecSize); + idx++; + } + } + } + sqlite3_finalize(stmt); + *out_vectors = vectors; *out_rowids = rowids; *out_N = idx; + return SQLITE_OK; +} + +static void ivf_invalidate_cached(vec0_vtab *p, int col_idx) { + sqlite3_finalize(p->stmtIvfCellMeta[col_idx]); p->stmtIvfCellMeta[col_idx] = NULL; + sqlite3_finalize(p->stmtIvfCentroidsAll[col_idx]); p->stmtIvfCentroidsAll[col_idx] = NULL; + sqlite3_finalize(p->stmtIvfCellUpdateN[col_idx]); p->stmtIvfCellUpdateN[col_idx] = NULL; + sqlite3_finalize(p->stmtIvfRowidMapInsert[col_idx]); p->stmtIvfRowidMapInsert[col_idx] = NULL; +} + +static int ivf_cmd_compute_centroids(vec0_vtab *p, int col_idx, int nlist_override, + int max_iter, uint32_t seed) { + int rc; + int D = (int)p->vector_columns[col_idx].dimensions; + int vecSize = D * (int)sizeof(float); + int quantizer = p->vector_columns[col_idx].ivf.quantizer; + int nlist = nlist_override > 0 ? nlist_override : p->vector_columns[col_idx].ivf.nlist; + if (nlist <= 0) { vtab_set_error(&p->base, "nlist must be specified"); return SQLITE_ERROR; } + + float *vectors = NULL; i64 *rowids = NULL; int N = 0; + rc = ivf_load_all_vectors(p, col_idx, &vectors, &rowids, &N); + if (rc != SQLITE_OK) return rc; + if (N == 0) { vtab_set_error(&p->base, "No vectors"); sqlite3_free(vectors); sqlite3_free(rowids); return SQLITE_ERROR; } + if (nlist > N) nlist = N; + + float *centroids = sqlite3_malloc64((i64)nlist * D * sizeof(float)); + if (!centroids) { sqlite3_free(vectors); sqlite3_free(rowids); return SQLITE_NOMEM; } + if (ivf_kmeans(vectors, N, D, nlist, max_iter, seed, centroids) != 0) { + sqlite3_free(vectors); sqlite3_free(rowids); sqlite3_free(centroids); return SQLITE_ERROR; + } + + // Compute assignments + int *assignments = sqlite3_malloc64((i64)N * sizeof(int)); + if (!assignments) { sqlite3_free(vectors); sqlite3_free(rowids); sqlite3_free(centroids); return SQLITE_NOMEM; } + // Assignment uses float32 distances (k-means operates in float32 space) + for (int i = 0; i < N; i++) { + float min_d = FLT_MAX; + int best = 0; + for (int c = 0; c < nlist; c++) { + float d = ivf_distance_float(p, col_idx, &vectors[i * D], ¢roids[c * D]); + if (d < min_d) { min_d = d; best = c; } + } + assignments[i] = best; + } + + // Invalidate all cached stmts before dropping/recreating tables + ivf_invalidate_cached(p, col_idx); + + sqlite3_exec(p->db, "SAVEPOINT ivf_train", NULL, NULL, NULL); + sqlite3_stmt *stmt = NULL; + char *zSql; + + // Clear all data + ivf_exec(p, "DELETE FROM " VEC0_SHADOW_IVF_CENTROIDS_NAME, col_idx); + ivf_exec(p, "DELETE FROM " VEC0_SHADOW_IVF_CELLS_NAME, col_idx); + ivf_exec(p, "DELETE FROM " VEC0_SHADOW_IVF_ROWID_MAP_NAME, col_idx); + + // Write centroids (quantized if quantizer is set) + int qvecSize = ivf_vec_size(p, col_idx); + void *qbuf = sqlite3_malloc(qvecSize > vecSize ? qvecSize : vecSize); + if (!qbuf) { rc = SQLITE_NOMEM; goto train_error; } + + zSql = sqlite3_mprintf( + "INSERT INTO " VEC0_SHADOW_IVF_CENTROIDS_NAME " (centroid_id, centroid) VALUES (?, ?)", + p->schemaName, p->tableName, col_idx); + if (!zSql) { sqlite3_free(qbuf); rc = SQLITE_NOMEM; goto train_error; } + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql); + if (rc != SQLITE_OK) { sqlite3_free(qbuf); goto train_error; } + for (int i = 0; i < nlist; i++) { + ivf_quantize(p, col_idx, ¢roids[i * D], qbuf); + sqlite3_reset(stmt); + sqlite3_bind_int(stmt, 1, i); + sqlite3_bind_blob(stmt, 2, qbuf, qvecSize, SQLITE_TRANSIENT); + if (sqlite3_step(stmt) != SQLITE_DONE) { sqlite3_finalize(stmt); sqlite3_free(qbuf); rc = SQLITE_ERROR; goto train_error; } + } + sqlite3_finalize(stmt); + + // Build cells: group vectors by centroid, create fixed-size cells + { + // Prepare INSERT statements + sqlite3_stmt *stmtCell = NULL; + zSql = sqlite3_mprintf( + "INSERT INTO " VEC0_SHADOW_IVF_CELLS_NAME + " (centroid_id, n_vectors, validity, rowids, vectors) VALUES (?, ?, ?, ?, ?)", + p->schemaName, p->tableName, col_idx); + if (!zSql) { rc = SQLITE_NOMEM; goto train_error; } + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmtCell, NULL); sqlite3_free(zSql); + if (rc != SQLITE_OK) goto train_error; + + sqlite3_stmt *stmtMap = NULL; + zSql = sqlite3_mprintf( + "INSERT INTO " VEC0_SHADOW_IVF_ROWID_MAP_NAME " (rowid, cell_id, slot) VALUES (?, ?, ?)", + p->schemaName, p->tableName, col_idx); + if (!zSql) { sqlite3_finalize(stmtCell); rc = SQLITE_NOMEM; goto train_error; } + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmtMap, NULL); sqlite3_free(zSql); + if (rc != SQLITE_OK) { sqlite3_finalize(stmtCell); goto train_error; } + + int cap = VEC0_IVF_CELL_MAX_VECTORS; + unsigned char *val = sqlite3_malloc(cap / 8); + i64 *rids = sqlite3_malloc64((i64)cap * sizeof(i64)); + unsigned char *vecs = sqlite3_malloc64((i64)cap * qvecSize); // quantized size + if (!val || !rids || !vecs) { + sqlite3_free(val); sqlite3_free(rids); sqlite3_free(vecs); + sqlite3_finalize(stmtCell); sqlite3_finalize(stmtMap); + sqlite3_free(qbuf); + rc = SQLITE_NOMEM; goto train_error; + } + + // Process one centroid at a time + for (int c = 0; c < nlist; c++) { + int slot = 0; + memset(val, 0, cap / 8); + memset(rids, 0, cap * sizeof(i64)); + + for (int i = 0; i < N; i++) { + if (assignments[i] != c) continue; + + if (slot >= cap) { + // Flush current cell + sqlite3_reset(stmtCell); + sqlite3_bind_int(stmtCell, 1, c); + sqlite3_bind_int(stmtCell, 2, slot); + sqlite3_bind_blob(stmtCell, 3, val, cap / 8, SQLITE_TRANSIENT); + sqlite3_bind_blob(stmtCell, 4, rids, cap * (int)sizeof(i64), SQLITE_TRANSIENT); + sqlite3_bind_blob(stmtCell, 5, vecs, cap * qvecSize, SQLITE_TRANSIENT); + sqlite3_step(stmtCell); + i64 flushed_cell_id = sqlite3_last_insert_rowid(p->db); + + for (int s = 0; s < slot; s++) { + sqlite3_reset(stmtMap); + sqlite3_bind_int64(stmtMap, 1, rids[s]); + sqlite3_bind_int64(stmtMap, 2, flushed_cell_id); + sqlite3_bind_int(stmtMap, 3, s); + sqlite3_step(stmtMap); + } + + slot = 0; + memset(val, 0, cap / 8); + memset(rids, 0, cap * sizeof(i64)); + } + + val[slot / 8] |= (1 << (slot % 8)); + rids[slot] = rowids[i]; + // Quantize float32 vector into cell blob + ivf_quantize(p, col_idx, &vectors[i * D], &vecs[slot * qvecSize]); + slot++; + } + + // Flush remaining + if (slot > 0) { + sqlite3_reset(stmtCell); + sqlite3_bind_int(stmtCell, 1, c); + sqlite3_bind_int(stmtCell, 2, slot); + sqlite3_bind_blob(stmtCell, 3, val, cap / 8, SQLITE_TRANSIENT); + sqlite3_bind_blob(stmtCell, 4, rids, cap * (int)sizeof(i64), SQLITE_TRANSIENT); + sqlite3_bind_blob(stmtCell, 5, vecs, cap * qvecSize, SQLITE_TRANSIENT); + sqlite3_step(stmtCell); + i64 flushed_cell_id = sqlite3_last_insert_rowid(p->db); + + for (int s = 0; s < slot; s++) { + sqlite3_reset(stmtMap); + sqlite3_bind_int64(stmtMap, 1, rids[s]); + sqlite3_bind_int64(stmtMap, 2, flushed_cell_id); + sqlite3_bind_int(stmtMap, 3, s); + sqlite3_step(stmtMap); + } + } + } + + sqlite3_free(val); sqlite3_free(rids); sqlite3_free(vecs); + sqlite3_finalize(stmtCell); sqlite3_finalize(stmtMap); + } + + sqlite3_free(qbuf); + + // Store full-precision vectors in _ivf_vectors when quantized + if (quantizer != VEC0_IVF_QUANTIZER_NONE) { + ivf_exec(p, "DELETE FROM " VEC0_SHADOW_IVF_VECTORS_NAME, col_idx); + zSql = sqlite3_mprintf( + "INSERT INTO " VEC0_SHADOW_IVF_VECTORS_NAME " (rowid, vector) VALUES (?, ?)", + p->schemaName, p->tableName, col_idx); + if (!zSql) { rc = SQLITE_NOMEM; goto train_error; } + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql); + if (rc != SQLITE_OK) goto train_error; + for (int i = 0; i < N; i++) { + sqlite3_reset(stmt); + sqlite3_bind_int64(stmt, 1, rowids[i]); + sqlite3_bind_blob(stmt, 2, &vectors[i * D], vecSize, SQLITE_STATIC); + sqlite3_step(stmt); + } + sqlite3_finalize(stmt); + } + + // Set trained = 1 + { + zSql = sqlite3_mprintf( + "INSERT OR REPLACE INTO " VEC0_SHADOW_INFO_NAME " (key, value) VALUES ('ivf_trained_%d', '1')", + p->schemaName, p->tableName, col_idx); + if (zSql) { sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql); + sqlite3_step(stmt); sqlite3_finalize(stmt); } + } + p->ivfTrainedCache[col_idx] = 1; + + sqlite3_exec(p->db, "RELEASE ivf_train", NULL, NULL, NULL); + sqlite3_free(vectors); sqlite3_free(rowids); sqlite3_free(centroids); sqlite3_free(assignments); + return SQLITE_OK; + +train_error: + sqlite3_exec(p->db, "ROLLBACK TO ivf_train", NULL, NULL, NULL); + sqlite3_exec(p->db, "RELEASE ivf_train", NULL, NULL, NULL); + sqlite3_free(vectors); sqlite3_free(rowids); sqlite3_free(centroids); sqlite3_free(assignments); + return rc; +} + +static int ivf_cmd_set_centroid(vec0_vtab *p, int col_idx, int centroid_id, + const void *vectorData, int vectorSize) { + sqlite3_stmt *stmt = NULL; + int rc; + int D = (int)p->vector_columns[col_idx].dimensions; + if (vectorSize != (int)(D * sizeof(float))) { vtab_set_error(&p->base, "Dimension mismatch"); return SQLITE_ERROR; } + + char *zSql = sqlite3_mprintf( + "INSERT OR REPLACE INTO " VEC0_SHADOW_IVF_CENTROIDS_NAME " (centroid_id, centroid) VALUES (?, ?)", + p->schemaName, p->tableName, col_idx); + if (!zSql) return SQLITE_NOMEM; + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql); + if (rc != SQLITE_OK) return rc; + sqlite3_bind_int(stmt, 1, centroid_id); + sqlite3_bind_blob(stmt, 2, vectorData, vectorSize, SQLITE_STATIC); + rc = sqlite3_step(stmt); sqlite3_finalize(stmt); + if (rc != SQLITE_DONE) return SQLITE_ERROR; + + zSql = sqlite3_mprintf( + "INSERT OR REPLACE INTO " VEC0_SHADOW_INFO_NAME " (key, value) VALUES ('ivf_trained_%d', '1')", + p->schemaName, p->tableName, col_idx); + if (zSql) { sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql); + sqlite3_step(stmt); sqlite3_finalize(stmt); } + p->ivfTrainedCache[col_idx] = 1; + sqlite3_finalize(p->stmtIvfCentroidsAll[col_idx]); p->stmtIvfCentroidsAll[col_idx] = NULL; + return SQLITE_OK; +} + +static int ivf_cmd_assign_vectors(vec0_vtab *p, int col_idx) { + if (!ivf_is_trained(p, col_idx)) { vtab_set_error(&p->base, "No centroids"); return SQLITE_ERROR; } + + int D = (int)p->vector_columns[col_idx].dimensions; + int vecSize = D * (int)sizeof(float); + int rc; + sqlite3_stmt *stmt = NULL; + char *zSql; + + // Load centroids + int nlist = 0; + float *centroids = NULL; + zSql = sqlite3_mprintf("SELECT count(*) FROM " VEC0_SHADOW_IVF_CENTROIDS_NAME, + p->schemaName, p->tableName, col_idx); + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql); + if (rc == SQLITE_OK && sqlite3_step(stmt) == SQLITE_ROW) nlist = sqlite3_column_int(stmt, 0); + sqlite3_finalize(stmt); + if (nlist == 0) { vtab_set_error(&p->base, "No centroids"); return SQLITE_ERROR; } + + centroids = sqlite3_malloc64((i64)nlist * D * sizeof(float)); + if (!centroids) return SQLITE_NOMEM; + zSql = sqlite3_mprintf("SELECT centroid_id, centroid FROM " VEC0_SHADOW_IVF_CENTROIDS_NAME " ORDER BY centroid_id", + p->schemaName, p->tableName, col_idx); + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql); + { int ci = 0; while (sqlite3_step(stmt) == SQLITE_ROW && ci < nlist) { + const void *b = sqlite3_column_blob(stmt, 1); + int bBytes = sqlite3_column_bytes(stmt, 1); + if (b && bBytes == vecSize) memcpy(¢roids[ci * D], b, vecSize); + ci++; + }} + sqlite3_finalize(stmt); + + // Read unassigned cells, re-insert into trained cells + zSql = sqlite3_mprintf( + "SELECT rowid, n_vectors, validity, rowids, vectors FROM " VEC0_SHADOW_IVF_CELLS_NAME + " WHERE centroid_id = %d", + p->schemaName, p->tableName, col_idx, VEC0_IVF_UNASSIGNED_CENTROID_ID); + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql); + + // Invalidate cached stmts since we'll be modifying cells + ivf_invalidate_cached(p, col_idx); + + while (sqlite3_step(stmt) == SQLITE_ROW) { + int n = sqlite3_column_int(stmt, 1); + const unsigned char *val = (const unsigned char *)sqlite3_column_blob(stmt, 2); + const i64 *rids = (const i64 *)sqlite3_column_blob(stmt, 3); + const float *vecs = (const float *)sqlite3_column_blob(stmt, 4); + int valBytes = sqlite3_column_bytes(stmt, 2); + int ridsBytes = sqlite3_column_bytes(stmt, 3); + int vecsBytes = sqlite3_column_bytes(stmt, 4); + if (!val || !rids || !vecs) continue; + int cap = valBytes * 8; + if (ridsBytes / (int)sizeof(i64) < cap) cap = ridsBytes / (int)sizeof(i64); + if (vecsBytes / vecSize < cap) cap = vecsBytes / vecSize; + + for (int i = 0; i < cap && n > 0; i++) { + if (!(val[i / 8] & (1 << (i % 8)))) continue; + n--; + int cid = ivf_find_nearest_centroid(p, col_idx, &vecs[i * D], centroids, D, nlist); + + // Delete old rowid_map entry + sqlite3_stmt *sd = NULL; + char *zd = sqlite3_mprintf("DELETE FROM " VEC0_SHADOW_IVF_ROWID_MAP_NAME " WHERE rowid = ?", + p->schemaName, p->tableName, col_idx); + if (zd) { sqlite3_prepare_v2(p->db, zd, -1, &sd, NULL); sqlite3_free(zd); + sqlite3_bind_int64(sd, 1, rids[i]); sqlite3_step(sd); sqlite3_finalize(sd); } + + ivf_cell_insert(p, col_idx, cid, rids[i], &vecs[i * D], vecSize); + } + } + sqlite3_finalize(stmt); + + // Delete unassigned cells + zSql = sqlite3_mprintf( + "DELETE FROM " VEC0_SHADOW_IVF_CELLS_NAME " WHERE centroid_id = %d", + p->schemaName, p->tableName, col_idx, VEC0_IVF_UNASSIGNED_CENTROID_ID); + if (zSql) { sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql); + sqlite3_step(stmt); sqlite3_finalize(stmt); } + + sqlite3_free(centroids); + return SQLITE_OK; +} + +static int ivf_cmd_clear_centroids(vec0_vtab *p, int col_idx) { + float *vectors = NULL; i64 *rowids = NULL; int N = 0; + int vecSize = ivf_vec_size(p, col_idx); + int D = (int)p->vector_columns[col_idx].dimensions; + int rc; + sqlite3_stmt *stmt = NULL; + char *zSql; + + rc = ivf_load_all_vectors(p, col_idx, &vectors, &rowids, &N); + if (rc != SQLITE_OK) return rc; + + ivf_invalidate_cached(p, col_idx); + + ivf_exec(p, "DELETE FROM " VEC0_SHADOW_IVF_CENTROIDS_NAME, col_idx); + ivf_exec(p, "DELETE FROM " VEC0_SHADOW_IVF_CELLS_NAME, col_idx); + ivf_exec(p, "DELETE FROM " VEC0_SHADOW_IVF_ROWID_MAP_NAME, col_idx); + + // Re-insert all vectors into unassigned cells + for (int i = 0; i < N; i++) { + ivf_cell_insert(p, col_idx, VEC0_IVF_UNASSIGNED_CENTROID_ID, + rowids[i], &vectors[i * D], vecSize); + } + + zSql = sqlite3_mprintf( + "INSERT OR REPLACE INTO " VEC0_SHADOW_INFO_NAME " (key, value) VALUES ('ivf_trained_%d', '0')", + p->schemaName, p->tableName, col_idx); + if (zSql) { sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql); + sqlite3_step(stmt); sqlite3_finalize(stmt); } + p->ivfTrainedCache[col_idx] = 0; + + sqlite3_free(vectors); sqlite3_free(rowids); + return SQLITE_OK; +} + +// ============================================================================ +// KNN Query — scan all cells for probed centroids +// ============================================================================ + +struct IvfCentroidDist { int id; float dist; }; +struct IvfCandidate { i64 rowid; float distance; }; + +static int ivf_candidate_cmp(const void *a, const void *b) { + float da = ((const struct IvfCandidate *)a)->distance; + float db = ((const struct IvfCandidate *)b)->distance; + if (da < db) return -1; + if (da > db) return 1; + return 0; +} + +/** + * Scan cell rows from a prepared statement, computing distances in-memory. + * The statement must return (n_vectors, validity, rowids, vectors) columns. + * queryVecQ is the quantized query (same type as cell vectors). + * qvecSize is the size of one quantized vector in bytes. + */ +static int ivf_scan_cells_from_stmt(vec0_vtab *p, int col_idx, + sqlite3_stmt *stmt, + const void *queryVecQ, int qvecSize, + struct IvfCandidate **candidates, + int *nCandidates, int *cap) { + while (sqlite3_step(stmt) == SQLITE_ROW) { + int n = sqlite3_column_int(stmt, 0); + if (n == 0) continue; + const unsigned char *validity = (const unsigned char *)sqlite3_column_blob(stmt, 1); + const i64 *rowids = (const i64 *)sqlite3_column_blob(stmt, 2); + const unsigned char *vectors = (const unsigned char *)sqlite3_column_blob(stmt, 3); + int valBytes = sqlite3_column_bytes(stmt, 1); + int ridsBytes = sqlite3_column_bytes(stmt, 2); + int vecsBytes = sqlite3_column_bytes(stmt, 3); + if (!validity || !rowids || !vectors) continue; + int cell_cap = valBytes * 8; + if (ridsBytes / (int)sizeof(i64) < cell_cap) cell_cap = ridsBytes / (int)sizeof(i64); + if (vecsBytes / qvecSize < cell_cap) cell_cap = vecsBytes / qvecSize; + + int found = 0; + for (int i = 0; i < cell_cap && found < n; i++) { + if (!(validity[i / 8] & (1 << (i % 8)))) continue; + found++; + if (*nCandidates >= *cap) { + *cap *= 2; + struct IvfCandidate *tmp = sqlite3_realloc64(*candidates, (i64)*cap * sizeof(struct IvfCandidate)); + if (!tmp) return SQLITE_NOMEM; + *candidates = tmp; + } + (*candidates)[*nCandidates].rowid = rowids[i]; + (*candidates)[*nCandidates].distance = ivf_distance(p, col_idx, + queryVecQ, &vectors[i * qvecSize]); + (*nCandidates)++; + } + } + return SQLITE_OK; +} + +static int ivf_query_knn(vec0_vtab *p, int col_idx, + const void *queryVector, int queryVectorSize, + i64 k, struct vec0_query_knn_data *knn_data) { + UNUSED_PARAMETER(queryVectorSize); + int rc; + int nprobe = p->vector_columns[col_idx].ivf.nprobe; + int trained = ivf_is_trained(p, col_idx); + int quantizer = p->vector_columns[col_idx].ivf.quantizer; + int oversample = p->vector_columns[col_idx].ivf.oversample; + int qvecSize = ivf_vec_size(p, col_idx); + + // Quantize query vector for scanning + void *queryQ = sqlite3_malloc(qvecSize); + if (!queryQ) return SQLITE_NOMEM; + ivf_quantize(p, col_idx, (const float *)queryVector, queryQ); + + // With oversample, collect more candidates for re-ranking + i64 collect_k = (oversample > 1) ? k * oversample : k; + + int cap = (collect_k < 1024) ? 1024 : (int)collect_k * 2; + int nCandidates = 0; + struct IvfCandidate *candidates = sqlite3_malloc64((i64)cap * sizeof(struct IvfCandidate)); + if (!candidates) { sqlite3_free(queryQ); return SQLITE_NOMEM; } + + if (trained) { + // Find top nprobe centroids using quantized distance + int nlist = 0; + rc = ivf_ensure_stmt(p, &p->stmtIvfCentroidsAll[col_idx], + "SELECT centroid_id, centroid FROM " VEC0_SHADOW_IVF_CENTROIDS_NAME, col_idx); + if (rc != SQLITE_OK) { sqlite3_free(queryQ); sqlite3_free(candidates); return rc; } + sqlite3_stmt *stmt = p->stmtIvfCentroidsAll[col_idx]; + sqlite3_reset(stmt); + + int centroid_cap = 64; + struct IvfCentroidDist *cd = sqlite3_malloc64(centroid_cap * sizeof(*cd)); + if (!cd) { sqlite3_free(queryQ); sqlite3_free(candidates); return SQLITE_NOMEM; } + + while (sqlite3_step(stmt) == SQLITE_ROW) { + if (nlist >= centroid_cap) { + centroid_cap *= 2; + struct IvfCentroidDist *tmp = sqlite3_realloc64(cd, centroid_cap * sizeof(*cd)); + if (!tmp) { sqlite3_free(cd); sqlite3_free(queryQ); sqlite3_free(candidates); return SQLITE_NOMEM; } + cd = tmp; + } + cd[nlist].id = sqlite3_column_int(stmt, 0); + const void *c = sqlite3_column_blob(stmt, 1); + int cBytes = sqlite3_column_bytes(stmt, 1); + // Compare quantized query with quantized centroid + cd[nlist].dist = (c && cBytes == qvecSize) ? ivf_distance(p, col_idx, queryQ, c) : FLT_MAX; + nlist++; + } + + int actual_nprobe = nprobe < nlist ? nprobe : nlist; + for (int i = 0; i < actual_nprobe; i++) { + int min_j = i; + for (int j = i + 1; j < nlist; j++) { + if (cd[j].dist < cd[min_j].dist) min_j = j; + } + if (min_j != i) { struct IvfCentroidDist tmp = cd[i]; cd[i] = cd[min_j]; cd[min_j] = tmp; } + } + + // Scan probed cells + unassigned with quantized distance + { + sqlite3_str *s = sqlite3_str_new(NULL); + sqlite3_str_appendf(s, + "SELECT n_vectors, validity, rowids, vectors FROM " VEC0_SHADOW_IVF_CELLS_NAME + " WHERE centroid_id IN (", + p->schemaName, p->tableName, col_idx); + for (int i = 0; i < actual_nprobe; i++) { + if (i > 0) sqlite3_str_appendall(s, ","); + sqlite3_str_appendf(s, "%d", cd[i].id); + } + sqlite3_str_appendf(s, ",%d)", VEC0_IVF_UNASSIGNED_CENTROID_ID); + char *zSql = sqlite3_str_finish(s); + if (!zSql) { sqlite3_free(cd); sqlite3_free(queryQ); sqlite3_free(candidates); return SQLITE_NOMEM; } + + sqlite3_stmt *stmtScan = NULL; + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmtScan, NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) { sqlite3_free(cd); sqlite3_free(queryQ); sqlite3_free(candidates); return rc; } + + rc = ivf_scan_cells_from_stmt(p, col_idx, stmtScan, queryQ, qvecSize, + &candidates, &nCandidates, &cap); + sqlite3_finalize(stmtScan); + if (rc != SQLITE_OK) { sqlite3_free(cd); sqlite3_free(queryQ); sqlite3_free(candidates); return rc; } + } + + sqlite3_free(cd); + } else { + // Flat mode: scan only unassigned cells + sqlite3_stmt *stmtScan = NULL; + char *zSql = sqlite3_mprintf( + "SELECT n_vectors, validity, rowids, vectors FROM " VEC0_SHADOW_IVF_CELLS_NAME + " WHERE centroid_id = %d", + p->schemaName, p->tableName, col_idx, VEC0_IVF_UNASSIGNED_CENTROID_ID); + if (!zSql) { sqlite3_free(queryQ); sqlite3_free(candidates); return SQLITE_NOMEM; } + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmtScan, NULL); sqlite3_free(zSql); + if (rc == SQLITE_OK) { + rc = ivf_scan_cells_from_stmt(p, col_idx, stmtScan, queryQ, qvecSize, + &candidates, &nCandidates, &cap); + sqlite3_finalize(stmtScan); + if (rc != SQLITE_OK) { sqlite3_free(queryQ); sqlite3_free(candidates); return rc; } + } + } + + sqlite3_free(queryQ); + + // Sort candidates by quantized distance + qsort(candidates, nCandidates, sizeof(struct IvfCandidate), ivf_candidate_cmp); + + // Oversample re-ranking: re-score top (oversample*k) with full-precision vectors + if (oversample > 1 && quantizer != VEC0_IVF_QUANTIZER_NONE && nCandidates > 0) { + i64 rescore_n = collect_k < nCandidates ? collect_k : nCandidates; + sqlite3_stmt *stmtVec = NULL; + char *zSql = sqlite3_mprintf( + "SELECT vector FROM " VEC0_SHADOW_IVF_VECTORS_NAME " WHERE rowid = ?", + p->schemaName, p->tableName, col_idx); + if (zSql) { + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmtVec, NULL); sqlite3_free(zSql); + if (rc == SQLITE_OK) { + for (i64 i = 0; i < rescore_n; i++) { + sqlite3_reset(stmtVec); + sqlite3_bind_int64(stmtVec, 1, candidates[i].rowid); + if (sqlite3_step(stmtVec) == SQLITE_ROW) { + const float *fullVec = (const float *)sqlite3_column_blob(stmtVec, 0); + int fullVecBytes = sqlite3_column_bytes(stmtVec, 0); + if (fullVec && fullVecBytes == (int)p->vector_columns[col_idx].dimensions * (int)sizeof(float)) { + candidates[i].distance = ivf_distance_float(p, col_idx, + (const float *)queryVector, fullVec); + } + } + } + sqlite3_finalize(stmtVec); + } + } + // Re-sort after re-scoring + qsort(candidates, (size_t)rescore_n, sizeof(struct IvfCandidate), ivf_candidate_cmp); + nCandidates = (int)rescore_n; + } + + qsort(candidates, nCandidates, sizeof(struct IvfCandidate), ivf_candidate_cmp); + i64 nResults = nCandidates < k ? nCandidates : k; + + if (nResults == 0) { + knn_data->rowids = NULL; knn_data->distances = NULL; + knn_data->k = k; knn_data->k_used = 0; knn_data->current_idx = 0; + sqlite3_free(candidates); return SQLITE_OK; + } + + knn_data->rowids = sqlite3_malloc64(nResults * sizeof(i64)); + knn_data->distances = sqlite3_malloc64(nResults * sizeof(f32)); + if (!knn_data->rowids || !knn_data->distances) { + sqlite3_free(knn_data->rowids); sqlite3_free(knn_data->distances); + sqlite3_free(candidates); return SQLITE_NOMEM; + } + for (i64 i = 0; i < nResults; i++) { + knn_data->rowids[i] = candidates[i].rowid; + knn_data->distances[i] = candidates[i].distance; + } + knn_data->k = k; knn_data->k_used = nResults; knn_data->current_idx = 0; + sqlite3_free(candidates); + return SQLITE_OK; +} + +// ============================================================================ +// Command dispatch +// ============================================================================ + +static int ivf_handle_command(vec0_vtab *p, const char *command, + int argc, sqlite3_value **argv) { + UNUSED_PARAMETER(argc); + int col_idx = -1; + for (int i = 0; i < p->numVectorColumns; i++) { + if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_IVF) { col_idx = i; break; } + } + if (col_idx < 0) return SQLITE_EMPTY; + + // nprobe=N — change nprobe at runtime without rebuilding + if (strncmp(command, "nprobe=", 7) == 0) { + int new_nprobe = atoi(command + 7); + if (new_nprobe < 1) { + vtab_set_error(&p->base, "nprobe must be >= 1"); + return SQLITE_ERROR; + } + p->vector_columns[col_idx].ivf.nprobe = new_nprobe; + return SQLITE_OK; + } + + if (strcmp(command, "compute-centroids") == 0) + return ivf_cmd_compute_centroids(p, col_idx, 0, VEC0_IVF_KMEANS_MAX_ITER, VEC0_IVF_KMEANS_DEFAULT_SEED); + + if (strncmp(command, "compute-centroids:", 18) == 0) { + const char *json = command + 18; + int nlist = 0, max_iter = VEC0_IVF_KMEANS_MAX_ITER; + uint32_t seed = VEC0_IVF_KMEANS_DEFAULT_SEED; + const char *pn = strstr(json, "\"nlist\":"); if (pn) nlist = atoi(pn + 8); + const char *pi = strstr(json, "\"max_iterations\":"); if (pi) max_iter = atoi(pi + 17); + const char *ps = strstr(json, "\"seed\":"); if (ps) seed = (uint32_t)atoi(ps + 7); + return ivf_cmd_compute_centroids(p, col_idx, nlist, max_iter, seed); + } + + if (strncmp(command, "set-centroid:", 13) == 0) { + int centroid_id = atoi(command + 13); + for (int i = 0; i < (int)(p->numVectorColumns + p->numPartitionColumns + + p->numAuxiliaryColumns + p->numMetadataColumns); i++) { + if (p->user_column_kinds[i] == SQLITE_VEC0_USER_COLUMN_KIND_VECTOR && + p->user_column_idxs[i] == col_idx) { + sqlite3_value *v = argv[2 + VEC0_COLUMN_USERN_START + i]; + if (sqlite3_value_type(v) == SQLITE_NULL) { vtab_set_error(&p->base, "set-centroid requires vector"); return SQLITE_ERROR; } + return ivf_cmd_set_centroid(p, col_idx, centroid_id, sqlite3_value_blob(v), sqlite3_value_bytes(v)); + } + } + return SQLITE_ERROR; + } + + if (strcmp(command, "assign-vectors") == 0) return ivf_cmd_assign_vectors(p, col_idx); + if (strcmp(command, "clear-centroids") == 0) return ivf_cmd_clear_centroids(p, col_idx); + return SQLITE_EMPTY; +} + +#endif /* SQLITE_VEC_IVF_C */ diff --git a/sqlite-vec.c b/sqlite-vec.c index 7079f7e..015792b 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -93,6 +93,10 @@ typedef size_t usize; #define COMPILER_SUPPORTS_VTAB_IN 1 #endif +#ifndef SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE +#define SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE 0 +#endif + #ifndef SQLITE_SUBTYPE #define SQLITE_SUBTYPE 0x000100000 #endif @@ -2539,6 +2543,7 @@ enum Vec0IndexType { #if SQLITE_VEC_ENABLE_RESCORE VEC0_INDEX_TYPE_RESCORE = 2, #endif + VEC0_INDEX_TYPE_IVF = 3, }; #if SQLITE_VEC_ENABLE_RESCORE @@ -2553,6 +2558,22 @@ struct Vec0RescoreConfig { }; #endif +#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE +enum Vec0IvfQuantizer { + VEC0_IVF_QUANTIZER_NONE = 0, + VEC0_IVF_QUANTIZER_INT8 = 1, + VEC0_IVF_QUANTIZER_BINARY = 2, +}; + +struct Vec0IvfConfig { + int nlist; // number of centroids (0 = deferred) + int nprobe; // cells to probe at query time + int quantizer; // VEC0_IVF_QUANTIZER_NONE / INT8 / BINARY + int oversample; // >= 1 (1 = no oversampling) +}; +#else +struct Vec0IvfConfig { char _unused; }; +#endif struct VectorColumnDefinition { char *name; @@ -2564,6 +2585,7 @@ struct VectorColumnDefinition { #if SQLITE_VEC_ENABLE_RESCORE struct Vec0RescoreConfig rescore; #endif + struct Vec0IvfConfig ivf; }; struct Vec0PartitionColumnDefinition { @@ -2715,6 +2737,12 @@ static int vec0_parse_rescore_options(struct Vec0Scanner *scanner, * @return int SQLITE_OK on success, SQLITE_EMPTY is it's not a vector column * definition, SQLITE_ERROR on error. */ +#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE +// Forward declaration — defined in sqlite-vec-ivf.c +static int vec0_parse_ivf_options(struct Vec0Scanner *scanner, + struct Vec0IvfConfig *config); +#endif + int vec0_parse_vector_column(const char *source, int source_length, struct VectorColumnDefinition *outColumn) { // parses a vector column definition like so: @@ -2733,6 +2761,8 @@ int vec0_parse_vector_column(const char *source, int source_length, struct Vec0RescoreConfig rescoreConfig; memset(&rescoreConfig, 0, sizeof(rescoreConfig)); #endif + struct Vec0IvfConfig ivfConfig; + memset(&ivfConfig, 0, sizeof(ivfConfig)); int dimensions; vec0_scanner_init(&scanner, source, source_length); @@ -2891,7 +2921,18 @@ int vec0_parse_vector_column(const char *source, int source_length, } } #endif - else { + else if (sqlite3_strnicmp(token.start, "ivf", indexNameLen) == 0) { +#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE + indexType = VEC0_INDEX_TYPE_IVF; + memset(&ivfConfig, 0, sizeof(ivfConfig)); + rc = vec0_parse_ivf_options(&scanner, &ivfConfig); + if (rc != SQLITE_OK) { + return SQLITE_ERROR; + } +#else + return SQLITE_ERROR; // IVF not compiled in +#endif + } else { // unknown index type return SQLITE_ERROR; } @@ -2914,6 +2955,7 @@ int vec0_parse_vector_column(const char *source, int source_length, #if SQLITE_VEC_ENABLE_RESCORE outColumn->rescore = rescoreConfig; #endif + outColumn->ivf = ivfConfig; return SQLITE_OK; } @@ -3279,6 +3321,18 @@ struct vec0_vtab { int chunk_size; +#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE + // IVF cached state per vector column + char *shadowIvfCellsNames[VEC0_MAX_VECTOR_COLUMNS]; // table name for blob_open + int ivfTrainedCache[VEC0_MAX_VECTOR_COLUMNS]; // -1=unknown, 0=no, 1=yes + sqlite3_stmt *stmtIvfCellMeta[VEC0_MAX_VECTOR_COLUMNS]; // SELECT n_vectors, length(validity)*8 FROM cells WHERE cell_id=? + sqlite3_stmt *stmtIvfCellUpdateN[VEC0_MAX_VECTOR_COLUMNS]; // UPDATE cells SET n_vectors=n_vectors+? WHERE cell_id=? + sqlite3_stmt *stmtIvfRowidMapInsert[VEC0_MAX_VECTOR_COLUMNS]; // INSERT INTO rowid_map(rowid,cell_id,slot) VALUES(?,?,?) + sqlite3_stmt *stmtIvfRowidMapLookup[VEC0_MAX_VECTOR_COLUMNS]; // SELECT cell_id,slot FROM rowid_map WHERE rowid=? + sqlite3_stmt *stmtIvfRowidMapDelete[VEC0_MAX_VECTOR_COLUMNS]; // DELETE FROM rowid_map WHERE rowid=? + sqlite3_stmt *stmtIvfCentroidsAll[VEC0_MAX_VECTOR_COLUMNS]; // SELECT centroid_id,centroid FROM centroids +#endif + // select latest chunk from _chunks, getting chunk_id sqlite3_stmt *stmtLatestChunk; @@ -3364,6 +3418,17 @@ void vec0_free_resources(vec0_vtab *p) { p->stmtRowidsUpdatePosition = NULL; sqlite3_finalize(p->stmtRowidsGetChunkPosition); p->stmtRowidsGetChunkPosition = NULL; + +#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE + for (int i = 0; i < VEC0_MAX_VECTOR_COLUMNS; i++) { + sqlite3_finalize(p->stmtIvfCellMeta[i]); p->stmtIvfCellMeta[i] = NULL; + sqlite3_finalize(p->stmtIvfCellUpdateN[i]); p->stmtIvfCellUpdateN[i] = NULL; + sqlite3_finalize(p->stmtIvfRowidMapInsert[i]); p->stmtIvfRowidMapInsert[i] = NULL; + sqlite3_finalize(p->stmtIvfRowidMapLookup[i]); p->stmtIvfRowidMapLookup[i] = NULL; + sqlite3_finalize(p->stmtIvfRowidMapDelete[i]); p->stmtIvfRowidMapDelete[i] = NULL; + sqlite3_finalize(p->stmtIvfCentroidsAll[i]); p->stmtIvfCentroidsAll[i] = NULL; + } +#endif } /** @@ -3386,6 +3451,10 @@ void vec0_free(vec0_vtab *p) { for (int i = 0; i < p->numVectorColumns; i++) { sqlite3_free(p->shadowVectorChunksNames[i]); p->shadowVectorChunksNames[i] = NULL; +#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE + sqlite3_free(p->shadowIvfCellsNames[i]); + p->shadowIvfCellsNames[i] = NULL; +#endif #if SQLITE_VEC_ENABLE_RESCORE sqlite3_free(p->shadowRescoreChunksNames[i]); @@ -3674,12 +3743,25 @@ int vec0_result_id(vec0_vtab *p, sqlite3_context *context, i64 rowid) { * will be stored. * @return int SQLITE_OK on success. */ +#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE +// Forward declaration — defined in sqlite-vec-ivf.c (included later) +static int ivf_get_vector_data(vec0_vtab *p, i64 rowid, int col_idx, + void **outVector, int *outVectorSize); +#endif + int vec0_get_vector_data(vec0_vtab *pVtab, i64 rowid, int vector_column_idx, void **outVector, int *outVectorSize) { vec0_vtab *p = pVtab; int rc, brc; i64 chunk_id; i64 chunk_offset; + +#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE + // IVF-indexed columns store vectors in _ivf_cells, not _vector_chunks + if (p->vector_columns[vector_column_idx].index_type == VEC0_INDEX_TYPE_IVF) { + return ivf_get_vector_data(p, rowid, vector_column_idx, outVector, outVectorSize); + } +#endif size_t size; void *buf = NULL; int blobOffset; @@ -4327,8 +4409,12 @@ int vec0_new_chunk(vec0_vtab *p, sqlite3_value ** partitionKeyValues, i64 *chunk int vector_column_idx = p->user_column_idxs[i]; #if SQLITE_VEC_ENABLE_RESCORE - // Rescore columns don't use _vector_chunks for float storage - if (p->vector_columns[vector_column_idx].index_type == VEC0_INDEX_TYPE_RESCORE) { + // Rescore and IVF columns don't use _vector_chunks for float storage + if (p->vector_columns[vector_column_idx].index_type == VEC0_INDEX_TYPE_RESCORE +#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE + || p->vector_columns[vector_column_idx].index_type == VEC0_INDEX_TYPE_IVF +#endif + ) { continue; } #endif @@ -4500,6 +4586,12 @@ void vec0_cursor_clear(vec0_cursor *pCur) { } } +// IVF index implementation — #include'd here after all struct/helper definitions +#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE +#include "sqlite-vec-ivf-kmeans.c" +#include "sqlite-vec-ivf.c" +#endif + #define VEC_CONSTRUCTOR_ERROR "vec0 constructor error: " static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, sqlite3_vtab **ppVtab, char **pzErr, bool isCreate) { @@ -4761,6 +4853,34 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, } #endif + // IVF indexes do not support auxiliary, metadata, or partition key columns. + { + int has_ivf = 0; + for (int i = 0; i < numVectorColumns; i++) { + if (pNew->vector_columns[i].index_type == VEC0_INDEX_TYPE_IVF) { + has_ivf = 1; + break; + } + } + if (has_ivf) { + if (numPartitionColumns > 0) { + *pzErr = sqlite3_mprintf(VEC_CONSTRUCTOR_ERROR + "partition key columns are not supported with IVF indexes"); + goto error; + } + if (numAuxiliaryColumns > 0) { + *pzErr = sqlite3_mprintf(VEC_CONSTRUCTOR_ERROR + "auxiliary columns are not supported with IVF indexes"); + goto error; + } + if (numMetadataColumns > 0) { + *pzErr = sqlite3_mprintf(VEC_CONSTRUCTOR_ERROR + "metadata columns are not supported with IVF indexes"); + goto error; + } + } + } + sqlite3_str *createStr = sqlite3_str_new(NULL); sqlite3_str_appendall(createStr, "CREATE TABLE x("); if (pkColumnName) { @@ -4866,6 +4986,15 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, } #endif } +#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE + for (int i = 0; i < pNew->numVectorColumns; i++) { + if (pNew->vector_columns[i].index_type != VEC0_INDEX_TYPE_IVF) continue; + pNew->shadowIvfCellsNames[i] = + sqlite3_mprintf("%s_ivf_cells%02d", tableName, i); + if (!pNew->shadowIvfCellsNames[i]) goto error; + pNew->ivfTrainedCache[i] = -1; // unknown + } +#endif for (int i = 0; i < pNew->numMetadataColumns; i++) { pNew->shadowMetadataChunksNames[i] = sqlite3_mprintf("%s_metadatachunks%02d", tableName, i); @@ -4989,8 +5118,8 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, for (int i = 0; i < pNew->numVectorColumns; i++) { #if SQLITE_VEC_ENABLE_RESCORE - // Rescore columns don't use _vector_chunks - if (pNew->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE) + // Rescore and IVF columns don't use _vector_chunks + if (pNew->vector_columns[i].index_type != VEC0_INDEX_TYPE_FLAT) continue; #endif char *zSql = sqlite3_mprintf(VEC0_SHADOW_VECTOR_N_CREATE, @@ -5018,6 +5147,18 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, } #endif +#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE + // Create IVF shadow tables for IVF-indexed vector columns + for (int i = 0; i < pNew->numVectorColumns; i++) { + if (pNew->vector_columns[i].index_type != VEC0_INDEX_TYPE_IVF) continue; + rc = ivf_create_shadow_tables(pNew, i); + if (rc != SQLITE_OK) { + *pzErr = sqlite3_mprintf("Could not create IVF shadow tables for column %d", i); + goto error; + } + } +#endif + // See SHADOW_TABLE_ROWID_QUIRK in vec0_new_chunk() — same "rowid PRIMARY KEY" // without INTEGER type issue applies here. for (int i = 0; i < pNew->numMetadataColumns; i++) { @@ -5153,7 +5294,7 @@ static int vec0Destroy(sqlite3_vtab *pVtab) { for (int i = 0; i < p->numVectorColumns; i++) { #if SQLITE_VEC_ENABLE_RESCORE - if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE) + if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_FLAT) continue; #endif zSql = sqlite3_mprintf("DROP TABLE \"%w\".\"%w\"", p->schemaName, @@ -5174,6 +5315,14 @@ static int vec0Destroy(sqlite3_vtab *pVtab) { } #endif +#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE + // Drop IVF shadow tables + for (int i = 0; i < p->numVectorColumns; i++) { + if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_IVF) continue; + ivf_drop_shadow_tables(p, i); + } +#endif + if(p->numAuxiliaryColumns > 0) { zSql = sqlite3_mprintf("DROP TABLE " VEC0_SHADOW_AUXILIARY_NAME, p->schemaName, p->tableName); rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0); @@ -7186,6 +7335,21 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum, } #endif +#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE + // IVF dispatch: if vector column has IVF, use IVF query instead of chunk scan + if (vector_column->index_type == VEC0_INDEX_TYPE_IVF) { + rc = ivf_query_knn(p, vectorColumnIdx, queryVector, + (int)vector_column_byte_size(*vector_column), k, knn_data); + if (rc != SQLITE_OK) { + goto cleanup; + } + pCur->knn_data = knn_data; + pCur->query_plan = VEC0_QUERY_PLAN_KNN; + rc = SQLITE_OK; + goto cleanup; + } +#endif + rc = vec0_chunks_iter(p, idxStr, argc, argv, &stmtChunks); if (rc != SQLITE_OK) { // IMP: V06942_23781 @@ -8011,8 +8175,12 @@ int vec0Update_InsertWriteFinalStep(vec0_vtab *p, i64 chunk_rowid, // Go insert the vector data into the vector chunk shadow tables for (int i = 0; i < p->numVectorColumns; i++) { #if SQLITE_VEC_ENABLE_RESCORE - // Rescore columns store float vectors in _rescore_vectors instead - if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE) + // Rescore and IVF columns don't use _vector_chunks + if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE +#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE + || p->vector_columns[i].index_type == VEC0_INDEX_TYPE_IVF +#endif + ) continue; #endif @@ -8425,6 +8593,18 @@ int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv, } #endif +#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE + // Step #4: IVF index insert (if any vector column uses IVF) + for (int i = 0; i < p->numVectorColumns; i++) { + if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_IVF) continue; + int vecSize = (int)vector_column_byte_size(p->vector_columns[i]); + rc = ivf_insert(p, i, rowid, vectorDatas[i], vecSize); + if (rc != SQLITE_OK) { + goto cleanup; + } + } +#endif + if(p->numAuxiliaryColumns > 0) { sqlite3_stmt *stmt; sqlite3_str * s = sqlite3_str_new(NULL); @@ -8616,8 +8796,8 @@ int vec0Update_Delete_ClearVectors(vec0_vtab *p, i64 chunk_id, int rc, brc; for (int i = 0; i < p->numVectorColumns; i++) { #if SQLITE_VEC_ENABLE_RESCORE - // Rescore columns don't use _vector_chunks - if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE) + // Non-FLAT columns don't use _vector_chunks + if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_FLAT) continue; #endif sqlite3_blob *blobVectors = NULL; @@ -8732,7 +8912,7 @@ int vec0Update_Delete_DeleteChunkIfEmpty(vec0_vtab *p, i64 chunk_id, // Delete from each _vector_chunksNN for (int i = 0; i < p->numVectorColumns; i++) { #if SQLITE_VEC_ENABLE_RESCORE - if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE) + if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_FLAT) continue; #endif zSql = sqlite3_mprintf( @@ -9009,6 +9189,15 @@ int vec0Update_Delete(sqlite3_vtab *pVTab, sqlite3_value *idValue) { } } +#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE + // 7. delete from IVF index + for (int i = 0; i < p->numVectorColumns; i++) { + if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_IVF) continue; + rc = ivf_delete(p, i, rowid); + if (rc != SQLITE_OK) return rc; + } +#endif + return SQLITE_OK; } @@ -9284,6 +9473,18 @@ static int vec0Update(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv, } // INSERT operation else if (argc > 1 && sqlite3_value_type(argv[0]) == SQLITE_NULL) { +#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE + // Check for IVF command inserts: INSERT INTO t(rowid) VALUES ('compute-centroids') + // The id column holds the command string. + sqlite3_value *idVal = argv[2 + VEC0_COLUMN_ID]; + if (sqlite3_value_type(idVal) == SQLITE_TEXT) { + const char *cmd = (const char *)sqlite3_value_text(idVal); + vec0_vtab *p = (vec0_vtab *)pVTab; + int cmdRc = ivf_handle_command(p, cmd, argc, argv); + if (cmdRc != SQLITE_EMPTY) return cmdRc; // handled (or error) + // SQLITE_EMPTY means not an IVF command — fall through to normal insert + } +#endif return vec0Update_Insert(pVTab, argc, argv, pRowid); } // UPDATE operation @@ -9431,9 +9632,15 @@ static sqlite3_module vec0Module = { #define SQLITE_VEC_DEBUG_BUILD_RESCORE "" #endif +#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE +#define SQLITE_VEC_DEBUG_BUILD_IVF "ivf" +#else +#define SQLITE_VEC_DEBUG_BUILD_IVF "" +#endif + #define SQLITE_VEC_DEBUG_BUILD \ SQLITE_VEC_DEBUG_BUILD_AVX " " SQLITE_VEC_DEBUG_BUILD_NEON " " \ - SQLITE_VEC_DEBUG_BUILD_RESCORE + SQLITE_VEC_DEBUG_BUILD_RESCORE " " SQLITE_VEC_DEBUG_BUILD_IVF #define SQLITE_VEC_DEBUG_STRING \ "Version: " SQLITE_VEC_VERSION "\n" \ diff --git a/tests/conftest.py b/tests/conftest.py index 9549d37..3a24468 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,29 @@ import pytest import sqlite3 +import os + + +def _vec_debug(): + db = sqlite3.connect(":memory:") + db.enable_load_extension(True) + db.load_extension("dist/vec0") + db.enable_load_extension(False) + return db.execute("SELECT vec_debug()").fetchone()[0] + + +def _has_build_flag(flag): + return flag in _vec_debug().split("Build flags:")[-1] + + +def pytest_collection_modifyitems(config, items): + has_ivf = _has_build_flag("ivf") + if has_ivf: + return + skip_ivf = pytest.mark.skip(reason="IVF not enabled (compile with -DSQLITE_VEC_EXPERIMENTAL_IVF_ENABLE=1)") + ivf_prefixes = ("test-ivf",) + for item in items: + if any(item.fspath.basename.startswith(p) for p in ivf_prefixes): + item.add_marker(skip_ivf) @pytest.fixture() diff --git a/tests/fuzz/Makefile b/tests/fuzz/Makefile index 0030c2e..a3405a4 100644 --- a/tests/fuzz/Makefile +++ b/tests/fuzz/Makefile @@ -93,13 +93,40 @@ $(TARGET_DIR)/rescore_quantize_edge: rescore-quantize-edge.c $(FUZZ_SRCS) | $(TA $(TARGET_DIR)/rescore_interleave: rescore-interleave.c $(FUZZ_SRCS) | $(TARGET_DIR) $(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE $(FUZZ_SRCS) $< -o $@ +$(TARGET_DIR)/ivf_create: ivf-create.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@ + +$(TARGET_DIR)/ivf_operations: ivf-operations.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@ + +$(TARGET_DIR)/ivf_quantize: ivf-quantize.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@ + +$(TARGET_DIR)/ivf_kmeans: ivf-kmeans.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@ + +$(TARGET_DIR)/ivf_shadow_corrupt: ivf-shadow-corrupt.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@ + +$(TARGET_DIR)/ivf_knn_deep: ivf-knn-deep.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@ + +$(TARGET_DIR)/ivf_cell_overflow: ivf-cell-overflow.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@ + +$(TARGET_DIR)/ivf_rescore: ivf-rescore.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@ + FUZZ_TARGETS = vec0_create exec json numpy \ shadow_corrupt vec0_operations scalar_functions \ vec0_create_full metadata_columns vec_each vec_mismatch \ vec0_delete_completeness \ rescore_operations rescore_create rescore_quantize \ rescore_shadow_corrupt rescore_knn_deep \ - rescore_quantize_edge rescore_interleave + rescore_quantize_edge rescore_interleave \ + ivf_create ivf_operations \ + ivf_quantize ivf_kmeans ivf_shadow_corrupt \ + ivf_knn_deep ivf_cell_overflow ivf_rescore all: $(addprefix $(TARGET_DIR)/,$(FUZZ_TARGETS)) diff --git a/tests/fuzz/ivf-cell-overflow.c b/tests/fuzz/ivf-cell-overflow.c new file mode 100644 index 0000000..4b18ba2 --- /dev/null +++ b/tests/fuzz/ivf-cell-overflow.c @@ -0,0 +1,192 @@ +/** + * Fuzz target: IVF cell overflow and boundary conditions. + * + * Pushes cells past VEC0_IVF_CELL_MAX_VECTORS (64) to trigger cell + * splitting, then exercises blob I/O at slot boundaries. + * + * Targets: + * - Cell splitting when n_vectors reaches cap (64) + * - Blob offset arithmetic: slot * vecSize, slot / 8, slot % 8 + * - Validity bitmap at byte boundaries (slot 7->8, 15->16, etc.) + * - Insert into full cell -> create new cell path + * - Delete from various slot positions (first, last, middle) + * - Multiple cells per centroid + * - assign-vectors command with multi-cell centroids + */ +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + if (size < 8) return 0; + + int rc; + sqlite3 *db; + + rc = sqlite3_open(":memory:", &db); + assert(rc == SQLITE_OK); + rc = sqlite3_vec_init(db, NULL, NULL); + assert(rc == SQLITE_OK); + + // Use small dimensions for speed but enough vectors to overflow cells + int dim = (data[0] % 8) + 2; // 2..9 + int nlist = (data[1] % 4) + 1; // 1..4 + // We need >64 vectors to overflow a cell + int num_vecs = (data[2] % 64) + 65; // 65..128 + int delete_pattern = data[3]; // Controls which vectors to delete + + const uint8_t *payload = data + 4; + size_t payload_size = size - 4; + + char sql[256]; + snprintf(sql, sizeof(sql), + "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[%d] indexed by ivf(nlist=%d, nprobe=%d))", + dim, nlist, nlist); + + rc = sqlite3_exec(db, sql, NULL, NULL, NULL); + if (rc != SQLITE_OK) { sqlite3_close(db); return 0; } + + // Insert enough vectors to overflow at least one cell + sqlite3_stmt *stmtInsert = NULL; + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); + if (!stmtInsert) { sqlite3_close(db); return 0; } + + size_t offset = 0; + for (int i = 0; i < num_vecs; i++) { + float *vec = sqlite3_malloc(dim * sizeof(float)); + if (!vec) break; + for (int d = 0; d < dim; d++) { + if (offset < payload_size) { + vec[d] = ((float)(int8_t)payload[offset++]) / 50.0f; + } else { + // Cluster vectors near specific centroids to ensure some cells overflow + int cluster = i % nlist; + vec[d] = (float)cluster + (float)(i % 10) * 0.01f + d * 0.001f; + } + } + sqlite3_reset(stmtInsert); + sqlite3_bind_int64(stmtInsert, 1, (int64_t)(i + 1)); + sqlite3_bind_blob(stmtInsert, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT); + sqlite3_step(stmtInsert); + sqlite3_free(vec); + } + sqlite3_finalize(stmtInsert); + + // Train to assign vectors to centroids (triggers cell building) + sqlite3_exec(db, + "INSERT INTO v(rowid) VALUES ('compute-centroids')", + NULL, NULL, NULL); + + // Delete vectors at boundary positions based on fuzz data + // This tests validity bitmap manipulation at different slot positions + for (int i = 0; i < num_vecs; i++) { + int byte_idx = i / 8; + if (byte_idx < (int)payload_size && (payload[byte_idx] & (1 << (i % 8)))) { + // Use delete_pattern to thin deletions + if ((delete_pattern + i) % 3 == 0) { + char delsql[64]; + snprintf(delsql, sizeof(delsql), "DELETE FROM v WHERE rowid = %d", i + 1); + sqlite3_exec(db, delsql, NULL, NULL, NULL); + } + } + } + + // Insert more vectors after deletions (into cells with holes) + { + sqlite3_stmt *si = NULL; + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &si, NULL); + if (si) { + for (int i = 0; i < 10; i++) { + float *vec = sqlite3_malloc(dim * sizeof(float)); + if (!vec) break; + for (int d = 0; d < dim; d++) + vec[d] = (float)(i + 200) * 0.01f; + sqlite3_reset(si); + sqlite3_bind_int64(si, 1, (int64_t)(num_vecs + i + 1)); + sqlite3_bind_blob(si, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT); + sqlite3_step(si); + sqlite3_free(vec); + } + sqlite3_finalize(si); + } + } + + // KNN query that must scan multiple cells per centroid + { + float *qvec = sqlite3_malloc(dim * sizeof(float)); + if (qvec) { + for (int d = 0; d < dim; d++) qvec[d] = 0.0f; + sqlite3_stmt *sk = NULL; + snprintf(sql, sizeof(sql), + "SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 20"); + sqlite3_prepare_v2(db, sql, -1, &sk, NULL); + if (sk) { + sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT); + while (sqlite3_step(sk) == SQLITE_ROW) {} + sqlite3_finalize(sk); + } + sqlite3_free(qvec); + } + } + + // Test assign-vectors with multi-cell state + // First clear centroids + sqlite3_exec(db, + "INSERT INTO v(rowid) VALUES ('clear-centroids')", + NULL, NULL, NULL); + + // Set centroids manually, then assign + for (int c = 0; c < nlist; c++) { + float *cvec = sqlite3_malloc(dim * sizeof(float)); + if (!cvec) break; + for (int d = 0; d < dim; d++) cvec[d] = (float)c + d * 0.1f; + + char cmd[128]; + snprintf(cmd, sizeof(cmd), + "INSERT INTO v(rowid, emb) VALUES ('set-centroid:%d', ?)", c); + sqlite3_stmt *sc = NULL; + sqlite3_prepare_v2(db, cmd, -1, &sc, NULL); + if (sc) { + sqlite3_bind_blob(sc, 1, cvec, dim * sizeof(float), SQLITE_TRANSIENT); + sqlite3_step(sc); + sqlite3_finalize(sc); + } + sqlite3_free(cvec); + } + + sqlite3_exec(db, + "INSERT INTO v(rowid) VALUES ('assign-vectors')", + NULL, NULL, NULL); + + // Final query after assign-vectors + { + float *qvec = sqlite3_malloc(dim * sizeof(float)); + if (qvec) { + for (int d = 0; d < dim; d++) qvec[d] = 1.0f; + sqlite3_stmt *sk = NULL; + sqlite3_prepare_v2(db, + "SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 5", + -1, &sk, NULL); + if (sk) { + sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT); + while (sqlite3_step(sk) == SQLITE_ROW) {} + sqlite3_finalize(sk); + } + sqlite3_free(qvec); + } + } + + // Full scan + sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL); + + sqlite3_close(db); + return 0; +} diff --git a/tests/fuzz/ivf-create.c b/tests/fuzz/ivf-create.c new file mode 100644 index 0000000..222b67b --- /dev/null +++ b/tests/fuzz/ivf-create.c @@ -0,0 +1,36 @@ +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + int rc; + sqlite3 *db; + sqlite3_stmt *stmt; + + rc = sqlite3_open(":memory:", &db); + assert(rc == SQLITE_OK); + rc = sqlite3_vec_init(db, NULL, NULL); + assert(rc == SQLITE_OK); + + sqlite3_str *s = sqlite3_str_new(NULL); + assert(s); + sqlite3_str_appendall(s, "CREATE VIRTUAL TABLE v USING vec0(emb float[4] indexed by ivf("); + sqlite3_str_appendf(s, "%.*s", (int)size, data); + sqlite3_str_appendall(s, "))"); + const char *zSql = sqlite3_str_finish(s); + assert(zSql); + + rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, NULL); + sqlite3_free((void *)zSql); + if (rc == SQLITE_OK) { + sqlite3_step(stmt); + } + sqlite3_finalize(stmt); + sqlite3_close(db); + return 0; +} diff --git a/tests/fuzz/ivf-create.dict b/tests/fuzz/ivf-create.dict new file mode 100644 index 0000000..9a014e7 --- /dev/null +++ b/tests/fuzz/ivf-create.dict @@ -0,0 +1,16 @@ +"nlist" +"nprobe" +"quantizer" +"oversample" +"binary" +"int8" +"none" +"=" +"," +"(" +")" +"0" +"1" +"128" +"65536" +"65537" diff --git a/tests/fuzz/ivf-kmeans.c b/tests/fuzz/ivf-kmeans.c new file mode 100644 index 0000000..46804d0 --- /dev/null +++ b/tests/fuzz/ivf-kmeans.c @@ -0,0 +1,180 @@ +/** + * Fuzz target: IVF k-means clustering. + * + * Builds a table, inserts fuzz-controlled vectors, then runs + * compute-centroids with fuzz-controlled parameters (nlist, max_iter, seed). + * Targets: + * - kmeans with N < k (clamping), N == 1, k == 1 + * - kmeans with duplicate/identical vectors (all distances zero) + * - kmeans with NaN/Inf vectors + * - Empty cluster reassignment path (farthest-point heuristic) + * - Large nlist relative to N + * - The compute-centroids:{json} command parsing + * - clear-centroids followed by compute-centroids (round-trip) + */ +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + if (size < 10) return 0; + + int rc; + sqlite3 *db; + + rc = sqlite3_open(":memory:", &db); + assert(rc == SQLITE_OK); + rc = sqlite3_vec_init(db, NULL, NULL); + assert(rc == SQLITE_OK); + + // Parse fuzz header + // Byte 0-1: dimension (1..128) + // Byte 2: nlist for CREATE (1..64) + // Byte 3: nlist override for compute-centroids (0 = use default) + // Byte 4: max_iter (1..50) + // Byte 5-8: seed + // Byte 9: num_vectors (1..64) + // Remaining: vector float data + + int dim = (data[0] | (data[1] << 8)) % 128 + 1; + int nlist_create = (data[2] % 64) + 1; + int nlist_override = data[3] % 65; // 0 means use table default + int max_iter = (data[4] % 50) + 1; + uint32_t seed = (uint32_t)data[5] | ((uint32_t)data[6] << 8) | + ((uint32_t)data[7] << 16) | ((uint32_t)data[8] << 24); + int num_vecs = (data[9] % 64) + 1; + + const uint8_t *payload = data + 10; + size_t payload_size = size - 10; + + char sql[256]; + snprintf(sql, sizeof(sql), + "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[%d] indexed by ivf(nlist=%d, nprobe=%d))", + dim, nlist_create, nlist_create); + + rc = sqlite3_exec(db, sql, NULL, NULL, NULL); + if (rc != SQLITE_OK) { sqlite3_close(db); return 0; } + + // Insert vectors + sqlite3_stmt *stmtInsert = NULL; + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); + if (!stmtInsert) { sqlite3_close(db); return 0; } + + size_t offset = 0; + for (int i = 0; i < num_vecs; i++) { + float *vec = sqlite3_malloc(dim * sizeof(float)); + if (!vec) break; + + for (int d = 0; d < dim; d++) { + if (offset + 4 <= payload_size) { + memcpy(&vec[d], payload + offset, sizeof(float)); + offset += 4; + } else if (offset < payload_size) { + // Scale to interesting range including values > 1, < -1 + vec[d] = ((float)(int8_t)payload[offset++]) / 5.0f; + } else { + // Reuse earlier bytes to fill remaining dimensions + vec[d] = (float)(i * dim + d) * 0.01f; + } + } + + sqlite3_reset(stmtInsert); + sqlite3_bind_int64(stmtInsert, 1, (int64_t)(i + 1)); + sqlite3_bind_blob(stmtInsert, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT); + sqlite3_step(stmtInsert); + sqlite3_free(vec); + } + sqlite3_finalize(stmtInsert); + + // Exercise compute-centroids with JSON options + { + char cmd[256]; + snprintf(cmd, sizeof(cmd), + "INSERT INTO v(rowid) VALUES " + "('compute-centroids:{\"nlist\":%d,\"max_iterations\":%d,\"seed\":%u}')", + nlist_override, max_iter, seed); + sqlite3_exec(db, cmd, NULL, NULL, NULL); + } + + // KNN query after training + { + float *qvec = sqlite3_malloc(dim * sizeof(float)); + if (qvec) { + for (int d = 0; d < dim; d++) { + qvec[d] = (d < 3) ? 1.0f : 0.0f; + } + sqlite3_stmt *stmtKnn = NULL; + sqlite3_prepare_v2(db, + "SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 5", + -1, &stmtKnn, NULL); + if (stmtKnn) { + sqlite3_bind_blob(stmtKnn, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT); + while (sqlite3_step(stmtKnn) == SQLITE_ROW) {} + sqlite3_finalize(stmtKnn); + } + sqlite3_free(qvec); + } + } + + // Clear centroids and re-compute to test round-trip + sqlite3_exec(db, + "INSERT INTO v(rowid) VALUES ('clear-centroids')", + NULL, NULL, NULL); + + // Insert a few more vectors in untrained state + { + sqlite3_stmt *si = NULL; + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &si, NULL); + if (si) { + for (int i = 0; i < 3; i++) { + float *vec = sqlite3_malloc(dim * sizeof(float)); + if (!vec) break; + for (int d = 0; d < dim; d++) vec[d] = (float)(i + 100) * 0.1f; + sqlite3_reset(si); + sqlite3_bind_int64(si, 1, (int64_t)(num_vecs + i + 1)); + sqlite3_bind_blob(si, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT); + sqlite3_step(si); + sqlite3_free(vec); + } + sqlite3_finalize(si); + } + } + + // Re-train + sqlite3_exec(db, + "INSERT INTO v(rowid) VALUES ('compute-centroids')", + NULL, NULL, NULL); + + // Delete some rows after training, then query + sqlite3_exec(db, "DELETE FROM v WHERE rowid = 1", NULL, NULL, NULL); + sqlite3_exec(db, "DELETE FROM v WHERE rowid = 2", NULL, NULL, NULL); + + // Query after deletes + { + float *qvec = sqlite3_malloc(dim * sizeof(float)); + if (qvec) { + for (int d = 0; d < dim; d++) qvec[d] = 0.5f; + sqlite3_stmt *stmtKnn = NULL; + sqlite3_prepare_v2(db, + "SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 10", + -1, &stmtKnn, NULL); + if (stmtKnn) { + sqlite3_bind_blob(stmtKnn, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT); + while (sqlite3_step(stmtKnn) == SQLITE_ROW) {} + sqlite3_finalize(stmtKnn); + } + sqlite3_free(qvec); + } + } + + sqlite3_close(db); + return 0; +} diff --git a/tests/fuzz/ivf-knn-deep.c b/tests/fuzz/ivf-knn-deep.c new file mode 100644 index 0000000..27d19a1 --- /dev/null +++ b/tests/fuzz/ivf-knn-deep.c @@ -0,0 +1,199 @@ +/** + * Fuzz target: IVF KNN search deep paths. + * + * Exercises the full KNN pipeline with fuzz-controlled: + * - nprobe values (including > nlist, =1, =nlist) + * - Query vectors (including adversarial floats) + * - Mix of trained/untrained state + * - Oversample + rescore path (quantizer=int8 with oversample>1) + * - Multiple interleaved KNN queries + * - Candidate array realloc path (many vectors in probed cells) + * + * Targets: + * - ivf_scan_cells_from_stmt: candidate realloc, distance computation + * - ivf_query_knn: centroid sorting, nprobe selection + * - Oversample rescore: re-ranking with full-precision vectors + * - qsort with NaN distances + */ +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +static uint16_t read_u16(const uint8_t *p) { + return (uint16_t)(p[0] | (p[1] << 8)); +} + +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + if (size < 16) return 0; + + int rc; + sqlite3 *db; + + rc = sqlite3_open(":memory:", &db); + assert(rc == SQLITE_OK); + rc = sqlite3_vec_init(db, NULL, NULL); + assert(rc == SQLITE_OK); + + // Header + int dim = (data[0] % 32) + 2; // 2..33 + int nlist = (data[1] % 16) + 1; // 1..16 + int nprobe_initial = (data[2] % 20) + 1; // 1..20 (can be > nlist) + int quantizer_type = data[3] % 3; // 0=none, 1=int8, 2=binary + int oversample = (data[4] % 4) + 1; // 1..4 + int num_vecs = (data[5] % 80) + 4; // 4..83 + int num_queries = (data[6] % 8) + 1; // 1..8 + int k_limit = (data[7] % 20) + 1; // 1..20 + + const uint8_t *payload = data + 8; + size_t payload_size = size - 8; + + // For binary quantizer, dimension must be multiple of 8 + if (quantizer_type == 2) { + dim = ((dim + 7) / 8) * 8; + if (dim == 0) dim = 8; + } + + const char *qname; + switch (quantizer_type) { + case 1: qname = "int8"; break; + case 2: qname = "binary"; break; + default: qname = "none"; break; + } + + // Oversample only valid with quantization + if (quantizer_type == 0) oversample = 1; + + // Cap nprobe to nlist for CREATE (parser rejects nprobe > nlist) + int nprobe_create = nprobe_initial <= nlist ? nprobe_initial : nlist; + + char sql[512]; + snprintf(sql, sizeof(sql), + "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[%d] indexed by ivf(nlist=%d, nprobe=%d, quantizer=%s%s))", + dim, nlist, nprobe_create, qname, + oversample > 1 ? ", oversample=2" : ""); + + // If that fails (e.g. oversample with none), try without oversample + rc = sqlite3_exec(db, sql, NULL, NULL, NULL); + if (rc != SQLITE_OK) { + snprintf(sql, sizeof(sql), + "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[%d] indexed by ivf(nlist=%d, nprobe=%d, quantizer=%s))", + dim, nlist, nprobe_create, qname); + rc = sqlite3_exec(db, sql, NULL, NULL, NULL); + if (rc != SQLITE_OK) { sqlite3_close(db); return 0; } + } + + // Insert vectors + sqlite3_stmt *stmtInsert = NULL; + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); + if (!stmtInsert) { sqlite3_close(db); return 0; } + + size_t offset = 0; + for (int i = 0; i < num_vecs; i++) { + float *vec = sqlite3_malloc(dim * sizeof(float)); + if (!vec) break; + for (int d = 0; d < dim; d++) { + if (offset < payload_size) { + vec[d] = ((float)(int8_t)payload[offset++]) / 20.0f; + } else { + vec[d] = (float)((i * dim + d) % 256 - 128) / 128.0f; + } + } + sqlite3_reset(stmtInsert); + sqlite3_bind_int64(stmtInsert, 1, (int64_t)(i + 1)); + sqlite3_bind_blob(stmtInsert, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT); + sqlite3_step(stmtInsert); + sqlite3_free(vec); + } + sqlite3_finalize(stmtInsert); + + // Query BEFORE training (flat scan path) + { + float *qvec = sqlite3_malloc(dim * sizeof(float)); + if (qvec) { + for (int d = 0; d < dim; d++) qvec[d] = 0.5f; + sqlite3_stmt *sk = NULL; + snprintf(sql, sizeof(sql), + "SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit); + sqlite3_prepare_v2(db, sql, -1, &sk, NULL); + if (sk) { + sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT); + while (sqlite3_step(sk) == SQLITE_ROW) {} + sqlite3_finalize(sk); + } + sqlite3_free(qvec); + } + } + + // Train + sqlite3_exec(db, + "INSERT INTO v(rowid) VALUES ('compute-centroids')", + NULL, NULL, NULL); + + // Change nprobe at runtime (can exceed nlist -- tests clamping in query) + { + char cmd[64]; + snprintf(cmd, sizeof(cmd), + "INSERT INTO v(rowid) VALUES ('nprobe=%d')", nprobe_initial); + sqlite3_exec(db, cmd, NULL, NULL, NULL); + } + + // Multiple KNN queries with different fuzz-derived query vectors + for (int q = 0; q < num_queries; q++) { + float *qvec = sqlite3_malloc(dim * sizeof(float)); + if (!qvec) break; + for (int d = 0; d < dim; d++) { + if (offset < payload_size) { + qvec[d] = ((float)(int8_t)payload[offset++]) / 10.0f; + } else { + qvec[d] = (q == 0) ? 1.0f : 0.0f; + } + } + + sqlite3_stmt *sk = NULL; + snprintf(sql, sizeof(sql), + "SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit); + sqlite3_prepare_v2(db, sql, -1, &sk, NULL); + if (sk) { + sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT); + while (sqlite3_step(sk) == SQLITE_ROW) {} + sqlite3_finalize(sk); + } + sqlite3_free(qvec); + } + + // Delete half the vectors then query again + for (int i = 1; i <= num_vecs / 2; i++) { + char delsql[64]; + snprintf(delsql, sizeof(delsql), "DELETE FROM v WHERE rowid = %d", i); + sqlite3_exec(db, delsql, NULL, NULL, NULL); + } + + // Query after mass deletion + { + float *qvec = sqlite3_malloc(dim * sizeof(float)); + if (qvec) { + for (int d = 0; d < dim; d++) qvec[d] = -0.5f; + sqlite3_stmt *sk = NULL; + snprintf(sql, sizeof(sql), + "SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit); + sqlite3_prepare_v2(db, sql, -1, &sk, NULL); + if (sk) { + sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT); + while (sqlite3_step(sk) == SQLITE_ROW) {} + sqlite3_finalize(sk); + } + sqlite3_free(qvec); + } + } + + sqlite3_close(db); + return 0; +} diff --git a/tests/fuzz/ivf-operations.c b/tests/fuzz/ivf-operations.c new file mode 100644 index 0000000..a955870 --- /dev/null +++ b/tests/fuzz/ivf-operations.c @@ -0,0 +1,121 @@ +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + if (size < 6) return 0; + + int rc; + sqlite3 *db; + sqlite3_stmt *stmtInsert = NULL; + sqlite3_stmt *stmtDelete = NULL; + sqlite3_stmt *stmtKnn = NULL; + sqlite3_stmt *stmtScan = NULL; + + rc = sqlite3_open(":memory:", &db); + assert(rc == SQLITE_OK); + rc = sqlite3_vec_init(db, NULL, NULL); + assert(rc == SQLITE_OK); + + rc = sqlite3_exec(db, + "CREATE VIRTUAL TABLE v USING vec0(emb float[4] indexed by ivf(nlist=4, nprobe=4))", + NULL, NULL, NULL); + if (rc != SQLITE_OK) { sqlite3_close(db); return 0; } + + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); + sqlite3_prepare_v2(db, + "DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL); + sqlite3_prepare_v2(db, + "SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 3", + -1, &stmtKnn, NULL); + sqlite3_prepare_v2(db, + "SELECT rowid FROM v", -1, &stmtScan, NULL); + + if (!stmtInsert || !stmtDelete || !stmtKnn || !stmtScan) goto cleanup; + + size_t i = 0; + while (i + 2 <= size) { + uint8_t op = data[i++] % 7; + uint8_t rowid_byte = data[i++]; + int64_t rowid = (int64_t)(rowid_byte % 32) + 1; + + switch (op) { + case 0: { + // INSERT: consume 16 bytes for 4 floats, or use what's left + float vec[4] = {0.0f, 0.0f, 0.0f, 0.0f}; + for (int j = 0; j < 4 && i < size; j++, i++) { + vec[j] = (float)((int8_t)data[i]) / 10.0f; + } + sqlite3_reset(stmtInsert); + sqlite3_bind_int64(stmtInsert, 1, rowid); + sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT); + sqlite3_step(stmtInsert); + break; + } + case 1: { + // DELETE + sqlite3_reset(stmtDelete); + sqlite3_bind_int64(stmtDelete, 1, rowid); + sqlite3_step(stmtDelete); + break; + } + case 2: { + // KNN query with a fixed query vector + float qvec[4] = {1.0f, 0.0f, 0.0f, 0.0f}; + sqlite3_reset(stmtKnn); + sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_STATIC); + while (sqlite3_step(stmtKnn) == SQLITE_ROW) {} + break; + } + case 3: { + // Full scan + sqlite3_reset(stmtScan); + while (sqlite3_step(stmtScan) == SQLITE_ROW) {} + break; + } + case 4: { + // compute-centroids command + sqlite3_exec(db, + "INSERT INTO v(rowid) VALUES ('compute-centroids')", + NULL, NULL, NULL); + break; + } + case 5: { + // clear-centroids command + sqlite3_exec(db, + "INSERT INTO v(rowid) VALUES ('clear-centroids')", + NULL, NULL, NULL); + break; + } + case 6: { + // nprobe=N command + if (i < size) { + uint8_t n = data[i++]; + int nprobe = (n % 4) + 1; + char buf[64]; + snprintf(buf, sizeof(buf), + "INSERT INTO v(rowid) VALUES ('nprobe=%d')", nprobe); + sqlite3_exec(db, buf, NULL, NULL, NULL); + } + break; + } + } + } + + // Final operations — must not crash regardless of prior state + sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL); + +cleanup: + sqlite3_finalize(stmtInsert); + sqlite3_finalize(stmtDelete); + sqlite3_finalize(stmtKnn); + sqlite3_finalize(stmtScan); + sqlite3_close(db); + return 0; +} diff --git a/tests/fuzz/ivf-quantize.c b/tests/fuzz/ivf-quantize.c new file mode 100644 index 0000000..22149ee --- /dev/null +++ b/tests/fuzz/ivf-quantize.c @@ -0,0 +1,129 @@ +/** + * Fuzz target: IVF quantization functions. + * + * Directly exercises ivf_quantize_int8 and ivf_quantize_binary with + * fuzz-controlled dimensions and float data. Targets: + * - ivf_quantize_int8: clamping, int8 overflow boundary + * - ivf_quantize_binary: D not divisible by 8, memset(D/8) undercount + * - Round-trip through CREATE TABLE + INSERT with quantized IVF + */ +#include +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + if (size < 8) return 0; + + int rc; + sqlite3 *db; + + rc = sqlite3_open(":memory:", &db); + assert(rc == SQLITE_OK); + rc = sqlite3_vec_init(db, NULL, NULL); + assert(rc == SQLITE_OK); + + // Byte 0: quantizer type (0=int8, 1=binary) + // Byte 1: dimension (1..64, but we test edge cases) + // Byte 2: nlist (1..8) + // Byte 3: num_vectors to insert (1..32) + // Remaining: float data + int qtype = data[0] % 2; + int dim = (data[1] % 64) + 1; + int nlist = (data[2] % 8) + 1; + int num_vecs = (data[3] % 32) + 1; + const uint8_t *payload = data + 4; + size_t payload_size = size - 4; + + // For binary quantizer, D must be multiple of 8 to avoid the D/8 bug + // in production. But we explicitly want to test non-multiples too to + // find the bug. Use dim as-is. + const char *quantizer = qtype ? "binary" : "int8"; + + // Binary quantizer needs D multiple of 8 in current code, but let's + // test both valid and invalid dimensions to see what happens. + // For binary with non-multiple-of-8, the code does memset(dst, 0, D/8) + // which underallocates when D%8 != 0. + char sql[256]; + snprintf(sql, sizeof(sql), + "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[%d] indexed by ivf(nlist=%d, nprobe=%d, quantizer=%s))", + dim, nlist, nlist, quantizer); + + rc = sqlite3_exec(db, sql, NULL, NULL, NULL); + if (rc != SQLITE_OK) { sqlite3_close(db); return 0; } + + // Insert vectors with fuzz-controlled float values + sqlite3_stmt *stmtInsert = NULL; + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); + if (!stmtInsert) { sqlite3_close(db); return 0; } + + size_t offset = 0; + for (int i = 0; i < num_vecs && offset < payload_size; i++) { + // Build float vector from fuzz data + float *vec = sqlite3_malloc(dim * sizeof(float)); + if (!vec) break; + + for (int d = 0; d < dim; d++) { + if (offset + 4 <= payload_size) { + // Use raw bytes as float -- can produce NaN, Inf, denormals + memcpy(&vec[d], payload + offset, sizeof(float)); + offset += 4; + } else if (offset < payload_size) { + // Partial: use byte as scaled value + vec[d] = ((float)(int8_t)payload[offset++]) / 50.0f; + } else { + vec[d] = 0.0f; + } + } + + sqlite3_reset(stmtInsert); + sqlite3_bind_int64(stmtInsert, 1, (int64_t)(i + 1)); + sqlite3_bind_blob(stmtInsert, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT); + sqlite3_step(stmtInsert); + sqlite3_free(vec); + } + sqlite3_finalize(stmtInsert); + + // Trigger compute-centroids to exercise kmeans + quantization together + sqlite3_exec(db, + "INSERT INTO v(rowid) VALUES ('compute-centroids')", + NULL, NULL, NULL); + + // KNN query with fuzz-derived query vector + { + float *qvec = sqlite3_malloc(dim * sizeof(float)); + if (qvec) { + for (int d = 0; d < dim; d++) { + if (offset < payload_size) { + qvec[d] = ((float)(int8_t)payload[offset++]) / 10.0f; + } else { + qvec[d] = 1.0f; + } + } + + sqlite3_stmt *stmtKnn = NULL; + sqlite3_prepare_v2(db, + "SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 5", + -1, &stmtKnn, NULL); + if (stmtKnn) { + sqlite3_bind_blob(stmtKnn, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT); + while (sqlite3_step(stmtKnn) == SQLITE_ROW) {} + sqlite3_finalize(stmtKnn); + } + sqlite3_free(qvec); + } + } + + // Full scan + sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL); + + sqlite3_close(db); + return 0; +} diff --git a/tests/fuzz/ivf-rescore.c b/tests/fuzz/ivf-rescore.c new file mode 100644 index 0000000..1c3f34a --- /dev/null +++ b/tests/fuzz/ivf-rescore.c @@ -0,0 +1,182 @@ +/** + * Fuzz target: IVF oversample + rescore path. + * + * Specifically targets the code path where quantizer != none AND + * oversample > 1, which triggers: + * 1. Quantized KNN scan to collect oversample*k candidates + * 2. Full-precision vector lookup from _ivf_vectors table + * 3. Re-scoring with float32 distances + * 4. Re-sort and truncation + * + * This path has the most complex memory management in the KNN query: + * - Two separate distance computations (quantized + float) + * - Cross-table lookups (cells + vectors KV store) + * - Candidate array resizing + * - qsort over partially re-scored arrays + * + * Also tests the int8 + binary quantization round-trip fidelity + * under adversarial float inputs. + */ +#include +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + if (size < 12) return 0; + + int rc; + sqlite3 *db; + + rc = sqlite3_open(":memory:", &db); + assert(rc == SQLITE_OK); + rc = sqlite3_vec_init(db, NULL, NULL); + assert(rc == SQLITE_OK); + + // Header + int quantizer_type = (data[0] % 2) + 1; // 1=int8, 2=binary (never none) + int dim = (data[1] % 32) + 8; // 8..39 + int nlist = (data[2] % 8) + 1; // 1..8 + int oversample = (data[3] % 4) + 2; // 2..5 (always > 1) + int num_vecs = (data[4] % 60) + 8; // 8..67 + int k_limit = (data[5] % 15) + 1; // 1..15 + + const uint8_t *payload = data + 6; + size_t payload_size = size - 6; + + // Binary quantizer needs D multiple of 8 + if (quantizer_type == 2) { + dim = ((dim + 7) / 8) * 8; + } + + const char *qname = (quantizer_type == 1) ? "int8" : "binary"; + + char sql[512]; + snprintf(sql, sizeof(sql), + "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[%d] indexed by ivf(nlist=%d, nprobe=%d, quantizer=%s, oversample=%d))", + dim, nlist, nlist, qname, oversample); + + rc = sqlite3_exec(db, sql, NULL, NULL, NULL); + if (rc != SQLITE_OK) { sqlite3_close(db); return 0; } + + // Insert vectors with diverse values + sqlite3_stmt *stmtInsert = NULL; + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); + if (!stmtInsert) { sqlite3_close(db); return 0; } + + size_t offset = 0; + for (int i = 0; i < num_vecs; i++) { + float *vec = sqlite3_malloc(dim * sizeof(float)); + if (!vec) break; + for (int d = 0; d < dim; d++) { + if (offset + 4 <= payload_size) { + // Use raw bytes as float for adversarial values + memcpy(&vec[d], payload + offset, sizeof(float)); + offset += 4; + // Sanitize: replace NaN/Inf with bounded values to avoid + // poisoning the entire computation. We want edge values, + // not complete nonsense. + if (isnan(vec[d]) || isinf(vec[d])) { + vec[d] = (vec[d] > 0) ? 1e6f : -1e6f; + if (isnan(vec[d])) vec[d] = 0.0f; + } + } else if (offset < payload_size) { + vec[d] = ((float)(int8_t)payload[offset++]) / 30.0f; + } else { + vec[d] = (float)(i * dim + d) * 0.001f; + } + } + sqlite3_reset(stmtInsert); + sqlite3_bind_int64(stmtInsert, 1, (int64_t)(i + 1)); + sqlite3_bind_blob(stmtInsert, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT); + sqlite3_step(stmtInsert); + sqlite3_free(vec); + } + sqlite3_finalize(stmtInsert); + + // Train + sqlite3_exec(db, + "INSERT INTO v(rowid) VALUES ('compute-centroids')", + NULL, NULL, NULL); + + // Multiple KNN queries to exercise rescore path + for (int q = 0; q < 4; q++) { + float *qvec = sqlite3_malloc(dim * sizeof(float)); + if (!qvec) break; + for (int d = 0; d < dim; d++) { + if (offset < payload_size) { + qvec[d] = ((float)(int8_t)payload[offset++]) / 10.0f; + } else { + qvec[d] = (q == 0) ? 1.0f : (q == 1) ? -1.0f : 0.0f; + } + } + + sqlite3_stmt *sk = NULL; + snprintf(sql, sizeof(sql), + "SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit); + sqlite3_prepare_v2(db, sql, -1, &sk, NULL); + if (sk) { + sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT); + while (sqlite3_step(sk) == SQLITE_ROW) {} + sqlite3_finalize(sk); + } + sqlite3_free(qvec); + } + + // Delete some vectors, then query again (rescore with missing _ivf_vectors rows) + for (int i = 1; i <= num_vecs / 3; i++) { + char delsql[64]; + snprintf(delsql, sizeof(delsql), "DELETE FROM v WHERE rowid = %d", i); + sqlite3_exec(db, delsql, NULL, NULL, NULL); + } + + { + float *qvec = sqlite3_malloc(dim * sizeof(float)); + if (qvec) { + for (int d = 0; d < dim; d++) qvec[d] = 0.5f; + sqlite3_stmt *sk = NULL; + snprintf(sql, sizeof(sql), + "SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit); + sqlite3_prepare_v2(db, sql, -1, &sk, NULL); + if (sk) { + sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT); + while (sqlite3_step(sk) == SQLITE_ROW) {} + sqlite3_finalize(sk); + } + sqlite3_free(qvec); + } + } + + // Retrain after deletions + sqlite3_exec(db, + "INSERT INTO v(rowid) VALUES ('compute-centroids')", + NULL, NULL, NULL); + + // Query after retrain + { + float *qvec = sqlite3_malloc(dim * sizeof(float)); + if (qvec) { + for (int d = 0; d < dim; d++) qvec[d] = -0.3f; + sqlite3_stmt *sk = NULL; + snprintf(sql, sizeof(sql), + "SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit); + sqlite3_prepare_v2(db, sql, -1, &sk, NULL); + if (sk) { + sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT); + while (sqlite3_step(sk) == SQLITE_ROW) {} + sqlite3_finalize(sk); + } + sqlite3_free(qvec); + } + } + + sqlite3_close(db); + return 0; +} diff --git a/tests/fuzz/ivf-shadow-corrupt.c b/tests/fuzz/ivf-shadow-corrupt.c new file mode 100644 index 0000000..1153ac9 --- /dev/null +++ b/tests/fuzz/ivf-shadow-corrupt.c @@ -0,0 +1,228 @@ +/** + * Fuzz target: IVF shadow table corruption. + * + * Creates a trained IVF table, then corrupts IVF shadow table blobs + * (centroids, cells validity/rowids/vectors, rowid_map) with fuzz data. + * Then exercises all read/write paths. Must not crash. + * + * Targets: + * - Cell validity bitmap with wrong size + * - Cell rowids blob with wrong size/alignment + * - Cell vectors blob with wrong size + * - Centroid blob with wrong size + * - n_vectors inconsistent with validity bitmap + * - Missing rowid_map entries + * - KNN scan over corrupted cells + * - Insert/delete with corrupted rowid_map + */ +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + if (size < 4) return 0; + + int rc; + sqlite3 *db; + + rc = sqlite3_open(":memory:", &db); + assert(rc == SQLITE_OK); + rc = sqlite3_vec_init(db, NULL, NULL); + assert(rc == SQLITE_OK); + + // Create IVF table and insert enough vectors to train + rc = sqlite3_exec(db, + "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[8] indexed by ivf(nlist=2, nprobe=2))", + NULL, NULL, NULL); + if (rc != SQLITE_OK) { sqlite3_close(db); return 0; } + + // Insert 10 vectors + { + sqlite3_stmt *si = NULL; + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &si, NULL); + if (!si) { sqlite3_close(db); return 0; } + for (int i = 0; i < 10; i++) { + float vec[8]; + for (int d = 0; d < 8; d++) { + vec[d] = (float)(i * 8 + d) * 0.1f; + } + sqlite3_reset(si); + sqlite3_bind_int64(si, 1, i + 1); + sqlite3_bind_blob(si, 2, vec, sizeof(vec), SQLITE_TRANSIENT); + sqlite3_step(si); + } + sqlite3_finalize(si); + } + + // Train + sqlite3_exec(db, + "INSERT INTO v(rowid) VALUES ('compute-centroids')", + NULL, NULL, NULL); + + // Now corrupt shadow tables based on fuzz input + uint8_t target = data[0] % 10; + const uint8_t *payload = data + 1; + int payload_size = (int)(size - 1); + + // Limit payload to avoid huge allocations + if (payload_size > 4096) payload_size = 4096; + + sqlite3_stmt *stmt = NULL; + + switch (target) { + case 0: { + // Corrupt cell validity blob + rc = sqlite3_prepare_v2(db, + "UPDATE v_ivf_cells00 SET validity = ? WHERE rowid = 1", + -1, &stmt, NULL); + if (rc == SQLITE_OK) { + sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC); + sqlite3_step(stmt); sqlite3_finalize(stmt); + } + break; + } + case 1: { + // Corrupt cell rowids blob + rc = sqlite3_prepare_v2(db, + "UPDATE v_ivf_cells00 SET rowids = ? WHERE rowid = 1", + -1, &stmt, NULL); + if (rc == SQLITE_OK) { + sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC); + sqlite3_step(stmt); sqlite3_finalize(stmt); + } + break; + } + case 2: { + // Corrupt cell vectors blob + rc = sqlite3_prepare_v2(db, + "UPDATE v_ivf_cells00 SET vectors = ? WHERE rowid = 1", + -1, &stmt, NULL); + if (rc == SQLITE_OK) { + sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC); + sqlite3_step(stmt); sqlite3_finalize(stmt); + } + break; + } + case 3: { + // Corrupt centroid blob + rc = sqlite3_prepare_v2(db, + "UPDATE v_ivf_centroids00 SET centroid = ? WHERE centroid_id = 0", + -1, &stmt, NULL); + if (rc == SQLITE_OK) { + sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC); + sqlite3_step(stmt); sqlite3_finalize(stmt); + } + break; + } + case 4: { + // Set n_vectors to a bogus value (larger than cell capacity) + int bogus_n = 99999; + if (payload_size >= 4) { + memcpy(&bogus_n, payload, 4); + bogus_n = abs(bogus_n) % 100000; + } + char sql[128]; + snprintf(sql, sizeof(sql), + "UPDATE v_ivf_cells00 SET n_vectors = %d WHERE rowid = 1", bogus_n); + sqlite3_exec(db, sql, NULL, NULL, NULL); + break; + } + case 5: { + // Delete rowid_map entries (orphan vectors) + sqlite3_exec(db, + "DELETE FROM v_ivf_rowid_map00 WHERE rowid IN (1, 2, 3)", + NULL, NULL, NULL); + break; + } + case 6: { + // Corrupt rowid_map slot values + char sql[128]; + int bogus_slot = payload_size > 0 ? (int)payload[0] * 1000 : 99999; + snprintf(sql, sizeof(sql), + "UPDATE v_ivf_rowid_map00 SET slot = %d WHERE rowid = 1", bogus_slot); + sqlite3_exec(db, sql, NULL, NULL, NULL); + break; + } + case 7: { + // Corrupt rowid_map cell_id values + sqlite3_exec(db, + "UPDATE v_ivf_rowid_map00 SET cell_id = 99999 WHERE rowid = 1", + NULL, NULL, NULL); + break; + } + case 8: { + // Delete all centroids (make trained but no centroids) + sqlite3_exec(db, + "DELETE FROM v_ivf_centroids00", + NULL, NULL, NULL); + break; + } + case 9: { + // Set validity to NULL + sqlite3_exec(db, + "UPDATE v_ivf_cells00 SET validity = NULL WHERE rowid = 1", + NULL, NULL, NULL); + break; + } + } + + // Exercise all read paths over corrupted state — must not crash + float qvec[8] = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + + // KNN query + { + sqlite3_stmt *sk = NULL; + sqlite3_prepare_v2(db, + "SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 5", + -1, &sk, NULL); + if (sk) { + sqlite3_bind_blob(sk, 1, qvec, sizeof(qvec), SQLITE_STATIC); + while (sqlite3_step(sk) == SQLITE_ROW) {} + sqlite3_finalize(sk); + } + } + + // Full scan + sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL); + + // Point query + sqlite3_exec(db, "SELECT * FROM v WHERE rowid = 1", NULL, NULL, NULL); + sqlite3_exec(db, "SELECT * FROM v WHERE rowid = 5", NULL, NULL, NULL); + + // Delete + sqlite3_exec(db, "DELETE FROM v WHERE rowid = 3", NULL, NULL, NULL); + + // Insert after corruption + { + float newvec[8] = {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}; + sqlite3_stmt *si = NULL; + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &si, NULL); + if (si) { + sqlite3_bind_int64(si, 1, 100); + sqlite3_bind_blob(si, 2, newvec, sizeof(newvec), SQLITE_STATIC); + sqlite3_step(si); + sqlite3_finalize(si); + } + } + + // compute-centroids over corrupted state + sqlite3_exec(db, + "INSERT INTO v(rowid) VALUES ('compute-centroids')", + NULL, NULL, NULL); + + // clear-centroids + sqlite3_exec(db, + "INSERT INTO v(rowid) VALUES ('clear-centroids')", + NULL, NULL, NULL); + + sqlite3_close(db); + return 0; +} diff --git a/tests/sqlite-vec-internal.h b/tests/sqlite-vec-internal.h index ca04b74..67f1370 100644 --- a/tests/sqlite-vec-internal.h +++ b/tests/sqlite-vec-internal.h @@ -5,6 +5,10 @@ #include #include +#ifndef SQLITE_VEC_ENABLE_IVF +#define SQLITE_VEC_ENABLE_IVF 1 +#endif + int min_idx( const float *distances, int32_t n, @@ -68,8 +72,36 @@ enum Vec0IndexType { #ifdef SQLITE_VEC_ENABLE_RESCORE VEC0_INDEX_TYPE_RESCORE = 2, #endif + VEC0_INDEX_TYPE_IVF = 3, }; +enum Vec0RescoreQuantizerType { + VEC0_RESCORE_QUANTIZER_BIT = 1, + VEC0_RESCORE_QUANTIZER_INT8 = 2, +}; + +struct Vec0RescoreConfig { + enum Vec0RescoreQuantizerType quantizer_type; + int oversample; +}; + +#if SQLITE_VEC_ENABLE_IVF +enum Vec0IvfQuantizer { + VEC0_IVF_QUANTIZER_NONE = 0, + VEC0_IVF_QUANTIZER_INT8 = 1, + VEC0_IVF_QUANTIZER_BINARY = 2, +}; + +struct Vec0IvfConfig { + int nlist; + int nprobe; + int quantizer; + int oversample; +}; +#else +struct Vec0IvfConfig { char _unused; }; +#endif + #ifdef SQLITE_VEC_ENABLE_RESCORE enum Vec0RescoreQuantizerType { VEC0_RESCORE_QUANTIZER_BIT = 1, @@ -93,6 +125,7 @@ struct VectorColumnDefinition { #ifdef SQLITE_VEC_ENABLE_RESCORE struct Vec0RescoreConfig rescore; #endif + struct Vec0IvfConfig ivf; }; int vec0_parse_vector_column(const char *source, int source_length, @@ -114,6 +147,10 @@ void _test_rescore_quantize_float_to_int8(const float *src, int8_t *dst, size_t size_t _test_rescore_quantized_byte_size_bit(size_t dimensions); size_t _test_rescore_quantized_byte_size_int8(size_t dimensions); #endif +#if SQLITE_VEC_ENABLE_IVF +void ivf_quantize_int8(const float *src, int8_t *dst, int D); +void ivf_quantize_binary(const float *src, uint8_t *dst, int D); +#endif #endif #endif /* SQLITE_VEC_INTERNAL_H */ diff --git a/tests/test-ivf-mutations.py b/tests/test-ivf-mutations.py new file mode 100644 index 0000000..5c61119 --- /dev/null +++ b/tests/test-ivf-mutations.py @@ -0,0 +1,575 @@ +""" +Thorough IVF mutation tests: insert, delete, update, KNN correctness, +error cases, edge cases, and cell overflow scenarios. +""" +import pytest +import sqlite3 +import struct +import math +from helpers import _f32, exec + + +@pytest.fixture() +def db(): + db = sqlite3.connect(":memory:") + db.row_factory = sqlite3.Row + db.enable_load_extension(True) + db.load_extension("dist/vec0") + db.enable_load_extension(False) + return db + + +def ivf_total_vectors(db, table="t", col=0): + """Count total vectors across all IVF cells.""" + return db.execute( + f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d}" + ).fetchone()[0] + + +def ivf_unassigned_count(db, table="t", col=0): + return db.execute( + f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d} WHERE centroid_id = -1" + ).fetchone()[0] + + +def ivf_assigned_count(db, table="t", col=0): + return db.execute( + f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d} WHERE centroid_id >= 0" + ).fetchone()[0] + + +def knn(db, query, k, table="t", col="v"): + """Run a KNN query and return list of (rowid, distance) tuples.""" + rows = db.execute( + f"SELECT rowid, distance FROM {table} WHERE {col} MATCH ? AND k = ?", + [_f32(query), k], + ).fetchall() + return [(r[0], r[1]) for r in rows] + + +# ============================================================================ +# Single row insert + KNN +# ============================================================================ + + +def test_insert_single_row_knn(db): + db.execute("CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf())") + db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 0, 0, 0])]) + results = knn(db, [1, 0, 0, 0], 5) + assert len(results) == 1 + assert results[0][0] == 1 + assert results[0][1] < 0.001 + + +# ============================================================================ +# Batch insert + KNN recall +# ============================================================================ + + +def test_batch_insert_knn_recall(db): + """Insert 200 vectors, train, verify KNN recall with nprobe=nlist.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=8, nprobe=8))" + ) + for i in range(200): + db.execute( + "INSERT INTO t(rowid, v) VALUES (?, ?)", + [i, _f32([float(i), 0, 0, 0])], + ) + assert ivf_total_vectors(db) == 200 + + db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + assert ivf_assigned_count(db) == 200 + + # Query near 100 -- closest should be rowid 100 + results = knn(db, [100.0, 0, 0, 0], 10) + assert len(results) == 10 + assert results[0][0] == 100 + assert results[0][1] < 0.01 + + # All results should be near 100 + rowids = {r[0] for r in results} + assert all(95 <= r <= 105 for r in rowids) + + +# ============================================================================ +# Delete rows, verify they're gone from KNN +# ============================================================================ + + +def test_delete_rows_gone_from_knn(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))" + ) + for i in range(20): + db.execute( + "INSERT INTO t(rowid, v) VALUES (?, ?)", + [i, _f32([float(i), 0, 0, 0])], + ) + + db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + + # Delete rowid 10 + db.execute("DELETE FROM t WHERE rowid = 10") + + results = knn(db, [10.0, 0, 0, 0], 20) + rowids = {r[0] for r in results} + assert 10 not in rowids + + +def test_delete_all_rows_empty_results(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))" + ) + for i in range(10): + db.execute( + "INSERT INTO t(rowid, v) VALUES (?, ?)", + [i, _f32([float(i), 0, 0, 0])], + ) + + db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + + for i in range(10): + db.execute("DELETE FROM t WHERE rowid = ?", [i]) + + assert ivf_total_vectors(db) == 0 + results = knn(db, [5.0, 0, 0, 0], 10) + assert len(results) == 0 + + +# ============================================================================ +# Insert after delete (reuse rowids) +# ============================================================================ + + +def test_insert_after_delete_reuse_rowid(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))" + ) + for i in range(10): + db.execute( + "INSERT INTO t(rowid, v) VALUES (?, ?)", + [i, _f32([float(i), 0, 0, 0])], + ) + + db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + + # Delete rowid 5 + db.execute("DELETE FROM t WHERE rowid = 5") + + # Re-insert rowid 5 with a very different vector + db.execute( + "INSERT INTO t(rowid, v) VALUES (5, ?)", [_f32([999.0, 0, 0, 0])] + ) + + # KNN near 999 should find rowid 5 + results = knn(db, [999.0, 0, 0, 0], 1) + assert len(results) >= 1 + assert results[0][0] == 5 + + +# ============================================================================ +# Update vectors (INSERT OR REPLACE), verify KNN reflects new values +# ============================================================================ + + +def test_update_vector_via_delete_insert(db): + """vec0 IVF update: delete then re-insert with new vector.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))" + ) + for i in range(10): + db.execute( + "INSERT INTO t(rowid, v) VALUES (?, ?)", + [i, _f32([float(i), 0, 0, 0])], + ) + + db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + + # "Update" rowid 3: delete and re-insert with new vector + db.execute("DELETE FROM t WHERE rowid = 3") + db.execute( + "INSERT INTO t(rowid, v) VALUES (3, ?)", + [_f32([100.0, 0, 0, 0])], + ) + + # KNN near 100 should find rowid 3 + results = knn(db, [100.0, 0, 0, 0], 1) + assert results[0][0] == 3 + + +# ============================================================================ +# Error cases: IVF + auxiliary/metadata/partition key columns +# ============================================================================ + + +def test_error_ivf_with_auxiliary_column(db): + result = exec( + db, + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(), +extra text)", + ) + assert "error" in result + assert "auxiliary" in result.get("message", "").lower() + + +def test_error_ivf_with_metadata_column(db): + result = exec( + db, + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(), genre text)", + ) + assert "error" in result + assert "metadata" in result.get("message", "").lower() + + +def test_error_ivf_with_partition_key(db): + result = exec( + db, + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(), user_id integer partition key)", + ) + assert "error" in result + assert "partition" in result.get("message", "").lower() + + +def test_flat_with_auxiliary_still_works(db): + """Regression guard: flat-indexed tables with aux columns should still work.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4], +extra text)" + ) + db.execute( + "INSERT INTO t(rowid, v, extra) VALUES (1, ?, 'hello')", + [_f32([1, 0, 0, 0])], + ) + row = db.execute("SELECT extra FROM t WHERE rowid = 1").fetchone() + assert row[0] == "hello" + + +def test_flat_with_metadata_still_works(db): + """Regression guard: flat-indexed tables with metadata columns should still work.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4], genre text)" + ) + db.execute( + "INSERT INTO t(rowid, v, genre) VALUES (1, ?, 'rock')", + [_f32([1, 0, 0, 0])], + ) + row = db.execute("SELECT genre FROM t WHERE rowid = 1").fetchone() + assert row[0] == "rock" + + +def test_flat_with_partition_key_still_works(db): + """Regression guard: flat-indexed tables with partition key should still work.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4], user_id integer partition key)" + ) + db.execute( + "INSERT INTO t(rowid, v, user_id) VALUES (1, ?, 42)", + [_f32([1, 0, 0, 0])], + ) + row = db.execute("SELECT user_id FROM t WHERE rowid = 1").fetchone() + assert row[0] == 42 + + +# ============================================================================ +# Edge cases +# ============================================================================ + + +def test_zero_vectors(db): + """Insert zero vectors, verify KNN still works.""" + db.execute("CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf())") + for i in range(5): + db.execute( + "INSERT INTO t(rowid, v) VALUES (?, ?)", + [i, _f32([0, 0, 0, 0])], + ) + results = knn(db, [0, 0, 0, 0], 5) + assert len(results) == 5 + # All distances should be 0 + for _, dist in results: + assert dist < 0.001 + + +def test_large_values(db): + """Insert vectors with very large and small values.""" + db.execute("CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf())") + db.execute( + "INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1e6, 1e6, 1e6, 1e6])] + ) + db.execute( + "INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([1e-6, 1e-6, 1e-6, 1e-6])] + ) + db.execute( + "INSERT INTO t(rowid, v) VALUES (3, ?)", [_f32([-1e6, -1e6, -1e6, -1e6])] + ) + + results = knn(db, [1e6, 1e6, 1e6, 1e6], 3) + assert results[0][0] == 1 + + +def test_single_row_compute_centroids(db): + """Single row table, compute-centroids should still work.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=1))" + ) + db.execute( + "INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 2, 3, 4])] + ) + db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + assert ivf_assigned_count(db) == 1 + + results = knn(db, [1, 2, 3, 4], 1) + assert len(results) == 1 + assert results[0][0] == 1 + + +# ============================================================================ +# Cell overflow (many vectors in one cell) +# ============================================================================ + + +def test_cell_overflow_many_vectors(db): + """Insert >64 vectors that all go to same centroid. Should create multiple cells.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=0))" + ) + # Insert 100 very similar vectors + for i in range(100): + db.execute( + "INSERT INTO t(rowid, v) VALUES (?, ?)", + [i, _f32([1.0 + i * 0.001, 0, 0, 0])], + ) + + # Set a single centroid so all vectors go there + db.execute( + "INSERT INTO t(rowid, v) VALUES ('set-centroid:0', ?)", + [_f32([1.0, 0, 0, 0])], + ) + db.execute("INSERT INTO t(rowid) VALUES ('assign-vectors')") + + assert ivf_assigned_count(db) == 100 + + # Should have more than 1 cell (64 max per cell) + cell_count = db.execute( + "SELECT count(*) FROM t_ivf_cells00 WHERE centroid_id = 0" + ).fetchone()[0] + assert cell_count >= 2 # 100 / 64 = 2 cells needed + + # All vectors should be queryable + results = knn(db, [1.0, 0, 0, 0], 100) + assert len(results) == 100 + + +# ============================================================================ +# Large batch with training +# ============================================================================ + + +def test_large_batch_with_training(db): + """Insert 500, train, insert 500 more, verify total is 1000.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=16, nprobe=16))" + ) + for i in range(500): + db.execute( + "INSERT INTO t(rowid, v) VALUES (?, ?)", + [i, _f32([float(i), 0, 0, 0])], + ) + + db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + + for i in range(500, 1000): + db.execute( + "INSERT INTO t(rowid, v) VALUES (?, ?)", + [i, _f32([float(i), 0, 0, 0])], + ) + + assert ivf_total_vectors(db) == 1000 + + # KNN should still work + results = knn(db, [750.0, 0, 0, 0], 5) + assert len(results) == 5 + assert results[0][0] == 750 + + +# ============================================================================ +# KNN after interleaved insert/delete +# ============================================================================ + + +def test_knn_after_interleaved_insert_delete(db): + """Insert 20, train, delete 10 closest to query, verify remaining.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4, nprobe=4))" + ) + for i in range(20): + db.execute( + "INSERT INTO t(rowid, v) VALUES (?, ?)", + [i, _f32([float(i), 0, 0, 0])], + ) + + db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + + # Delete rowids 0-9 (closest to query at 5.0) + for i in range(10): + db.execute("DELETE FROM t WHERE rowid = ?", [i]) + + results = knn(db, [5.0, 0, 0, 0], 10) + rowids = {r[0] for r in results} + # None of the deleted rowids should appear + assert all(r >= 10 for r in rowids) + assert len(results) == 10 + + +def test_knn_empty_centroids_after_deletes(db): + """Some centroids may become empty after deletes. Should not crash.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4, nprobe=2))" + ) + # Insert vectors clustered near 4 points + for i in range(40): + db.execute( + "INSERT INTO t(rowid, v) VALUES (?, ?)", + [i, _f32([float(i % 10) * 10, 0, 0, 0])], + ) + + db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + + # Delete a bunch, potentially emptying some centroids + for i in range(30): + db.execute("DELETE FROM t WHERE rowid = ?", [i]) + + # Should not crash even with empty centroids + results = knn(db, [50.0, 0, 0, 0], 20) + assert len(results) <= 10 # only 10 left + + +# ============================================================================ +# KNN returns correct distances +# ============================================================================ + + +def test_knn_correct_distances(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))" + ) + db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([0, 0, 0, 0])]) + db.execute("INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([3, 0, 0, 0])]) + db.execute("INSERT INTO t(rowid, v) VALUES (3, ?)", [_f32([0, 4, 0, 0])]) + + db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + + results = knn(db, [0, 0, 0, 0], 3) + result_map = {r[0]: r[1] for r in results} + + # L2 distances (sqrt of sum of squared differences) + assert abs(result_map[1] - 0.0) < 0.01 + assert abs(result_map[2] - 3.0) < 0.01 # sqrt(3^2) = 3 + assert abs(result_map[3] - 4.0) < 0.01 # sqrt(4^2) = 4 + + +# ============================================================================ +# Delete in flat mode leaves no orphan rowid_map entries +# ============================================================================ + + +def test_delete_flat_mode_rowid_map_count(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))" + ) + for i in range(5): + db.execute( + "INSERT INTO t(rowid, v) VALUES (?, ?)", + [i, _f32([float(i), 0, 0, 0])], + ) + + db.execute("DELETE FROM t WHERE rowid = 0") + db.execute("DELETE FROM t WHERE rowid = 2") + db.execute("DELETE FROM t WHERE rowid = 4") + + assert db.execute("SELECT count(*) FROM t_ivf_rowid_map00").fetchone()[0] == 2 + assert ivf_unassigned_count(db) == 2 + + +# ============================================================================ +# Duplicate rowid insert +# ============================================================================ + + +def test_delete_reinsert_as_update(db): + """Simulate update via delete + insert on same rowid.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))" + ) + db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 0, 0, 0])]) + + # Delete then re-insert as "update" + db.execute("DELETE FROM t WHERE rowid = 1") + db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([99, 0, 0, 0])]) + + results = knn(db, [99, 0, 0, 0], 1) + assert len(results) == 1 + assert results[0][0] == 1 + assert results[0][1] < 0.01 + + +def test_duplicate_rowid_insert_fails(db): + """Inserting a duplicate rowid should fail with a constraint error.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))" + ) + db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 0, 0, 0])]) + + result = exec( + db, + "INSERT INTO t(rowid, v) VALUES (1, ?)", + [_f32([99, 0, 0, 0])], + ) + assert "error" in result + + +# ============================================================================ +# Interleaved insert/delete with KNN correctness +# ============================================================================ + + +def test_interleaved_ops_correctness(db): + """Complex sequence of inserts and deletes, verify KNN is always correct.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4, nprobe=4))" + ) + + # Phase 1: Insert 50 vectors + for i in range(50): + db.execute( + "INSERT INTO t(rowid, v) VALUES (?, ?)", + [i, _f32([float(i), 0, 0, 0])], + ) + + db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + + # Phase 2: Delete even-numbered rowids + for i in range(0, 50, 2): + db.execute("DELETE FROM t WHERE rowid = ?", [i]) + + # Phase 3: Insert new vectors with higher rowids + for i in range(50, 75): + db.execute( + "INSERT INTO t(rowid, v) VALUES (?, ?)", + [i, _f32([float(i), 0, 0, 0])], + ) + + # Phase 4: Delete some of the new ones + for i in range(60, 70): + db.execute("DELETE FROM t WHERE rowid = ?", [i]) + + # KNN query: should only find existing vectors + results = knn(db, [25.0, 0, 0, 0], 50) + rowids = {r[0] for r in results} + + # Verify no deleted rowids appear + deleted = set(range(0, 50, 2)) | set(range(60, 70)) + assert len(rowids & deleted) == 0 + + # Verify we get the right count (25 odd + 15 new - 10 deleted new = 30) + expected_alive = set(range(1, 50, 2)) | set(range(50, 60)) | set(range(70, 75)) + assert rowids.issubset(expected_alive) diff --git a/tests/test-ivf-quantization.py b/tests/test-ivf-quantization.py new file mode 100644 index 0000000..9790680 --- /dev/null +++ b/tests/test-ivf-quantization.py @@ -0,0 +1,255 @@ +import pytest +import sqlite3 +from helpers import _f32, exec + + +@pytest.fixture() +def db(): + db = sqlite3.connect(":memory:") + db.row_factory = sqlite3.Row + db.enable_load_extension(True) + db.load_extension("dist/vec0") + db.enable_load_extension(False) + return db + + +# ============================================================================ +# Parser tests — quantizer and oversample options +# ============================================================================ + + +def test_ivf_quantizer_binary(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + "v float[768] indexed by ivf(nlist=64, quantizer=binary, oversample=10))" + ) + tables = [ + r[0] + for r in db.execute( + "SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1" + ).fetchall() + ] + assert "t_ivf_centroids00" in tables + assert "t_ivf_cells00" in tables + + +def test_ivf_quantizer_int8(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + "v float[768] indexed by ivf(nlist=64, quantizer=int8))" + ) + tables = [ + r[0] + for r in db.execute( + "SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1" + ).fetchall() + ] + assert "t_ivf_centroids00" in tables + + +def test_ivf_quantizer_none_explicit(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + "v float[768] indexed by ivf(quantizer=none))" + ) + # Should work — same as no quantizer + tables = [ + r[0] + for r in db.execute( + "SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1" + ).fetchall() + ] + assert "t_ivf_centroids00" in tables + + +def test_ivf_quantizer_all_params(db): + """All params together: nlist, nprobe, quantizer, oversample.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + "v float[768] distance_metric=cosine " + "indexed by ivf(nlist=128, nprobe=16, quantizer=int8, oversample=4))" + ) + tables = [ + r[0] + for r in db.execute( + "SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1" + ).fetchall() + ] + assert "t_ivf_centroids00" in tables + + +def test_ivf_error_oversample_without_quantizer(db): + """oversample > 1 without quantizer should error.""" + result = exec( + db, + "CREATE VIRTUAL TABLE t USING vec0(" + "v float[768] indexed by ivf(oversample=10))", + ) + assert "error" in result + + +def test_ivf_error_unknown_quantizer(db): + result = exec( + db, + "CREATE VIRTUAL TABLE t USING vec0(" + "v float[768] indexed by ivf(quantizer=pq))", + ) + assert "error" in result + + +def test_ivf_oversample_1_without_quantizer_ok(db): + """oversample=1 (default) is fine without quantizer.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + "v float[768] indexed by ivf(nlist=64))" + ) + # Should succeed — oversample defaults to 1 + + +# ============================================================================ +# Functional tests — insert, train, query with quantized IVF +# ============================================================================ + + +def test_ivf_int8_insert_and_query(db): + """int8 quantized IVF: insert, train, query.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + "v float[4] indexed by ivf(nlist=2, quantizer=int8, oversample=4))" + ) + for i in range(20): + db.execute( + "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])] + ) + + db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + + # Should be able to query + rows = db.execute( + "SELECT rowid, distance FROM t WHERE v MATCH ? AND k = 5", + [_f32([10.0, 0, 0, 0])], + ).fetchall() + assert len(rows) == 5 + # Top result should be close to 10 + assert rows[0][0] in range(8, 13) + + # Full vectors should be in _ivf_vectors table + fv_count = db.execute("SELECT count(*) FROM t_ivf_vectors00").fetchone()[0] + assert fv_count == 20 + + +def test_ivf_binary_insert_and_query(db): + """Binary quantized IVF: insert, train, query.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + "v float[8] indexed by ivf(nlist=2, quantizer=binary, oversample=4))" + ) + for i in range(20): + # Vectors with varying sign patterns + v = [(i * 0.1 - 1.0) + j * 0.3 for j in range(8)] + db.execute( + "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32(v)] + ) + + db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + + rows = db.execute( + "SELECT rowid FROM t WHERE v MATCH ? AND k = 5", + [_f32([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5])], + ).fetchall() + assert len(rows) == 5 + + # Full vectors stored + fv_count = db.execute("SELECT count(*) FROM t_ivf_vectors00").fetchone()[0] + assert fv_count == 20 + + +def test_ivf_int8_cell_sizes_smaller(db): + """Cell blobs should be smaller with int8 quantization.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + "v float[768] indexed by ivf(nlist=2, quantizer=int8, oversample=1))" + ) + for i in range(10): + db.execute( + "INSERT INTO t(rowid, v) VALUES (?, ?)", + [i, _f32([float(x) / 768 for x in range(768)])], + ) + + # Cell vectors blob: 10 vectors at int8 = 10 * 768 = 7680 bytes + # vs float32 = 10 * 768 * 4 = 30720 bytes + # But cells have capacity 64, so blob = 64 * 768 = 49152 (int8) vs 64*768*4=196608 (float32) + blob_size = db.execute( + "SELECT length(vectors) FROM t_ivf_cells00 LIMIT 1" + ).fetchone()[0] + # int8: 64 slots * 768 bytes = 49152 + assert blob_size == 64 * 768 + + +def test_ivf_binary_cell_sizes_smaller(db): + """Cell blobs should be much smaller with binary quantization.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + "v float[768] indexed by ivf(nlist=2, quantizer=binary, oversample=1))" + ) + for i in range(10): + db.execute( + "INSERT INTO t(rowid, v) VALUES (?, ?)", + [i, _f32([float(x) / 768 for x in range(768)])], + ) + + blob_size = db.execute( + "SELECT length(vectors) FROM t_ivf_cells00 LIMIT 1" + ).fetchone()[0] + # binary: 64 slots * 768/8 bytes = 6144 + assert blob_size == 64 * (768 // 8) + + +def test_ivf_int8_oversample_improves_recall(db): + """Oversample re-ranking should improve recall over oversample=1.""" + # Create two tables: one with oversample=1, one with oversample=10 + db.execute( + "CREATE VIRTUAL TABLE t1 USING vec0(" + "v float[4] indexed by ivf(nlist=4, quantizer=int8, oversample=1))" + ) + db.execute( + "CREATE VIRTUAL TABLE t2 USING vec0(" + "v float[4] indexed by ivf(nlist=4, quantizer=int8, oversample=10))" + ) + for i in range(100): + v = _f32([i * 0.1, (i * 0.1) % 3, (i * 0.3) % 5, i * 0.01]) + db.execute("INSERT INTO t1(rowid, v) VALUES (?, ?)", [i, v]) + db.execute("INSERT INTO t2(rowid, v) VALUES (?, ?)", [i, v]) + + db.execute("INSERT INTO t1(rowid) VALUES ('compute-centroids')") + db.execute("INSERT INTO t2(rowid) VALUES ('compute-centroids')") + db.execute("INSERT INTO t1(rowid) VALUES ('nprobe=4')") + db.execute("INSERT INTO t2(rowid) VALUES ('nprobe=4')") + + query = _f32([5.0, 1.5, 2.5, 0.5]) + r1 = db.execute("SELECT rowid FROM t1 WHERE v MATCH ? AND k=10", [query]).fetchall() + r2 = db.execute("SELECT rowid FROM t2 WHERE v MATCH ? AND k=10", [query]).fetchall() + + # Both should return 10 results + assert len(r1) == 10 + assert len(r2) == 10 + # oversample=10 should have at least as good recall (same or better ordering) + + +def test_ivf_quantized_delete(db): + """Delete should remove from both cells and _ivf_vectors.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + "v float[4] indexed by ivf(nlist=2, quantizer=int8, oversample=1))" + ) + for i in range(10): + db.execute( + "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])] + ) + + db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + assert db.execute("SELECT count(*) FROM t_ivf_vectors00").fetchone()[0] == 10 + + db.execute("DELETE FROM t WHERE rowid = 5") + # _ivf_vectors should have 9 rows + assert db.execute("SELECT count(*) FROM t_ivf_vectors00").fetchone()[0] == 9 diff --git a/tests/test-ivf.py b/tests/test-ivf.py new file mode 100644 index 0000000..18a7532 --- /dev/null +++ b/tests/test-ivf.py @@ -0,0 +1,426 @@ +import pytest +import sqlite3 +import struct +import math +from helpers import _f32, exec + + +@pytest.fixture() +def db(): + db = sqlite3.connect(":memory:") + db.row_factory = sqlite3.Row + db.enable_load_extension(True) + db.load_extension("dist/vec0") + db.enable_load_extension(False) + return db + + +def ivf_total_vectors(db, table="t", col=0): + """Count total vectors across all IVF cells.""" + return db.execute( + f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d}" + ).fetchone()[0] + + +def ivf_unassigned_count(db, table="t", col=0): + """Count vectors in unassigned cells (centroid_id=-1).""" + return db.execute( + f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d} WHERE centroid_id = -1" + ).fetchone()[0] + + +def ivf_assigned_count(db, table="t", col=0): + """Count vectors in trained cells (centroid_id >= 0).""" + return db.execute( + f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d} WHERE centroid_id >= 0" + ).fetchone()[0] + + +# ============================================================================ +# Parser tests +# ============================================================================ + + +def test_ivf_create_defaults(db): + """ivf() with no args uses defaults.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf())" + ) + tables = [ + r[0] + for r in db.execute( + "SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1" + ).fetchall() + ] + assert "t_ivf_centroids00" in tables + assert "t_ivf_cells00" in tables + assert "t_ivf_rowid_map00" in tables + + +def test_ivf_create_custom_params(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=64, nprobe=8))" + ) + tables = [ + r[0] + for r in db.execute( + "SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1" + ).fetchall() + ] + assert "t_ivf_centroids00" in tables + assert "t_ivf_cells00" in tables + + +def test_ivf_create_with_distance_metric(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] distance_metric=cosine indexed by ivf(nlist=16))" + ) + tables = [ + r[0] + for r in db.execute( + "SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1" + ).fetchall() + ] + assert "t_ivf_centroids00" in tables + + +def test_ivf_create_error_unknown_key(db): + result = exec( + db, + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(bogus=1))", + ) + assert "error" in result + + +def test_ivf_create_error_nprobe_gt_nlist(db): + result = exec( + db, + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4, nprobe=10))", + ) + assert "error" in result + + +# ============================================================================ +# Shadow table tests +# ============================================================================ + + +def test_ivf_shadow_tables_created(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=8))" + ) + trained = db.execute( + "SELECT value FROM t_info WHERE key = 'ivf_trained_0'" + ).fetchone()[0] + assert str(trained) == "0" + + # No cells yet (created lazily on first insert) + count = db.execute( + "SELECT count(*) FROM t_ivf_cells00" + ).fetchone()[0] + assert count == 0 + + +def test_ivf_drop_cleans_up(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))" + ) + db.execute("DROP TABLE t") + tables = [ + r[0] + for r in db.execute( + "SELECT name FROM sqlite_master WHERE type='table'" + ).fetchall() + ] + assert not any("ivf" in t for t in tables) + + +# ============================================================================ +# Insert tests (flat mode) +# ============================================================================ + + +def test_ivf_insert_flat_mode(db): + """Before training, vectors go to unassigned cell.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))" + ) + db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 2, 3, 4])]) + db.execute("INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([5, 6, 7, 8])]) + + assert ivf_unassigned_count(db) == 2 + assert ivf_assigned_count(db) == 0 + + # rowid_map should have 2 entries + assert db.execute("SELECT count(*) FROM t_ivf_rowid_map00").fetchone()[0] == 2 + + +def test_ivf_delete_flat_mode(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))" + ) + db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 2, 3, 4])]) + db.execute("INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([5, 6, 7, 8])]) + db.execute("DELETE FROM t WHERE rowid = 1") + + assert ivf_unassigned_count(db) == 1 + assert db.execute("SELECT count(*) FROM t_ivf_rowid_map00").fetchone()[0] == 1 + + +# ============================================================================ +# KNN flat mode tests +# ============================================================================ + + +def test_ivf_knn_flat_mode(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))" + ) + db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 0, 0, 0])]) + db.execute("INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([2, 0, 0, 0])]) + db.execute("INSERT INTO t(rowid, v) VALUES (3, ?)", [_f32([9, 0, 0, 0])]) + + rows = db.execute( + "SELECT rowid, distance FROM t WHERE v MATCH ? AND k = 2", + [_f32([1.5, 0, 0, 0])], + ).fetchall() + assert len(rows) == 2 + rowids = {r[0] for r in rows} + assert rowids == {1, 2} + + +def test_ivf_knn_flat_empty(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))" + ) + rows = db.execute( + "SELECT rowid FROM t WHERE v MATCH ? AND k = 5", + [_f32([1, 0, 0, 0])], + ).fetchall() + assert len(rows) == 0 + + +# ============================================================================ +# compute-centroids tests +# ============================================================================ + + +def test_compute_centroids(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))" + ) + for i in range(40): + db.execute( + "INSERT INTO t(rowid, v) VALUES (?, ?)", + [i, _f32([i % 10, i // 10, 0, 0])], + ) + + assert ivf_unassigned_count(db) == 40 + + db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + + # After training: unassigned cell should be gone (or empty), vectors in trained cells + assert ivf_unassigned_count(db) == 0 + assert ivf_assigned_count(db) == 40 + assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 4 + trained = db.execute( + "SELECT value FROM t_info WHERE key='ivf_trained_0'" + ).fetchone()[0] + assert str(trained) == "1" + + +def test_compute_centroids_recompute(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2))" + ) + for i in range(20): + db.execute( + "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])] + ) + + db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2 + + db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2 + assert ivf_assigned_count(db) == 20 + + +# ============================================================================ +# Insert after training (assigned mode) +# ============================================================================ + + +def test_ivf_insert_after_training(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2))" + ) + for i in range(20): + db.execute( + "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])] + ) + + db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + + db.execute( + "INSERT INTO t(rowid, v) VALUES (100, ?)", [_f32([5, 0, 0, 0])] + ) + + # Should be in a trained cell, not unassigned + row = db.execute( + "SELECT m.cell_id, c.centroid_id FROM t_ivf_rowid_map00 m " + "JOIN t_ivf_cells00 c ON c.rowid = m.cell_id " + "WHERE m.rowid = 100" + ).fetchone() + assert row is not None + assert row[1] >= 0 # centroid_id >= 0 means trained cell + + +# ============================================================================ +# KNN after training (IVF probe mode) +# ============================================================================ + + +def test_ivf_knn_after_training(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4, nprobe=4))" + ) + for i in range(100): + db.execute( + "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])] + ) + + db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + + rows = db.execute( + "SELECT rowid, distance FROM t WHERE v MATCH ? AND k = 5", + [_f32([50.0, 0, 0, 0])], + ).fetchall() + assert len(rows) == 5 + assert rows[0][0] == 50 + assert rows[0][1] < 0.01 + + +def test_ivf_knn_k_larger_than_n(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))" + ) + for i in range(5): + db.execute( + "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])] + ) + + db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + + rows = db.execute( + "SELECT rowid FROM t WHERE v MATCH ? AND k = 100", + [_f32([0, 0, 0, 0])], + ).fetchall() + assert len(rows) == 5 + + +# ============================================================================ +# Manual centroid import (set-centroid, assign-vectors) +# ============================================================================ + + +def test_set_centroid_and_assign(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=0))" + ) + for i in range(20): + db.execute( + "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])] + ) + + db.execute( + "INSERT INTO t(rowid, v) VALUES ('set-centroid:0', ?)", + [_f32([5, 0, 0, 0])], + ) + db.execute( + "INSERT INTO t(rowid, v) VALUES ('set-centroid:1', ?)", + [_f32([15, 0, 0, 0])], + ) + + assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2 + + db.execute("INSERT INTO t(rowid) VALUES ('assign-vectors')") + + assert ivf_unassigned_count(db) == 0 + assert ivf_assigned_count(db) == 20 + + +# ============================================================================ +# clear-centroids +# ============================================================================ + + +def test_clear_centroids(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2))" + ) + for i in range(20): + db.execute( + "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])] + ) + + db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2 + + db.execute("INSERT INTO t(rowid) VALUES ('clear-centroids')") + assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 0 + assert ivf_unassigned_count(db) == 20 + trained = db.execute( + "SELECT value FROM t_info WHERE key='ivf_trained_0'" + ).fetchone()[0] + assert str(trained) == "0" + + +# ============================================================================ +# Delete after training +# ============================================================================ + + +def test_ivf_delete_after_training(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2))" + ) + for i in range(10): + db.execute( + "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])] + ) + + db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + assert ivf_assigned_count(db) == 10 + + db.execute("DELETE FROM t WHERE rowid = 5") + assert ivf_assigned_count(db) == 9 + assert db.execute("SELECT count(*) FROM t_ivf_rowid_map00").fetchone()[0] == 9 + + +# ============================================================================ +# Recall test +# ============================================================================ + + +def test_ivf_recall_nprobe_equals_nlist(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=8, nprobe=8))" + ) + for i in range(100): + db.execute( + "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])] + ) + + db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')") + + rows = db.execute( + "SELECT rowid FROM t WHERE v MATCH ? AND k = 10", + [_f32([50.0, 0, 0, 0])], + ).fetchall() + rowids = {r[0] for r in rows} + + # 45 and 55 are equidistant from 50, so either may appear in top 10 + assert 50 in rowids + assert len(rowids) == 10 + assert all(45 <= r <= 55 for r in rowids) diff --git a/tests/test-unit.c b/tests/test-unit.c index b180625..27a469d 100644 --- a/tests/test-unit.c +++ b/tests/test-unit.c @@ -577,6 +577,182 @@ void test_vec0_parse_vector_column() { assert(rc == SQLITE_ERROR); } +#if SQLITE_VEC_ENABLE_IVF + // IVF: indexed by ivf() — defaults + { + const char *input = "v float[4] indexed by ivf()"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.index_type == VEC0_INDEX_TYPE_IVF); + assert(col.dimensions == 4); + assert(col.index_type == VEC0_INDEX_TYPE_IVF); + assert(col.ivf.nlist == 128); // default + assert(col.ivf.nprobe == 10); // default + sqlite3_free(col.name); + } + + // IVF: indexed by ivf(nlist=8) — nprobe auto-clamped to 8 + { + const char *input = "v float[4] indexed by ivf(nlist=8)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.index_type == VEC0_INDEX_TYPE_IVF); + assert(col.index_type == VEC0_INDEX_TYPE_IVF); + assert(col.ivf.nlist == 8); + assert(col.ivf.nprobe == 8); // clamped from default 10 + sqlite3_free(col.name); + } + + // IVF: indexed by ivf(nlist=64, nprobe=8) + { + const char *input = "v float[4] indexed by ivf(nlist=64, nprobe=8)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.index_type == VEC0_INDEX_TYPE_IVF); + assert(col.ivf.nlist == 64); + assert(col.ivf.nprobe == 8); + sqlite3_free(col.name); + } + + // IVF: with distance_metric before indexed by + { + const char *input = "v float[4] distance_metric=cosine indexed by ivf(nlist=16)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.index_type == VEC0_INDEX_TYPE_IVF); + assert(col.distance_metric == VEC0_DISTANCE_METRIC_COSINE); + assert(col.index_type == VEC0_INDEX_TYPE_IVF); + assert(col.ivf.nlist == 16); + sqlite3_free(col.name); + } + + // IVF: nlist=0 (deferred) + { + const char *input = "v float[4] indexed by ivf(nlist=0)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.ivf.nlist == 0); + sqlite3_free(col.name); + } + + // IVF error: nprobe > nlist + { + const char *input = "v float[4] indexed by ivf(nlist=4, nprobe=10)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_ERROR); + } + + // IVF error: unknown key + { + const char *input = "v float[4] indexed by ivf(bogus=1)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_ERROR); + } + + // IVF error: unknown index type (hnsw not supported) + { + const char *input = "v float[4] indexed by hnsw()"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_ERROR); + } + + // Not IVF: no ivf config + { + const char *input = "v float[4]"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.index_type == VEC0_INDEX_TYPE_FLAT); + sqlite3_free(col.name); + } + + // IVF: quantizer=binary + { + const char *input = "v float[768] indexed by ivf(nlist=128, quantizer=binary)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.index_type == VEC0_INDEX_TYPE_IVF); + assert(col.ivf.nlist == 128); + assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_BINARY); + assert(col.ivf.oversample == 1); + sqlite3_free(col.name); + } + + // IVF: quantizer=int8 + { + const char *input = "v float[768] indexed by ivf(nlist=64, quantizer=int8)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_INT8); + sqlite3_free(col.name); + } + + // IVF: quantizer=none (explicit) + { + const char *input = "v float[768] indexed by ivf(quantizer=none)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_NONE); + sqlite3_free(col.name); + } + + // IVF: oversample=10 with quantizer + { + const char *input = "v float[768] indexed by ivf(nlist=128, quantizer=binary, oversample=10)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_BINARY); + assert(col.ivf.oversample == 10); + assert(col.ivf.nlist == 128); + sqlite3_free(col.name); + } + + // IVF: all params + { + const char *input = "v float[768] distance_metric=cosine indexed by ivf(nlist=256, nprobe=16, quantizer=int8, oversample=4)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.distance_metric == VEC0_DISTANCE_METRIC_COSINE); + assert(col.ivf.nlist == 256); + assert(col.ivf.nprobe == 16); + assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_INT8); + assert(col.ivf.oversample == 4); + sqlite3_free(col.name); + } + + // IVF error: oversample > 1 without quantizer + { + const char *input = "v float[768] indexed by ivf(oversample=10)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_ERROR); + } + + // IVF error: unknown quantizer value + { + const char *input = "v float[768] indexed by ivf(quantizer=pq)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_ERROR); + } + + // IVF: quantizer with defaults (nlist=128 default, nprobe=10 default) + { + const char *input = "v float[768] indexed by ivf(quantizer=binary, oversample=5)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.ivf.nlist == 128); + assert(col.ivf.nprobe == 10); + assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_BINARY); + assert(col.ivf.oversample == 5); + sqlite3_free(col.name); + } +#else + // When IVF is disabled, parsing "ivf" should fail + { + const char *input = "v float[4] indexed by ivf()"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_ERROR); + } +#endif /* SQLITE_VEC_ENABLE_IVF */ + printf(" All vec0_parse_vector_column tests passed.\n"); } @@ -821,6 +997,38 @@ void test_rescore_quantize_float_to_int8() { float src[8] = {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f}; _test_rescore_quantize_float_to_int8(src, dst, 8); for (int i = 0; i < 8; i++) { +#if SQLITE_VEC_ENABLE_IVF +void test_ivf_quantize_int8() { + printf("Starting %s...\n", __func__); + + // Basic values in [-1, 1] range + { + float src[] = {0.0f, 1.0f, -1.0f, 0.5f}; + int8_t dst[4]; + ivf_quantize_int8(src, dst, 4); + assert(dst[0] == 0); + assert(dst[1] == 127); + assert(dst[2] == -127); + assert(dst[3] == 63); // 0.5 * 127 = 63.5, truncated to 63 + } + + // Clamping: values beyond [-1, 1] + { + float src[] = {2.0f, -3.0f, 100.0f, -0.01f}; + int8_t dst[4]; + ivf_quantize_int8(src, dst, 4); + assert(dst[0] == 127); // clamped to 1.0 + assert(dst[1] == -127); // clamped to -1.0 + assert(dst[2] == 127); // clamped to 1.0 + assert(dst[3] == (int8_t)(-0.01f * 127.0f)); + } + + // Zero vector + { + float src[] = {0.0f, 0.0f, 0.0f, 0.0f}; + int8_t dst[4]; + ivf_quantize_int8(src, dst, 4); + for (int i = 0; i < 4; i++) { assert(dst[i] == 0); } } @@ -882,6 +1090,103 @@ void test_rescore_quantized_byte_size() { } void test_vec0_parse_vector_column_rescore() { + // Negative zero + { + float src[] = {-0.0f}; + int8_t dst[1]; + ivf_quantize_int8(src, dst, 1); + assert(dst[0] == 0); + } + + // Single element + { + float src[] = {0.75f}; + int8_t dst[1]; + ivf_quantize_int8(src, dst, 1); + assert(dst[0] == (int8_t)(0.75f * 127.0f)); + } + + // Boundary: exactly 1.0 and -1.0 + { + float src[] = {1.0f, -1.0f}; + int8_t dst[2]; + ivf_quantize_int8(src, dst, 2); + assert(dst[0] == 127); + assert(dst[1] == -127); + } + + printf(" All ivf_quantize_int8 tests passed.\n"); +} + +void test_ivf_quantize_binary() { + printf("Starting %s...\n", __func__); + + // Basic sign-bit quantization: positive -> 1, negative/zero -> 0 + { + float src[] = {1.0f, -1.0f, 0.5f, -0.5f, 0.0f, 0.1f, -0.1f, 2.0f}; + uint8_t dst[1]; + ivf_quantize_binary(src, dst, 8); + // bit 0: 1.0 > 0 -> 1 (LSB) + // bit 1: -1.0 -> 0 + // bit 2: 0.5 > 0 -> 1 + // bit 3: -0.5 -> 0 + // bit 4: 0.0 -> 0 (not > 0) + // bit 5: 0.1 > 0 -> 1 + // bit 6: -0.1 -> 0 + // bit 7: 2.0 > 0 -> 1 + // Expected: bits 0,2,5,7 = 0b10100101 = 0xA5 + assert(dst[0] == 0xA5); + } + + // All positive + { + float src[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + uint8_t dst[1]; + ivf_quantize_binary(src, dst, 8); + assert(dst[0] == 0xFF); + } + + // All negative + { + float src[] = {-1.0f, -2.0f, -3.0f, -4.0f, -5.0f, -6.0f, -7.0f, -8.0f}; + uint8_t dst[1]; + ivf_quantize_binary(src, dst, 8); + assert(dst[0] == 0x00); + } + + // All zero (zero is NOT > 0, so all bits should be 0) + { + float src[] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + uint8_t dst[1]; + ivf_quantize_binary(src, dst, 8); + assert(dst[0] == 0x00); + } + + // Multi-byte: 16 dimensions -> 2 bytes + { + float src[16]; + for (int i = 0; i < 16; i++) src[i] = (i % 2 == 0) ? 1.0f : -1.0f; + uint8_t dst[2]; + ivf_quantize_binary(src, dst, 16); + // Even indices are positive: bits 0,2,4,6 in each byte + // byte 0: bits 0,2,4,6 = 0b01010101 = 0x55 + // byte 1: same pattern = 0x55 + assert(dst[0] == 0x55); + assert(dst[1] == 0x55); + } + + // Single byte, only first bit set + { + float src[] = {0.1f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; + uint8_t dst[1]; + ivf_quantize_binary(src, dst, 8); + assert(dst[0] == 0x01); + } + + printf(" All ivf_quantize_binary tests passed.\n"); +} + +void test_ivf_config_parsing() { printf("Starting %s...\n", __func__); struct VectorColumnDefinition col; int rc; @@ -955,6 +1260,116 @@ void test_vec0_parse_vector_column_rescore() { } #endif /* SQLITE_VEC_ENABLE_RESCORE */ + // Default IVF config + { + const char *s = "v float[4] indexed by ivf()"; + rc = vec0_parse_vector_column(s, (int)strlen(s), &col); + assert(rc == SQLITE_OK); + assert(col.index_type == VEC0_INDEX_TYPE_IVF); + assert(col.ivf.nlist == 128); // default + assert(col.ivf.nprobe == 10); // default + assert(col.ivf.quantizer == 0); // VEC0_IVF_QUANTIZER_NONE + sqlite3_free(col.name); + } + + // Custom nlist and nprobe + { + const char *s = "v float[4] indexed by ivf(nlist=64, nprobe=8)"; + rc = vec0_parse_vector_column(s, (int)strlen(s), &col); + assert(rc == SQLITE_OK); + assert(col.ivf.nlist == 64); + assert(col.ivf.nprobe == 8); + sqlite3_free(col.name); + } + + // nlist=0 (deferred) + { + const char *s = "v float[4] indexed by ivf(nlist=0)"; + rc = vec0_parse_vector_column(s, (int)strlen(s), &col); + assert(rc == SQLITE_OK); + assert(col.ivf.nlist == 0); + sqlite3_free(col.name); + } + + // Quantizer options + { + const char *s = "v float[8] indexed by ivf(quantizer=int8)"; + rc = vec0_parse_vector_column(s, (int)strlen(s), &col); + assert(rc == SQLITE_OK); + assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_INT8); + sqlite3_free(col.name); + } + + { + const char *s = "v float[8] indexed by ivf(quantizer=binary)"; + rc = vec0_parse_vector_column(s, (int)strlen(s), &col); + assert(rc == SQLITE_OK); + assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_BINARY); + sqlite3_free(col.name); + } + + // nprobe > nlist (explicit) should fail + { + const char *s = "v float[4] indexed by ivf(nlist=4, nprobe=10)"; + rc = vec0_parse_vector_column(s, (int)strlen(s), &col); + assert(rc == SQLITE_ERROR); + } + + // Unknown key + { + const char *s = "v float[4] indexed by ivf(bogus=1)"; + rc = vec0_parse_vector_column(s, (int)strlen(s), &col); + assert(rc == SQLITE_ERROR); + } + + // nlist > max (65536) should fail + { + const char *s = "v float[4] indexed by ivf(nlist=65537)"; + rc = vec0_parse_vector_column(s, (int)strlen(s), &col); + assert(rc == SQLITE_ERROR); + } + + // nlist at max boundary (65536) should succeed + { + const char *s = "v float[4] indexed by ivf(nlist=65536)"; + rc = vec0_parse_vector_column(s, (int)strlen(s), &col); + assert(rc == SQLITE_OK); + assert(col.ivf.nlist == 65536); + sqlite3_free(col.name); + } + + // oversample > 1 without quantization should fail + { + const char *s = "v float[4] indexed by ivf(oversample=4)"; + rc = vec0_parse_vector_column(s, (int)strlen(s), &col); + assert(rc == SQLITE_ERROR); + } + + // oversample with quantizer should succeed + { + const char *s = "v float[8] indexed by ivf(quantizer=int8, oversample=4)"; + rc = vec0_parse_vector_column(s, (int)strlen(s), &col); + assert(rc == SQLITE_OK); + assert(col.ivf.oversample == 4); + assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_INT8); + sqlite3_free(col.name); + } + + // All options combined + { + const char *s = "v float[8] indexed by ivf(nlist=32, nprobe=4, quantizer=int8, oversample=2)"; + rc = vec0_parse_vector_column(s, (int)strlen(s), &col); + assert(rc == SQLITE_OK); + assert(col.ivf.nlist == 32); + assert(col.ivf.nprobe == 4); + assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_INT8); + assert(col.ivf.oversample == 2); + sqlite3_free(col.name); + } + + printf(" All ivf_config_parsing tests passed.\n"); +} +#endif /* SQLITE_VEC_ENABLE_IVF */ int main() { printf("Starting unit tests...\n"); @@ -982,6 +1397,10 @@ int main() { test_rescore_quantize_float_to_int8(); test_rescore_quantized_byte_size(); test_vec0_parse_vector_column_rescore(); +#if SQLITE_VEC_ENABLE_IVF + test_ivf_quantize_int8(); + test_ivf_quantize_binary(); + test_ivf_config_parsing(); #endif printf("All unit tests passed.\n"); }