diff --git a/IVF_PLAN.md b/IVF_PLAN.md
new file mode 100644
index 0000000..91bb85a
--- /dev/null
+++ b/IVF_PLAN.md
@@ -0,0 +1,264 @@
+# IVF Index for sqlite-vec
+
+## Overview
+
+IVF (Inverted File Index) is an approximate nearest neighbor index for
+sqlite-vec's `vec0` virtual table. It partitions vectors into clusters via
+k-means, then at query time only scans the nearest clusters instead of all
+vectors. Combined with scalar or binary quantization, this gives 5-20x query
+speedups over brute-force with tunable recall.
+
+## SQL API
+
+### Table Creation
+
+```sql
+CREATE VIRTUAL TABLE vec_items USING vec0(
+ id INTEGER PRIMARY KEY,
+ embedding float[768] distance_metric=cosine
+ INDEXED BY ivf(nlist=128, nprobe=16)
+);
+
+-- With quantization (4x smaller cells, rescore for recall)
+CREATE VIRTUAL TABLE vec_items USING vec0(
+ id INTEGER PRIMARY KEY,
+ embedding float[768] distance_metric=cosine
+ INDEXED BY ivf(nlist=128, nprobe=16, quantizer=int8, oversample=4)
+);
+```
+
+### Parameters
+
+| Parameter | Values | Default | Description |
+|-----------|--------|---------|-------------|
+| `nlist` | 1-65536, or 0 | 128 | Number of k-means clusters. Rule of thumb: `sqrt(N)` |
+| `nprobe` | 1-nlist | 10 | Clusters to search at query time. More = better recall, slower |
+| `quantizer` | `none`, `int8`, `binary` | `none` | How vectors are stored in cells |
+| `oversample` | >= 1 | 1 | Re-rank `oversample * k` candidates with full-precision distance |
+
+### Inserting Vectors
+
+```sql
+-- Works immediately, even before training
+INSERT INTO vec_items(id, embedding) VALUES (1, :vector);
+```
+
+Before centroids exist, vectors go to an "unassigned" partition and queries do
+brute-force. After training, new inserts are assigned to the nearest centroid.
+
+### Training (Computing Centroids)
+
+```sql
+-- Run built-in k-means on all vectors
+INSERT INTO vec_items(id) VALUES ('compute-centroids');
+```
+
+This loads all vectors into memory, runs k-means++ with Lloyd's algorithm,
+creates quantized centroids, and redistributes all vectors into cluster cells.
+It's a blocking operation — run it once after bulk insert.
+
+### Manual Centroid Import
+
+```sql
+-- Import externally-computed centroids
+INSERT INTO vec_items(id, embedding) VALUES ('set-centroid:0', :centroid_0);
+INSERT INTO vec_items(id, embedding) VALUES ('set-centroid:1', :centroid_1);
+
+-- Assign vectors to imported centroids
+INSERT INTO vec_items(id) VALUES ('assign-vectors');
+```
+
+### Runtime Parameter Tuning
+
+```sql
+-- Change nprobe without rebuilding the index
+INSERT INTO vec_items(id) VALUES ('nprobe=32');
+```
+
+### KNN Queries
+
+```sql
+-- Same syntax as standard vec0
+SELECT id, distance
+FROM vec_items
+WHERE embedding MATCH :query AND k = 10;
+```
+
+### Other Commands
+
+```sql
+-- Remove centroids, move all vectors back to unassigned
+INSERT INTO vec_items(id) VALUES ('clear-centroids');
+```
+
+## How It Works
+
+### Architecture
+
+```
+User vector (float32)
+ → quantize to int8/binary (if quantizer != none)
+ → find nearest centroid (quantized distance)
+ → store quantized vector in cell blob
+ → store full vector in KV table (if quantizer != none)
+ → query:
+ 1. quantize query vector
+ 2. find top nprobe centroids by quantized distance
+ 3. scan cell blobs: quantized distance (fast, small I/O)
+ 4. if oversample > 1: re-score top N*k with full vectors
+ 5. return top k
+```
+
+### Shadow Tables
+
+For a table `vec_items` with vector column index 0:
+
+| Table | Schema | Purpose |
+|-------|--------|---------|
+| `vec_items_ivf_centroids00` | `centroid_id PK, centroid BLOB` | K-means centroids (quantized) |
+| `vec_items_ivf_cells00` | `centroid_id, n_vectors, validity BLOB, rowids BLOB, vectors BLOB` | Packed vector cells, 64 vectors max per row. Multiple rows per centroid. Index on centroid_id. |
+| `vec_items_ivf_rowid_map00` | `rowid PK, cell_id, slot` | Maps vector rowid → cell location for O(1) delete |
+| `vec_items_ivf_vectors00` | `rowid PK, vector BLOB` | Full-precision vectors (only when quantizer != none) |
+
+### Cell Storage
+
+Cells use packed blob storage identical to vec0's chunk layout:
+- **validity**: bitmap (1 bit per slot) marking live vectors
+- **rowids**: packed i64 array
+- **vectors**: packed array of quantized vectors
+
+Cells are capped at 64 vectors (~200KB at 768-dim float32, ~48KB for int8,
+~6KB for binary). When a cell fills, a new row is created for the same
+centroid. This avoids SQLite overflow page traversal which was a 110x
+performance bottleneck with unbounded cells.
+
+### Quantization
+
+**int8**: Each float32 dimension clamped to [-1,1] and scaled to int8
+[-127,127]. 4x storage reduction. Distance computed via int8 L2.
+
+**binary**: Sign-bit quantization — each bit is 1 if the float is positive.
+32x storage reduction. Distance computed via hamming distance.
+
+**Oversample re-ranking**: When `oversample > 1`, the quantized scan collects
+`oversample * k` candidates, then looks up each candidate's full-precision
+vector from the KV table and re-computes exact distance. This recovers nearly
+all recall lost from quantization. At oversample=4 with int8, recall matches
+non-quantized IVF exactly.
+
+### K-Means
+
+Uses Lloyd's algorithm with k-means++ initialization:
+1. K-means++ picks initial centroids weighted by distance
+2. Lloyd's iterations: assign vectors to nearest centroid, recompute centroids as cluster means
+3. Empty cluster handling: reassign to farthest point
+4. K-means runs in float32; centroids are quantized before storage
+
+Training data: recommend 16× nlist vectors. At nlist=1000, that's 16k
+vectors — k-means takes ~140s on 768-dim data.
+
+## Performance
+
+### 100k vectors (COHERE 768-dim cosine)
+
+```
+ name qry(ms) recall
+───────────────────────────────────────────────
+ ivf(q=int8,os=4),p=8 5.3ms 0.934 ← 6x faster than flat
+ ivf(q=int8,os=4),p=16 5.4ms 0.968
+ ivf(q=none),p=8 5.3ms 0.934
+ ivf(q=binary,os=10),p=16 1.3ms 0.832 ← 26x faster than flat
+ ivf(q=int8,os=4),p=32 7.4ms 0.990
+ ivf(q=none),p=32 15.5ms 0.992
+ int8(os=4) 18.7ms 0.996
+ bit(os=8) 18.7ms 0.884
+ flat 33.7ms 1.000
+```
+
+### 1M vectors (COHERE 768-dim cosine)
+
+```
+ name insert train MB qry(ms) recall
+──────────────────────────────────────────────────────────────────────
+ ivf(q=int8,os=4),p=8 163s 142s 4725 16.3ms 0.892
+ ivf(q=binary,os=10),p=16 118s 144s 4073 17.7ms 0.830
+ ivf(q=int8,os=4),p=16 163s 142s 4725 24.3ms 0.950
+ ivf(q=int8,os=4),p=32 163s 142s 4725 41.6ms 0.980
+ ivf(q=none),p=8 497s 144s 3101 52.1ms 0.890
+ ivf(q=none),p=16 497s 144s 3101 56.6ms 0.950
+ bit(os=8) 18s - 3048 83.5ms 0.918
+ ivf(q=none),p=32 497s 144s 3101 103.9ms 0.980
+ int8(os=4) 19s - 3689 169.1ms 0.994
+ flat 20s - 2955 338.0ms 1.000
+```
+
+**Best config at 1M: `ivf(quantizer=int8, oversample=4, nprobe=16)`** —
+24ms query, 0.95 recall, 14x faster than flat, 7x faster than int8 baseline.
+
+### Scaling Characteristics
+
+| Metric | 100k | 1M | Scaling |
+|--------|------|-----|---------|
+| Flat query | 34ms | 338ms | 10x (linear) |
+| IVF int8 p=16 | 5.4ms | 24.3ms | 4.5x (sublinear) |
+| IVF insert rate | ~10k/s | ~6k/s | Slight degradation |
+| Training (nlist=1000) | 13s | 142s | ~11x |
+
+## Implementation
+
+### File Structure
+
+```
+sqlite-vec-ivf-kmeans.c K-means++ algorithm (pure C, no SQLite deps)
+sqlite-vec-ivf.c All IVF logic: parser, shadow tables, insert,
+ delete, query, centroid commands, quantization
+sqlite-vec.c ~50 lines of additions: struct fields, #includes,
+ dispatch hooks in parse/create/insert/delete/filter
+```
+
+Both IVF files are `#include`d into `sqlite-vec.c`. No Makefile changes needed.
+
+### Key Design Decisions
+
+1. **Fixed-size cells (64 vectors)** instead of one blob per centroid. Avoids
+ SQLite overflow page traversal which caused 110x insert slowdown.
+
+2. **Multiple cell rows per centroid** with an index on centroid_id. When a
+ cell fills, a new row is created. Query scans all rows for probed centroids
+ via `WHERE centroid_id IN (...)`.
+
+3. **Always store full vectors** when quantizer != none (in `_ivf_vectors` KV
+ table). Enables oversample re-ranking and point queries returning original
+ precision.
+
+4. **K-means in float32, quantize after**. Simpler than quantized k-means,
+ and assignment accuracy doesn't suffer much since nprobe compensates.
+
+5. **NEON SIMD for cosine distance**. Added `cosine_float_neon()` with 4-wide
+ FMA for dot product + magnitudes. Benefits all vec0 queries, not just IVF.
+
+6. **Runtime nprobe tuning**. `INSERT INTO t(id) VALUES ('nprobe=N')` changes
+ the probe count without rebuilding — enables fast parameter sweeps.
+
+### Optimization History
+
+| Optimization | Impact |
+|-------------|--------|
+| Fixed-size cells (64 max) | 110x insert speedup |
+| Skip chunk writes for IVF | 2x DB size reduction |
+| NEON cosine distance | 2x query speedup + 13% recall improvement (correct metric) |
+| Cached prepared statements | Eliminated per-insert prepare/finalize |
+| Batched cell reads (IN clause) | Fewer SQLite queries per KNN |
+| int8 quantization | 2.5x query speedup at same recall |
+| Binary quantization | 32x less cell I/O |
+| Oversample re-ranking | Recovers quantization recall loss |
+
+## Remaining Work
+
+See `ivf-benchmarks/TODO.md` for the full list. Key items:
+
+- **Cache centroids in memory** — each insert re-reads all centroids from SQLite
+- **Runtime oversample** — same pattern as nprobe runtime command
+- **SIMD k-means** — training uses scalar distance, could be 4x faster
+- **Top-k heap** — replace qsort with min-heap for large nprobe
+- **IVF-PQ** — product quantization for better compression/recall tradeoff
diff --git a/Makefile b/Makefile
index b50751b..2758ee5 100644
--- a/Makefile
+++ b/Makefile
@@ -206,6 +206,21 @@ test-loadable-watch:
test-unit:
$(CC) -DSQLITE_CORE -DSQLITE_VEC_TEST -DSQLITE_VEC_ENABLE_RESCORE tests/test-unit.c sqlite-vec.c vendor/sqlite3.c -I./ -Ivendor -o $(prefix)/test-unit && $(prefix)/test-unit
+# Standalone sqlite3 CLI with vec0 compiled in. Useful for benchmarking,
+# profiling (has debug symbols), and scripting without .load_extension.
+# make cli
+# dist/sqlite3 :memory: "SELECT vec_version()"
+# dist/sqlite3 < script.sql
+cli: sqlite-vec.h $(prefix)
+ $(CC) -O2 -g \
+ -DSQLITE_CORE \
+ -DSQLITE_EXTRA_INIT=core_init \
+ -DSQLITE_THREADSAFE=0 \
+ -Ivendor/ -I./ \
+ $(CFLAGS) \
+ vendor/sqlite3.c vendor/shell.c sqlite-vec.c examples/sqlite3-cli/core_init.c \
+ -ldl -lm -o $(prefix)/sqlite3
+
fuzz-build:
$(MAKE) -C tests/fuzz all
diff --git a/benchmarks-ann/Makefile b/benchmarks-ann/Makefile
index 0789d38..6081457 100644
--- a/benchmarks-ann/Makefile
+++ b/benchmarks-ann/Makefile
@@ -8,27 +8,20 @@ BASELINES = \
"brute-int8:type=baseline,variant=int8" \
"brute-bit:type=baseline,variant=bit"
-# --- Index-specific configs ---
-# Each index branch should add its own configs here. Example:
-#
-# DISKANN_CONFIGS = \
-# "diskann-R48-binary:type=diskann,R=48,L=128,quantizer=binary" \
-# "diskann-R72-int8:type=diskann,R=72,L=128,quantizer=int8"
-#
-# IVF_CONFIGS = \
-# "ivf-n128-p16:type=ivf,nlist=128,nprobe=16"
-#
-# ANNOY_CONFIGS = \
-# "annoy-t50:type=annoy,n_trees=50"
+# --- IVF configs ---
+IVF_CONFIGS = \
+ "ivf-n32-p8:type=ivf,nlist=32,nprobe=8" \
+ "ivf-n128-p16:type=ivf,nlist=128,nprobe=16" \
+ "ivf-n512-p32:type=ivf,nlist=512,nprobe=32"
RESCORE_CONFIGS = \
"rescore-bit-os8:type=rescore,quantizer=bit,oversample=8" \
"rescore-bit-os16:type=rescore,quantizer=bit,oversample=16" \
"rescore-int8-os8:type=rescore,quantizer=int8,oversample=8"
-ALL_CONFIGS = $(BASELINES) $(RESCORE_CONFIGS)
+ALL_CONFIGS = $(BASELINES) $(RESCORE_CONFIGS) $(IVF_CONFIGS)
-.PHONY: seed ground-truth bench-smoke bench-rescore bench-10k bench-50k bench-100k bench-all \
+.PHONY: seed ground-truth bench-smoke bench-rescore bench-ivf bench-10k bench-50k bench-100k bench-all \
report clean
# --- Data preparation ---
@@ -43,7 +36,8 @@ ground-truth: seed
# --- Quick smoke test ---
bench-smoke: seed
$(BENCH) --subset-size 5000 -k 10 -n 20 -o runs/smoke \
- $(BASELINES)
+ "brute-float:type=baseline,variant=float" \
+ "ivf-quick:type=ivf,nlist=16,nprobe=4"
bench-rescore: seed
$(BENCH) --subset-size 10000 -k 10 -o runs/rescore \
@@ -62,6 +56,12 @@ bench-100k: seed
bench-all: bench-10k bench-50k bench-100k
+# --- IVF across sizes ---
+bench-ivf: seed
+ $(BENCH) --subset-size 10000 -k 10 -o runs/ivf $(BASELINES) $(IVF_CONFIGS)
+ $(BENCH) --subset-size 50000 -k 10 -o runs/ivf $(BASELINES) $(IVF_CONFIGS)
+ $(BENCH) --subset-size 100000 -k 10 -o runs/ivf $(BASELINES) $(IVF_CONFIGS)
+
# --- Report ---
report:
@echo "Use: sqlite3 runs/
/results.db 'SELECT * FROM bench_results ORDER BY recall DESC'"
diff --git a/benchmarks-ann/bench.py b/benchmarks-ann/bench.py
index c1179d6..c640628 100644
--- a/benchmarks-ann/bench.py
+++ b/benchmarks-ann/bench.py
@@ -173,6 +173,48 @@ INDEX_REGISTRY["rescore"] = {
}
+# ============================================================================
+# IVF implementation
+# ============================================================================
+
+
+def _ivf_create_table_sql(params):
+ return (
+ f"CREATE VIRTUAL TABLE vec_items USING vec0("
+ f" id integer primary key,"
+ f" embedding float[768] distance_metric=cosine"
+ f" indexed by ivf("
+ f" nlist={params['nlist']},"
+ f" nprobe={params['nprobe']}"
+ f" )"
+ f")"
+ )
+
+
+def _ivf_post_insert_hook(conn, params):
+ print(" Training k-means centroids...", flush=True)
+ t0 = time.perf_counter()
+ conn.execute("INSERT INTO vec_items(id) VALUES ('compute-centroids')")
+ conn.commit()
+ elapsed = time.perf_counter() - t0
+ print(f" Training done in {elapsed:.1f}s", flush=True)
+ return elapsed
+
+
+def _ivf_describe(params):
+ return f"ivf nlist={params['nlist']:<4} nprobe={params['nprobe']}"
+
+
+INDEX_REGISTRY["ivf"] = {
+ "defaults": {"nlist": 128, "nprobe": 16},
+ "create_table_sql": _ivf_create_table_sql,
+ "insert_sql": None,
+ "post_insert_hook": _ivf_post_insert_hook,
+ "run_query": None,
+ "describe": _ivf_describe,
+}
+
+
# ============================================================================
# Config parsing
# ============================================================================
diff --git a/sqlite-vec-ivf-kmeans.c b/sqlite-vec-ivf-kmeans.c
new file mode 100644
index 0000000..0faa803
--- /dev/null
+++ b/sqlite-vec-ivf-kmeans.c
@@ -0,0 +1,214 @@
+/**
+ * sqlite-vec-ivf-kmeans.c — Pure k-means clustering algorithm.
+ *
+ * No SQLite dependency. Operates on float arrays in memory.
+ * #include'd into sqlite-vec.c after struct definitions.
+ */
+
+#ifndef SQLITE_VEC_IVF_KMEANS_C
+#define SQLITE_VEC_IVF_KMEANS_C
+
+// When opened standalone in an editor, pull in types so the LSP is happy.
+// When #include'd from sqlite-vec.c, SQLITE_VEC_H is already defined.
+#ifndef SQLITE_VEC_H
+#include "sqlite-vec.c" // IWYU pragma: keep
+#endif
+
+#include
+#include
+
+#define VEC0_IVF_KMEANS_MAX_ITER 25
+#define VEC0_IVF_KMEANS_DEFAULT_SEED 0
+
+// Simple xorshift32 PRNG
+static uint32_t ivf_xorshift32(uint32_t *state) {
+ uint32_t x = *state;
+ x ^= x << 13;
+ x ^= x >> 17;
+ x ^= x << 5;
+ *state = x;
+ return x;
+}
+
+// L2 squared distance between two float vectors
+static float ivf_l2_dist(const float *a, const float *b, int D) {
+ float sum = 0.0f;
+ for (int d = 0; d < D; d++) {
+ float diff = a[d] - b[d];
+ sum += diff * diff;
+ }
+ return sum;
+}
+
+// Find nearest centroid for a single vector. Returns centroid index.
+static int ivf_nearest_centroid(const float *vec, const float *centroids,
+ int D, int k) {
+ float min_dist = FLT_MAX;
+ int best = 0;
+ for (int c = 0; c < k; c++) {
+ float dist = ivf_l2_dist(vec, ¢roids[c * D], D);
+ if (dist < min_dist) {
+ min_dist = dist;
+ best = c;
+ }
+ }
+ return best;
+}
+
+/**
+ * K-means++ initialization.
+ * Picks k initial centroids from the data with probability proportional
+ * to squared distance from nearest existing centroid.
+ */
+static int ivf_kmeans_init_plusplus(const float *vectors, int N, int D,
+ int k, uint32_t seed, float *centroids) {
+ if (N <= 0 || k <= 0 || D <= 0)
+ return -1;
+ if (seed == 0)
+ seed = 42;
+
+ // Pick first centroid randomly
+ int first = ivf_xorshift32(&seed) % N;
+ memcpy(centroids, &vectors[first * D], D * sizeof(float));
+
+ if (k == 1)
+ return 0;
+
+ // Allocate distance array
+ float *dists = sqlite3_malloc64((i64)N * sizeof(float));
+ if (!dists)
+ return -1;
+
+ for (int c = 1; c < k; c++) {
+ // Compute D(x) = distance to nearest existing centroid
+ double total = 0.0;
+ for (int i = 0; i < N; i++) {
+ float d = ivf_l2_dist(&vectors[i * D], ¢roids[(c - 1) * D], D);
+ if (c == 1 || d < dists[i]) {
+ dists[i] = d;
+ }
+ total += dists[i];
+ }
+
+ // Weighted random selection
+ if (total <= 0.0) {
+ // All distances zero — pick randomly
+ int pick = ivf_xorshift32(&seed) % N;
+ memcpy(¢roids[c * D], &vectors[pick * D], D * sizeof(float));
+ } else {
+ double threshold = ((double)ivf_xorshift32(&seed) / (double)0xFFFFFFFF) * total;
+ double cumulative = 0.0;
+ int pick = N - 1;
+ for (int i = 0; i < N; i++) {
+ cumulative += dists[i];
+ if (cumulative >= threshold) {
+ pick = i;
+ break;
+ }
+ }
+ memcpy(¢roids[c * D], &vectors[pick * D], D * sizeof(float));
+ }
+ }
+
+ sqlite3_free(dists);
+ return 0;
+}
+
+/**
+ * Lloyd's k-means algorithm.
+ *
+ * @param vectors N*D float array (row-major)
+ * @param N number of vectors
+ * @param D dimensionality
+ * @param k number of clusters
+ * @param max_iter maximum iterations
+ * @param seed PRNG seed for initialization
+ * @param out_centroids output: k*D float array (caller-allocated)
+ * @return 0 on success, -1 on error
+ */
+static int ivf_kmeans(const float *vectors, int N, int D, int k,
+ int max_iter, uint32_t seed, float *out_centroids) {
+ if (N <= 0 || D <= 0 || k <= 0)
+ return -1;
+
+ // Clamp k to N
+ if (k > N)
+ k = N;
+
+ // Allocate working memory
+ int *assignments = sqlite3_malloc64((i64)N * sizeof(int));
+ float *new_centroids = sqlite3_malloc64((i64)k * D * sizeof(float));
+ int *counts = sqlite3_malloc64((i64)k * sizeof(int));
+
+ if (!assignments || !new_centroids || !counts) {
+ sqlite3_free(assignments);
+ sqlite3_free(new_centroids);
+ sqlite3_free(counts);
+ return -1;
+ }
+
+ memset(assignments, -1, N * sizeof(int));
+
+ // Initialize centroids via k-means++
+ if (ivf_kmeans_init_plusplus(vectors, N, D, k, seed, out_centroids) != 0) {
+ sqlite3_free(assignments);
+ sqlite3_free(new_centroids);
+ sqlite3_free(counts);
+ return -1;
+ }
+
+ for (int iter = 0; iter < max_iter; iter++) {
+ // Assignment step
+ int changed = 0;
+ for (int i = 0; i < N; i++) {
+ int nearest = ivf_nearest_centroid(&vectors[i * D], out_centroids, D, k);
+ if (nearest != assignments[i]) {
+ assignments[i] = nearest;
+ changed++;
+ }
+ }
+ if (changed == 0)
+ break;
+
+ // Update step
+ memset(new_centroids, 0, (size_t)k * D * sizeof(float));
+ memset(counts, 0, k * sizeof(int));
+
+ for (int i = 0; i < N; i++) {
+ int c = assignments[i];
+ counts[c]++;
+ for (int d = 0; d < D; d++) {
+ new_centroids[c * D + d] += vectors[i * D + d];
+ }
+ }
+
+ for (int c = 0; c < k; c++) {
+ if (counts[c] == 0) {
+ // Empty cluster: reassign to farthest point from its nearest centroid
+ float max_dist = -1.0f;
+ int farthest = 0;
+ for (int i = 0; i < N; i++) {
+ float d = ivf_l2_dist(&vectors[i * D],
+ &out_centroids[assignments[i] * D], D);
+ if (d > max_dist) {
+ max_dist = d;
+ farthest = i;
+ }
+ }
+ memcpy(&out_centroids[c * D], &vectors[farthest * D],
+ D * sizeof(float));
+ } else {
+ for (int d = 0; d < D; d++) {
+ out_centroids[c * D + d] = new_centroids[c * D + d] / counts[c];
+ }
+ }
+ }
+ }
+
+ sqlite3_free(assignments);
+ sqlite3_free(new_centroids);
+ sqlite3_free(counts);
+ return 0;
+}
+
+#endif /* SQLITE_VEC_IVF_KMEANS_C */
diff --git a/sqlite-vec-ivf.c b/sqlite-vec-ivf.c
new file mode 100644
index 0000000..5bc8edb
--- /dev/null
+++ b/sqlite-vec-ivf.c
@@ -0,0 +1,1445 @@
+/**
+ * sqlite-vec-ivf.c — IVF (Inverted File Index) for sqlite-vec.
+ *
+ * #include'd into sqlite-vec.c after struct definitions and before vec0_init().
+ *
+ * Storage: fixed-size packed blob cells (capped at IVF_CELL_MAX_VECTORS).
+ * Multiple cell rows per centroid. cell_id is auto-increment rowid,
+ * centroid_id is indexed for lookup. This keeps blobs small (~200KB)
+ * and avoids expensive overflow page traversal on insert.
+ */
+
+#ifndef SQLITE_VEC_IVF_C
+#define SQLITE_VEC_IVF_C
+
+#ifdef SQLITE_VEC_TEST
+#define IVF_STATIC
+#else
+#define IVF_STATIC static
+#endif
+
+// When opened standalone in an editor, pull in sqlite-vec.c so the LSP
+// can resolve all types (vec0_vtab, VectorColumnDefinition, etc.).
+// When #include'd from sqlite-vec.c, SQLITE_VEC_H is already defined.
+#ifndef SQLITE_VEC_H
+#include "sqlite-vec.c" // IWYU pragma: keep
+#endif
+
+#define VEC0_IVF_DEFAULT_NLIST 128
+#define VEC0_IVF_DEFAULT_NPROBE 10
+#define VEC0_IVF_MAX_NLIST 65536
+#define VEC0_IVF_CELL_MAX_VECTORS 64 // ~200KB per cell at 768-dim f32
+#define VEC0_IVF_UNASSIGNED_CENTROID_ID (-1)
+
+#define VEC0_SHADOW_IVF_CENTROIDS_NAME "\"%w\".\"%w_ivf_centroids%02d\""
+#define VEC0_SHADOW_IVF_CELLS_NAME "\"%w\".\"%w_ivf_cells%02d\""
+#define VEC0_SHADOW_IVF_ROWID_MAP_NAME "\"%w\".\"%w_ivf_rowid_map%02d\""
+#define VEC0_SHADOW_IVF_VECTORS_NAME "\"%w\".\"%w_ivf_vectors%02d\""
+
+// ============================================================================
+// Parser
+// ============================================================================
+
+static int vec0_parse_ivf_options(struct Vec0Scanner *scanner,
+ struct Vec0IvfConfig *config) {
+ struct Vec0Token token;
+ int rc;
+ config->nlist = VEC0_IVF_DEFAULT_NLIST;
+ config->nprobe = -1;
+ config->quantizer = VEC0_IVF_QUANTIZER_NONE;
+ config->oversample = 1;
+ int nprobe_explicit = 0;
+
+ rc = vec0_scanner_next(scanner, &token);
+ if (rc != VEC0_TOKEN_RESULT_SOME || token.token_type != TOKEN_TYPE_LPAREN)
+ return SQLITE_ERROR;
+
+ rc = vec0_scanner_next(scanner, &token);
+ if (rc == VEC0_TOKEN_RESULT_SOME && token.token_type == TOKEN_TYPE_RPAREN) {
+ config->nprobe = VEC0_IVF_DEFAULT_NPROBE;
+ return SQLITE_OK;
+ }
+
+ while (1) {
+ if (rc != VEC0_TOKEN_RESULT_SOME || token.token_type != TOKEN_TYPE_IDENTIFIER)
+ return SQLITE_ERROR;
+ char *key = token.start;
+ int keyLength = token.end - token.start;
+
+ rc = vec0_scanner_next(scanner, &token);
+ if (rc != VEC0_TOKEN_RESULT_SOME || token.token_type != TOKEN_TYPE_EQ)
+ return SQLITE_ERROR;
+
+ // Read value — can be digit or identifier
+ rc = vec0_scanner_next(scanner, &token);
+ if (rc != VEC0_TOKEN_RESULT_SOME) return SQLITE_ERROR;
+ if (token.token_type != TOKEN_TYPE_DIGIT &&
+ token.token_type != TOKEN_TYPE_IDENTIFIER)
+ return SQLITE_ERROR;
+
+ char *val = token.start;
+ int valLength = token.end - token.start;
+
+ if (sqlite3_strnicmp(key, "nlist", keyLength) == 0) {
+ if (token.token_type != TOKEN_TYPE_DIGIT) return SQLITE_ERROR;
+ int v = atoi(val);
+ if (v < 0 || v > VEC0_IVF_MAX_NLIST) return SQLITE_ERROR;
+ config->nlist = v;
+ } else if (sqlite3_strnicmp(key, "nprobe", keyLength) == 0) {
+ if (token.token_type != TOKEN_TYPE_DIGIT) return SQLITE_ERROR;
+ int v = atoi(val);
+ if (v < 1 || v > VEC0_IVF_MAX_NLIST) return SQLITE_ERROR;
+ config->nprobe = v;
+ nprobe_explicit = 1;
+ } else if (sqlite3_strnicmp(key, "quantizer", keyLength) == 0) {
+ if (token.token_type != TOKEN_TYPE_IDENTIFIER) return SQLITE_ERROR;
+ if (sqlite3_strnicmp(val, "none", valLength) == 0) {
+ config->quantizer = VEC0_IVF_QUANTIZER_NONE;
+ } else if (sqlite3_strnicmp(val, "int8", valLength) == 0) {
+ config->quantizer = VEC0_IVF_QUANTIZER_INT8;
+ } else if (sqlite3_strnicmp(val, "binary", valLength) == 0) {
+ config->quantizer = VEC0_IVF_QUANTIZER_BINARY;
+ } else {
+ return SQLITE_ERROR;
+ }
+ } else if (sqlite3_strnicmp(key, "oversample", keyLength) == 0) {
+ if (token.token_type != TOKEN_TYPE_DIGIT) return SQLITE_ERROR;
+ int v = atoi(val);
+ if (v < 1) return SQLITE_ERROR;
+ config->oversample = v;
+ } else {
+ return SQLITE_ERROR;
+ }
+
+ rc = vec0_scanner_next(scanner, &token);
+ if (rc != VEC0_TOKEN_RESULT_SOME) return SQLITE_ERROR;
+ if (token.token_type == TOKEN_TYPE_RPAREN) break;
+ if (token.token_type != TOKEN_TYPE_COMMA) return SQLITE_ERROR;
+ rc = vec0_scanner_next(scanner, &token);
+ }
+
+ if (config->nprobe < 0) config->nprobe = VEC0_IVF_DEFAULT_NPROBE;
+ if (config->nlist > 0 && config->nprobe > config->nlist) {
+ if (nprobe_explicit) return SQLITE_ERROR;
+ config->nprobe = config->nlist;
+ }
+
+ // Validation: oversample > 1 only makes sense with quantization
+ if (config->oversample > 1 && config->quantizer == VEC0_IVF_QUANTIZER_NONE) {
+ return SQLITE_ERROR;
+ }
+
+ return SQLITE_OK;
+}
+
+// ============================================================================
+// Helpers
+// ============================================================================
+
+/**
+ * Size of a stored vector in bytes, accounting for quantization.
+ */
+static int ivf_vec_size(vec0_vtab *p, int col_idx) {
+ int D = (int)p->vector_columns[col_idx].dimensions;
+ switch (p->vector_columns[col_idx].ivf.quantizer) {
+ case VEC0_IVF_QUANTIZER_INT8: return D;
+ case VEC0_IVF_QUANTIZER_BINARY: return D / 8;
+ default: return D * (int)sizeof(float);
+ }
+}
+
+/**
+ * Size of the full-precision vector in bytes (always float32).
+ */
+static int ivf_full_vec_size(vec0_vtab *p, int col_idx) {
+ return (int)(p->vector_columns[col_idx].dimensions * sizeof(float));
+}
+
+/**
+ * Quantize float32 vector to int8.
+ * Uses unit normalization: clamp to [-1,1], scale to [-127,127].
+ */
+IVF_STATIC void ivf_quantize_int8(const float *src, int8_t *dst, int D) {
+ for (int i = 0; i < D; i++) {
+ float v = src[i];
+ if (v > 1.0f) v = 1.0f;
+ if (v < -1.0f) v = -1.0f;
+ dst[i] = (int8_t)(v * 127.0f);
+ }
+}
+
+/**
+ * Quantize float32 vector to binary (sign-bit quantization).
+ * Each bit = 1 if src[i] > 0, else 0.
+ */
+IVF_STATIC void ivf_quantize_binary(const float *src, uint8_t *dst, int D) {
+ memset(dst, 0, D / 8);
+ for (int i = 0; i < D; i++) {
+ if (src[i] > 0.0f) {
+ dst[i / 8] |= (1 << (i % 8));
+ }
+ }
+}
+
+/**
+ * Quantize a float32 vector to the target type based on config.
+ * dst must be pre-allocated to ivf_vec_size() bytes.
+ * If quantizer=none, copies src as-is.
+ */
+static void ivf_quantize(vec0_vtab *p, int col_idx,
+ const float *src, void *dst) {
+ int D = (int)p->vector_columns[col_idx].dimensions;
+ switch (p->vector_columns[col_idx].ivf.quantizer) {
+ case VEC0_IVF_QUANTIZER_INT8:
+ ivf_quantize_int8(src, (int8_t *)dst, D);
+ break;
+ case VEC0_IVF_QUANTIZER_BINARY:
+ ivf_quantize_binary(src, (uint8_t *)dst, D);
+ break;
+ default:
+ memcpy(dst, src, D * sizeof(float));
+ break;
+ }
+}
+
+// Forward declaration
+static float ivf_distance(vec0_vtab *p, int col_idx, const void *a, const void *b);
+
+/**
+ * Find nearest centroid. Works with quantized or float centroids.
+ * vec and centroids must be in the same representation (both quantized or both float).
+ * vecSize = size of one vector in bytes.
+ */
+static int ivf_find_nearest_centroid(vec0_vtab *p, int col_idx,
+ const void *vec, const void *centroids,
+ int vecSize, int k) {
+ float min_dist = FLT_MAX;
+ int best = 0;
+ const unsigned char *cdata = (const unsigned char *)centroids;
+ for (int c = 0; c < k; c++) {
+ float dist = ivf_distance(p, col_idx, vec, cdata + c * vecSize);
+ if (dist < min_dist) { min_dist = dist; best = c; }
+ }
+ return best;
+}
+
+/**
+ * Compute distance between two vectors using the column's distance_metric.
+ * Dispatches to SIMD-optimized functions (NEON/AVX) via distance_*_float().
+ * For float32 (non-quantized) vectors.
+ */
+static float ivf_distance_float(vec0_vtab *p, int col_idx,
+ const float *a, const float *b) {
+ size_t dims = p->vector_columns[col_idx].dimensions;
+ switch (p->vector_columns[col_idx].distance_metric) {
+ case VEC0_DISTANCE_METRIC_COSINE:
+ return distance_cosine_float(a, b, &dims);
+ case VEC0_DISTANCE_METRIC_L1:
+ return (float)distance_l1_f32(a, b, &dims);
+ case VEC0_DISTANCE_METRIC_L2:
+ default:
+ return distance_l2_sqr_float(a, b, &dims);
+ }
+}
+
+/**
+ * Compute distance between two quantized vectors.
+ * For int8: uses L2 or cosine on int8.
+ * For binary: uses hamming distance.
+ * For none: delegates to ivf_distance_float.
+ */
+static float ivf_distance(vec0_vtab *p, int col_idx,
+ const void *a, const void *b) {
+ size_t dims = p->vector_columns[col_idx].dimensions;
+ switch (p->vector_columns[col_idx].ivf.quantizer) {
+ case VEC0_IVF_QUANTIZER_INT8:
+ return distance_l2_sqr_int8(a, b, &dims);
+ case VEC0_IVF_QUANTIZER_BINARY:
+ return distance_hamming(a, b, &dims);
+ default:
+ return ivf_distance_float(p, col_idx, (const float *)a, (const float *)b);
+ }
+}
+
+static int ivf_ensure_stmt(vec0_vtab *p, sqlite3_stmt **pStmt, const char *fmt,
+ int col_idx) {
+ if (*pStmt) return SQLITE_OK;
+ char *zSql = sqlite3_mprintf(fmt, p->schemaName, p->tableName, col_idx);
+ if (!zSql) return SQLITE_NOMEM;
+ int rc = sqlite3_prepare_v2(p->db, zSql, -1, pStmt, NULL);
+ sqlite3_free(zSql);
+ return rc;
+}
+
+static int ivf_exec(vec0_vtab *p, const char *fmt, int col_idx) {
+ sqlite3_stmt *stmt = NULL;
+ char *zSql = sqlite3_mprintf(fmt, p->schemaName, p->tableName, col_idx);
+ if (!zSql) return SQLITE_NOMEM;
+ int rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL);
+ sqlite3_free(zSql);
+ if (rc == SQLITE_OK) { sqlite3_step(stmt); sqlite3_finalize(stmt); }
+ return SQLITE_OK;
+}
+
+static int ivf_is_trained(vec0_vtab *p, int col_idx) {
+ if (p->ivfTrainedCache[col_idx] >= 0) return p->ivfTrainedCache[col_idx];
+ sqlite3_stmt *stmt = NULL;
+ int trained = 0;
+ char *zSql = sqlite3_mprintf(
+ "SELECT value FROM " VEC0_SHADOW_INFO_NAME " WHERE key = 'ivf_trained_%d'",
+ p->schemaName, p->tableName, col_idx);
+ if (!zSql) return 0;
+ if (sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL) == SQLITE_OK) {
+ if (sqlite3_step(stmt) == SQLITE_ROW)
+ trained = (sqlite3_column_int(stmt, 0) == 1);
+ }
+ sqlite3_free(zSql);
+ sqlite3_finalize(stmt);
+ p->ivfTrainedCache[col_idx] = trained;
+ return trained;
+}
+
+// ============================================================================
+// Cell operations — fixed-size cells, multiple rows per centroid
+// ============================================================================
+
+/**
+ * Create a new cell row. Returns the new cell_id (rowid) via *out_cell_id.
+ */
+static int ivf_cell_create(vec0_vtab *p, int col_idx, i64 centroid_id,
+ i64 *out_cell_id) {
+ sqlite3_stmt *stmt = NULL;
+ int rc;
+ int cap = VEC0_IVF_CELL_MAX_VECTORS;
+ int vecSize = ivf_vec_size(p, col_idx);
+ char *zSql = sqlite3_mprintf(
+ "INSERT INTO " VEC0_SHADOW_IVF_CELLS_NAME
+ " (centroid_id, n_vectors, validity, rowids, vectors) VALUES (?, 0, ?, ?, ?)",
+ p->schemaName, p->tableName, col_idx);
+ if (!zSql) return SQLITE_NOMEM;
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL);
+ sqlite3_free(zSql);
+ if (rc != SQLITE_OK) return rc;
+ sqlite3_bind_int64(stmt, 1, centroid_id);
+ sqlite3_bind_zeroblob(stmt, 2, cap / 8);
+ sqlite3_bind_zeroblob(stmt, 3, cap * (int)sizeof(i64));
+ sqlite3_bind_zeroblob(stmt, 4, cap * vecSize);
+ rc = sqlite3_step(stmt);
+ sqlite3_finalize(stmt);
+ if (rc != SQLITE_DONE) return SQLITE_ERROR;
+ if (out_cell_id) *out_cell_id = sqlite3_last_insert_rowid(p->db);
+ return SQLITE_OK;
+}
+
+/**
+ * Find a cell with space for the given centroid, or create one.
+ * Returns cell_id (rowid) and current n_vectors.
+ */
+static int ivf_cell_find_or_create(vec0_vtab *p, int col_idx, i64 centroid_id,
+ i64 *out_cell_id, int *out_n) {
+ int rc;
+ // Find existing cell with space
+ rc = ivf_ensure_stmt(p, &p->stmtIvfCellMeta[col_idx],
+ "SELECT rowid, n_vectors FROM " VEC0_SHADOW_IVF_CELLS_NAME
+ " WHERE centroid_id = ? AND n_vectors < %d LIMIT 1",
+ col_idx);
+ // The %d in the format won't work with ivf_ensure_stmt since it only has 3
+ // format args. Use a direct approach instead.
+ sqlite3_finalize(p->stmtIvfCellMeta[col_idx]);
+ p->stmtIvfCellMeta[col_idx] = NULL;
+
+ char *zSql = sqlite3_mprintf(
+ "SELECT rowid, n_vectors FROM " VEC0_SHADOW_IVF_CELLS_NAME
+ " WHERE centroid_id = ? AND n_vectors < %d LIMIT 1",
+ p->schemaName, p->tableName, col_idx, VEC0_IVF_CELL_MAX_VECTORS);
+ if (!zSql) return SQLITE_NOMEM;
+ // Cache this manually
+ if (!p->stmtIvfCellMeta[col_idx]) {
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &p->stmtIvfCellMeta[col_idx], NULL);
+ sqlite3_free(zSql);
+ if (rc != SQLITE_OK) return rc;
+ } else {
+ sqlite3_free(zSql);
+ }
+
+ sqlite3_stmt *stmt = p->stmtIvfCellMeta[col_idx];
+ sqlite3_reset(stmt);
+ sqlite3_bind_int64(stmt, 1, centroid_id);
+
+ if (sqlite3_step(stmt) == SQLITE_ROW) {
+ *out_cell_id = sqlite3_column_int64(stmt, 0);
+ *out_n = sqlite3_column_int(stmt, 1);
+ return SQLITE_OK;
+ }
+
+ // No cell with space — create new one
+ rc = ivf_cell_create(p, col_idx, centroid_id, out_cell_id);
+ *out_n = 0;
+ return rc;
+}
+
+/**
+ * Insert vector into cell at slot = n_vectors (append).
+ * Cell must have space (n_vectors < VEC0_IVF_CELL_MAX_VECTORS).
+ */
+static int ivf_cell_insert(vec0_vtab *p, int col_idx, i64 centroid_id,
+ i64 rowid, const void *vectorData, int vectorSize) {
+ int rc;
+ i64 cell_id;
+ int n_vectors;
+
+ rc = ivf_cell_find_or_create(p, col_idx, centroid_id, &cell_id, &n_vectors);
+ if (rc != SQLITE_OK) return rc;
+
+ int slot = n_vectors;
+ char *cellsTable = p->shadowIvfCellsNames[col_idx];
+
+ // Set validity bit
+ sqlite3_blob *blob = NULL;
+ rc = sqlite3_blob_open(p->db, p->schemaName, cellsTable, "validity",
+ cell_id, 1, &blob);
+ if (rc != SQLITE_OK) return rc;
+ unsigned char bx;
+ sqlite3_blob_read(blob, &bx, 1, slot / 8);
+ bx |= (1 << (slot % 8));
+ sqlite3_blob_write(blob, &bx, 1, slot / 8);
+ sqlite3_blob_close(blob);
+
+ // Write rowid
+ rc = sqlite3_blob_open(p->db, p->schemaName, cellsTable, "rowids",
+ cell_id, 1, &blob);
+ if (rc == SQLITE_OK) {
+ sqlite3_blob_write(blob, &rowid, sizeof(i64), slot * (int)sizeof(i64));
+ sqlite3_blob_close(blob);
+ }
+
+ // Write vector
+ rc = sqlite3_blob_open(p->db, p->schemaName, cellsTable, "vectors",
+ cell_id, 1, &blob);
+ if (rc == SQLITE_OK) {
+ sqlite3_blob_write(blob, vectorData, vectorSize, slot * vectorSize);
+ sqlite3_blob_close(blob);
+ }
+
+ // Increment n_vectors (cached)
+ ivf_ensure_stmt(p, &p->stmtIvfCellUpdateN[col_idx],
+ "UPDATE " VEC0_SHADOW_IVF_CELLS_NAME
+ " SET n_vectors = n_vectors + 1 WHERE rowid = ?", col_idx);
+ if (p->stmtIvfCellUpdateN[col_idx]) {
+ sqlite3_stmt *s = p->stmtIvfCellUpdateN[col_idx];
+ sqlite3_reset(s);
+ sqlite3_bind_int64(s, 1, cell_id);
+ sqlite3_step(s);
+ }
+
+ // Insert rowid_map (cached)
+ ivf_ensure_stmt(p, &p->stmtIvfRowidMapInsert[col_idx],
+ "INSERT INTO " VEC0_SHADOW_IVF_ROWID_MAP_NAME
+ " (rowid, cell_id, slot) VALUES (?, ?, ?)", col_idx);
+ if (p->stmtIvfRowidMapInsert[col_idx]) {
+ sqlite3_stmt *s = p->stmtIvfRowidMapInsert[col_idx];
+ sqlite3_reset(s);
+ sqlite3_bind_int64(s, 1, rowid);
+ sqlite3_bind_int64(s, 2, cell_id);
+ sqlite3_bind_int(s, 3, slot);
+ sqlite3_step(s);
+ }
+
+ return SQLITE_OK;
+}
+
+// ============================================================================
+// Shadow tables
+// ============================================================================
+
+static int ivf_create_shadow_tables(vec0_vtab *p, int col_idx) {
+ sqlite3_stmt *stmt = NULL;
+ int rc;
+ char *zSql;
+
+ zSql = sqlite3_mprintf(
+ "CREATE TABLE " VEC0_SHADOW_IVF_CENTROIDS_NAME
+ " (centroid_id INTEGER PRIMARY KEY, centroid BLOB NOT NULL)",
+ p->schemaName, p->tableName, col_idx);
+ if (!zSql) return SQLITE_NOMEM;
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql);
+ if (rc != SQLITE_OK || sqlite3_step(stmt) != SQLITE_DONE) { sqlite3_finalize(stmt); return SQLITE_ERROR; }
+ sqlite3_finalize(stmt);
+
+ // cell_id is rowid (auto-increment), centroid_id is indexed
+ zSql = sqlite3_mprintf(
+ "CREATE TABLE " VEC0_SHADOW_IVF_CELLS_NAME
+ " (centroid_id INTEGER NOT NULL,"
+ " n_vectors INTEGER NOT NULL DEFAULT 0,"
+ " validity BLOB NOT NULL,"
+ " rowids BLOB NOT NULL,"
+ " vectors BLOB NOT NULL)",
+ p->schemaName, p->tableName, col_idx);
+ if (!zSql) return SQLITE_NOMEM;
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql);
+ if (rc != SQLITE_OK || sqlite3_step(stmt) != SQLITE_DONE) { sqlite3_finalize(stmt); return SQLITE_ERROR; }
+ sqlite3_finalize(stmt);
+
+ // Index on centroid_id for cell lookup
+ zSql = sqlite3_mprintf(
+ "CREATE INDEX \"%w_ivf_cells%02d_centroid\" ON \"%w_ivf_cells%02d\" (centroid_id)",
+ p->tableName, col_idx, p->tableName, col_idx);
+ if (!zSql) return SQLITE_NOMEM;
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql);
+ if (rc != SQLITE_OK || sqlite3_step(stmt) != SQLITE_DONE) { sqlite3_finalize(stmt); return SQLITE_ERROR; }
+ sqlite3_finalize(stmt);
+
+ zSql = sqlite3_mprintf(
+ "CREATE TABLE " VEC0_SHADOW_IVF_ROWID_MAP_NAME
+ " (rowid INTEGER PRIMARY KEY, cell_id INTEGER NOT NULL, slot INTEGER NOT NULL)",
+ p->schemaName, p->tableName, col_idx);
+ if (!zSql) return SQLITE_NOMEM;
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql);
+ if (rc != SQLITE_OK || sqlite3_step(stmt) != SQLITE_DONE) { sqlite3_finalize(stmt); return SQLITE_ERROR; }
+ sqlite3_finalize(stmt);
+
+ // _ivf_vectors — full-precision KV store (only when quantizer != none)
+ if (p->vector_columns[col_idx].ivf.quantizer != VEC0_IVF_QUANTIZER_NONE) {
+ zSql = sqlite3_mprintf(
+ "CREATE TABLE " VEC0_SHADOW_IVF_VECTORS_NAME
+ " (rowid INTEGER PRIMARY KEY, vector BLOB NOT NULL)",
+ p->schemaName, p->tableName, col_idx);
+ if (!zSql) return SQLITE_NOMEM;
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql);
+ if (rc != SQLITE_OK || sqlite3_step(stmt) != SQLITE_DONE) { sqlite3_finalize(stmt); return SQLITE_ERROR; }
+ sqlite3_finalize(stmt);
+ }
+
+ zSql = sqlite3_mprintf(
+ "INSERT INTO " VEC0_SHADOW_INFO_NAME " (key, value) VALUES ('ivf_trained_%d', '0')",
+ p->schemaName, p->tableName, col_idx);
+ if (!zSql) return SQLITE_NOMEM;
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql);
+ if (rc != SQLITE_OK || sqlite3_step(stmt) != SQLITE_DONE) { sqlite3_finalize(stmt); return SQLITE_ERROR; }
+ sqlite3_finalize(stmt);
+
+ return SQLITE_OK;
+}
+
+static int ivf_drop_shadow_tables(vec0_vtab *p, int col_idx) {
+ ivf_exec(p, "DROP TABLE IF EXISTS " VEC0_SHADOW_IVF_CENTROIDS_NAME, col_idx);
+ ivf_exec(p, "DROP TABLE IF EXISTS " VEC0_SHADOW_IVF_CELLS_NAME, col_idx);
+ ivf_exec(p, "DROP TABLE IF EXISTS " VEC0_SHADOW_IVF_ROWID_MAP_NAME, col_idx);
+ ivf_exec(p, "DROP TABLE IF EXISTS " VEC0_SHADOW_IVF_VECTORS_NAME, col_idx);
+ return SQLITE_OK;
+}
+
+// ============================================================================
+// Insert / Delete
+// ============================================================================
+
+static int ivf_insert(vec0_vtab *p, int col_idx, i64 rowid,
+ const void *vectorData, int vectorSize) {
+ UNUSED_PARAMETER(vectorSize);
+ int quantizer = p->vector_columns[col_idx].ivf.quantizer;
+ int qvecSize = ivf_vec_size(p, col_idx);
+ int rc;
+
+ // Quantize the input vector (or copy as-is if no quantization)
+ void *qvec = sqlite3_malloc(qvecSize);
+ if (!qvec) return SQLITE_NOMEM;
+ ivf_quantize(p, col_idx, (const float *)vectorData, qvec);
+
+ if (!ivf_is_trained(p, col_idx)) {
+ rc = ivf_cell_insert(p, col_idx, VEC0_IVF_UNASSIGNED_CENTROID_ID,
+ rowid, qvec, qvecSize);
+ } else {
+ // Find nearest centroid using quantized distance
+ int best_centroid = -1;
+ float min_dist = FLT_MAX;
+
+ rc = ivf_ensure_stmt(p, &p->stmtIvfCentroidsAll[col_idx],
+ "SELECT centroid_id, centroid FROM " VEC0_SHADOW_IVF_CENTROIDS_NAME, col_idx);
+ if (rc != SQLITE_OK) { sqlite3_free(qvec); return rc; }
+ sqlite3_stmt *stmt = p->stmtIvfCentroidsAll[col_idx];
+ sqlite3_reset(stmt);
+ while (sqlite3_step(stmt) == SQLITE_ROW) {
+ int cid = sqlite3_column_int(stmt, 0);
+ const void *c = sqlite3_column_blob(stmt, 1);
+ int cBytes = sqlite3_column_bytes(stmt, 1);
+ if (!c || cBytes != qvecSize) continue;
+ float dist = ivf_distance(p, col_idx, qvec, c);
+ if (dist < min_dist) { min_dist = dist; best_centroid = cid; }
+ }
+ if (best_centroid < 0) { sqlite3_free(qvec); return SQLITE_ERROR; }
+
+ rc = ivf_cell_insert(p, col_idx, best_centroid, rowid, qvec, qvecSize);
+ }
+
+ sqlite3_free(qvec);
+ if (rc != SQLITE_OK) return rc;
+
+ // Store full-precision vector in KV table when quantized
+ if (quantizer != VEC0_IVF_QUANTIZER_NONE) {
+ sqlite3_stmt *stmt = NULL;
+ char *zSql = sqlite3_mprintf(
+ "INSERT INTO " VEC0_SHADOW_IVF_VECTORS_NAME " (rowid, vector) VALUES (?, ?)",
+ p->schemaName, p->tableName, col_idx);
+ if (!zSql) return SQLITE_NOMEM;
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql);
+ if (rc != SQLITE_OK) return rc;
+ sqlite3_bind_int64(stmt, 1, rowid);
+ sqlite3_bind_blob(stmt, 2, vectorData, ivf_full_vec_size(p, col_idx), SQLITE_STATIC);
+ rc = sqlite3_step(stmt);
+ sqlite3_finalize(stmt);
+ if (rc != SQLITE_DONE) return SQLITE_ERROR;
+ }
+
+ return SQLITE_OK;
+}
+
+static int ivf_delete(vec0_vtab *p, int col_idx, i64 rowid) {
+ int rc;
+ i64 cell_id = 0;
+ int slot = -1;
+
+ rc = ivf_ensure_stmt(p, &p->stmtIvfRowidMapLookup[col_idx],
+ "SELECT cell_id, slot FROM " VEC0_SHADOW_IVF_ROWID_MAP_NAME
+ " WHERE rowid = ?", col_idx);
+ if (rc != SQLITE_OK) return rc;
+ sqlite3_stmt *s = p->stmtIvfRowidMapLookup[col_idx];
+ sqlite3_reset(s);
+ sqlite3_bind_int64(s, 1, rowid);
+ if (sqlite3_step(s) == SQLITE_ROW) {
+ cell_id = sqlite3_column_int64(s, 0);
+ slot = sqlite3_column_int(s, 1);
+ }
+ if (slot < 0) return SQLITE_OK;
+
+ // Clear validity bit
+ char *cellsTable = p->shadowIvfCellsNames[col_idx];
+ sqlite3_blob *blob = NULL;
+ rc = sqlite3_blob_open(p->db, p->schemaName, cellsTable, "validity",
+ cell_id, 1, &blob);
+ if (rc == SQLITE_OK) {
+ unsigned char bx;
+ sqlite3_blob_read(blob, &bx, 1, slot / 8);
+ bx &= ~(1 << (slot % 8));
+ sqlite3_blob_write(blob, &bx, 1, slot / 8);
+ sqlite3_blob_close(blob);
+ }
+
+ // Decrement n_vectors
+ if (p->stmtIvfCellUpdateN[col_idx]) {
+ // This stmt does +1, but we want -1. Use a different cached stmt.
+ }
+ // Just use inline for decrement (not hot path)
+ {
+ sqlite3_stmt *stmtDec = NULL;
+ char *zSql = sqlite3_mprintf(
+ "UPDATE " VEC0_SHADOW_IVF_CELLS_NAME
+ " SET n_vectors = n_vectors - 1 WHERE rowid = ?",
+ p->schemaName, p->tableName, col_idx);
+ if (zSql) {
+ sqlite3_prepare_v2(p->db, zSql, -1, &stmtDec, NULL); sqlite3_free(zSql);
+ if (stmtDec) { sqlite3_bind_int64(stmtDec, 1, cell_id); sqlite3_step(stmtDec); sqlite3_finalize(stmtDec); }
+ }
+ }
+
+ // Delete from rowid_map
+ ivf_ensure_stmt(p, &p->stmtIvfRowidMapDelete[col_idx],
+ "DELETE FROM " VEC0_SHADOW_IVF_ROWID_MAP_NAME " WHERE rowid = ?", col_idx);
+ if (p->stmtIvfRowidMapDelete[col_idx]) {
+ sqlite3_stmt *sd = p->stmtIvfRowidMapDelete[col_idx];
+ sqlite3_reset(sd);
+ sqlite3_bind_int64(sd, 1, rowid);
+ sqlite3_step(sd);
+ }
+
+ // Delete from _ivf_vectors (full-precision KV) when quantized
+ if (p->vector_columns[col_idx].ivf.quantizer != VEC0_IVF_QUANTIZER_NONE) {
+ sqlite3_stmt *stmtDelVec = NULL;
+ char *zSql = sqlite3_mprintf(
+ "DELETE FROM " VEC0_SHADOW_IVF_VECTORS_NAME " WHERE rowid = ?",
+ p->schemaName, p->tableName, col_idx);
+ if (zSql) {
+ sqlite3_prepare_v2(p->db, zSql, -1, &stmtDelVec, NULL); sqlite3_free(zSql);
+ if (stmtDelVec) { sqlite3_bind_int64(stmtDelVec, 1, rowid); sqlite3_step(stmtDelVec); sqlite3_finalize(stmtDelVec); }
+ }
+ }
+
+ return SQLITE_OK;
+}
+
+// ============================================================================
+// Point query
+// ============================================================================
+
+static int ivf_get_vector_data(vec0_vtab *p, i64 rowid, int col_idx,
+ void **outVector, int *outVectorSize) {
+ int rc;
+ int vecSize = ivf_vec_size(p, col_idx);
+ i64 cell_id = 0;
+ int slot = -1;
+
+ rc = ivf_ensure_stmt(p, &p->stmtIvfRowidMapLookup[col_idx],
+ "SELECT cell_id, slot FROM " VEC0_SHADOW_IVF_ROWID_MAP_NAME
+ " WHERE rowid = ?", col_idx);
+ if (rc != SQLITE_OK) return rc;
+ sqlite3_stmt *s = p->stmtIvfRowidMapLookup[col_idx];
+ sqlite3_reset(s);
+ sqlite3_bind_int64(s, 1, rowid);
+ if (sqlite3_step(s) != SQLITE_ROW) return SQLITE_EMPTY;
+ cell_id = sqlite3_column_int64(s, 0);
+ slot = sqlite3_column_int(s, 1);
+
+ void *buf = sqlite3_malloc(vecSize);
+ if (!buf) return SQLITE_NOMEM;
+
+ sqlite3_blob *blob = NULL;
+ rc = sqlite3_blob_open(p->db, p->schemaName, p->shadowIvfCellsNames[col_idx],
+ "vectors", cell_id, 0, &blob);
+ if (rc != SQLITE_OK) { sqlite3_free(buf); return rc; }
+ rc = sqlite3_blob_read(blob, buf, vecSize, slot * vecSize);
+ sqlite3_blob_close(blob);
+ if (rc != SQLITE_OK) { sqlite3_free(buf); return rc; }
+
+ *outVector = buf;
+ if (outVectorSize) *outVectorSize = vecSize;
+ return SQLITE_OK;
+}
+
+// ============================================================================
+// Centroid commands
+// ============================================================================
+
+static int ivf_load_all_vectors(vec0_vtab *p, int col_idx,
+ float **out_vectors, i64 **out_rowids, int *out_N) {
+ sqlite3_stmt *stmt = NULL;
+ int rc;
+ int D = (int)p->vector_columns[col_idx].dimensions;
+ int vecSize = D * (int)sizeof(float);
+ int quantizer = p->vector_columns[col_idx].ivf.quantizer;
+
+ // When quantized, load full-precision vectors from _ivf_vectors KV table
+ if (quantizer != VEC0_IVF_QUANTIZER_NONE) {
+ int total = 0;
+ char *zSql = sqlite3_mprintf(
+ "SELECT count(*) FROM " VEC0_SHADOW_IVF_VECTORS_NAME,
+ p->schemaName, p->tableName, col_idx);
+ if (!zSql) return SQLITE_NOMEM;
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql);
+ if (rc == SQLITE_OK && sqlite3_step(stmt) == SQLITE_ROW) total = sqlite3_column_int(stmt, 0);
+ sqlite3_finalize(stmt);
+ if (total == 0) { *out_vectors = NULL; *out_rowids = NULL; *out_N = 0; return SQLITE_OK; }
+
+ float *vectors = sqlite3_malloc64((i64)total * D * sizeof(float));
+ i64 *rowids = sqlite3_malloc64((i64)total * sizeof(i64));
+ if (!vectors || !rowids) { sqlite3_free(vectors); sqlite3_free(rowids); return SQLITE_NOMEM; }
+
+ int idx = 0;
+ zSql = sqlite3_mprintf(
+ "SELECT rowid, vector FROM " VEC0_SHADOW_IVF_VECTORS_NAME,
+ p->schemaName, p->tableName, col_idx);
+ if (!zSql) { sqlite3_free(vectors); sqlite3_free(rowids); return SQLITE_NOMEM; }
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql);
+ if (rc == SQLITE_OK) {
+ while (sqlite3_step(stmt) == SQLITE_ROW && idx < total) {
+ rowids[idx] = sqlite3_column_int64(stmt, 0);
+ const void *blob = sqlite3_column_blob(stmt, 1);
+ int blobBytes = sqlite3_column_bytes(stmt, 1);
+ if (blob && blobBytes == vecSize) {
+ memcpy(&vectors[idx * D], blob, vecSize);
+ idx++;
+ }
+ }
+ }
+ sqlite3_finalize(stmt);
+ *out_vectors = vectors; *out_rowids = rowids; *out_N = idx;
+ return SQLITE_OK;
+ }
+
+ // Non-quantized: load from cells (existing path)
+
+ // Count total
+ int total = 0;
+ char *zSql = sqlite3_mprintf(
+ "SELECT COALESCE(SUM(n_vectors),0) FROM " VEC0_SHADOW_IVF_CELLS_NAME,
+ p->schemaName, p->tableName, col_idx);
+ if (!zSql) return SQLITE_NOMEM;
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql);
+ if (rc == SQLITE_OK && sqlite3_step(stmt) == SQLITE_ROW) total = sqlite3_column_int(stmt, 0);
+ sqlite3_finalize(stmt);
+
+ if (total == 0) { *out_vectors = NULL; *out_rowids = NULL; *out_N = 0; return SQLITE_OK; }
+
+ float *vectors = sqlite3_malloc64((i64)total * D * sizeof(float));
+ i64 *rowids = sqlite3_malloc64((i64)total * sizeof(i64));
+ if (!vectors || !rowids) { sqlite3_free(vectors); sqlite3_free(rowids); return SQLITE_NOMEM; }
+
+ int idx = 0;
+ zSql = sqlite3_mprintf(
+ "SELECT n_vectors, validity, rowids, vectors FROM " VEC0_SHADOW_IVF_CELLS_NAME,
+ p->schemaName, p->tableName, col_idx);
+ if (!zSql) { sqlite3_free(vectors); sqlite3_free(rowids); return SQLITE_NOMEM; }
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql);
+ if (rc != SQLITE_OK) { sqlite3_free(vectors); sqlite3_free(rowids); return rc; }
+
+ while (sqlite3_step(stmt) == SQLITE_ROW) {
+ int n = sqlite3_column_int(stmt, 0);
+ if (n == 0) continue;
+ const unsigned char *val = (const unsigned char *)sqlite3_column_blob(stmt, 1);
+ const i64 *rids = (const i64 *)sqlite3_column_blob(stmt, 2);
+ const float *vecs = (const float *)sqlite3_column_blob(stmt, 3);
+ int valBytes = sqlite3_column_bytes(stmt, 1);
+ int ridsBytes = sqlite3_column_bytes(stmt, 2);
+ int vecsBytes = sqlite3_column_bytes(stmt, 3);
+ if (!val || !rids || !vecs) continue;
+ int cap = valBytes * 8;
+ // Clamp cap to the number of entries actually backed by the rowids and vectors blobs
+ if (ridsBytes / (int)sizeof(i64) < cap) cap = ridsBytes / (int)sizeof(i64);
+ if (vecsBytes / vecSize < cap) cap = vecsBytes / vecSize;
+ for (int i = 0; i < cap && idx < total; i++) {
+ if (val[i / 8] & (1 << (i % 8))) {
+ rowids[idx] = rids[i];
+ memcpy(&vectors[idx * D], &vecs[i * D], vecSize);
+ idx++;
+ }
+ }
+ }
+ sqlite3_finalize(stmt);
+ *out_vectors = vectors; *out_rowids = rowids; *out_N = idx;
+ return SQLITE_OK;
+}
+
+static void ivf_invalidate_cached(vec0_vtab *p, int col_idx) {
+ sqlite3_finalize(p->stmtIvfCellMeta[col_idx]); p->stmtIvfCellMeta[col_idx] = NULL;
+ sqlite3_finalize(p->stmtIvfCentroidsAll[col_idx]); p->stmtIvfCentroidsAll[col_idx] = NULL;
+ sqlite3_finalize(p->stmtIvfCellUpdateN[col_idx]); p->stmtIvfCellUpdateN[col_idx] = NULL;
+ sqlite3_finalize(p->stmtIvfRowidMapInsert[col_idx]); p->stmtIvfRowidMapInsert[col_idx] = NULL;
+}
+
+static int ivf_cmd_compute_centroids(vec0_vtab *p, int col_idx, int nlist_override,
+ int max_iter, uint32_t seed) {
+ int rc;
+ int D = (int)p->vector_columns[col_idx].dimensions;
+ int vecSize = D * (int)sizeof(float);
+ int quantizer = p->vector_columns[col_idx].ivf.quantizer;
+ int nlist = nlist_override > 0 ? nlist_override : p->vector_columns[col_idx].ivf.nlist;
+ if (nlist <= 0) { vtab_set_error(&p->base, "nlist must be specified"); return SQLITE_ERROR; }
+
+ float *vectors = NULL; i64 *rowids = NULL; int N = 0;
+ rc = ivf_load_all_vectors(p, col_idx, &vectors, &rowids, &N);
+ if (rc != SQLITE_OK) return rc;
+ if (N == 0) { vtab_set_error(&p->base, "No vectors"); sqlite3_free(vectors); sqlite3_free(rowids); return SQLITE_ERROR; }
+ if (nlist > N) nlist = N;
+
+ float *centroids = sqlite3_malloc64((i64)nlist * D * sizeof(float));
+ if (!centroids) { sqlite3_free(vectors); sqlite3_free(rowids); return SQLITE_NOMEM; }
+ if (ivf_kmeans(vectors, N, D, nlist, max_iter, seed, centroids) != 0) {
+ sqlite3_free(vectors); sqlite3_free(rowids); sqlite3_free(centroids); return SQLITE_ERROR;
+ }
+
+ // Compute assignments
+ int *assignments = sqlite3_malloc64((i64)N * sizeof(int));
+ if (!assignments) { sqlite3_free(vectors); sqlite3_free(rowids); sqlite3_free(centroids); return SQLITE_NOMEM; }
+ // Assignment uses float32 distances (k-means operates in float32 space)
+ for (int i = 0; i < N; i++) {
+ float min_d = FLT_MAX;
+ int best = 0;
+ for (int c = 0; c < nlist; c++) {
+ float d = ivf_distance_float(p, col_idx, &vectors[i * D], ¢roids[c * D]);
+ if (d < min_d) { min_d = d; best = c; }
+ }
+ assignments[i] = best;
+ }
+
+ // Invalidate all cached stmts before dropping/recreating tables
+ ivf_invalidate_cached(p, col_idx);
+
+ sqlite3_exec(p->db, "SAVEPOINT ivf_train", NULL, NULL, NULL);
+ sqlite3_stmt *stmt = NULL;
+ char *zSql;
+
+ // Clear all data
+ ivf_exec(p, "DELETE FROM " VEC0_SHADOW_IVF_CENTROIDS_NAME, col_idx);
+ ivf_exec(p, "DELETE FROM " VEC0_SHADOW_IVF_CELLS_NAME, col_idx);
+ ivf_exec(p, "DELETE FROM " VEC0_SHADOW_IVF_ROWID_MAP_NAME, col_idx);
+
+ // Write centroids (quantized if quantizer is set)
+ int qvecSize = ivf_vec_size(p, col_idx);
+ void *qbuf = sqlite3_malloc(qvecSize > vecSize ? qvecSize : vecSize);
+ if (!qbuf) { rc = SQLITE_NOMEM; goto train_error; }
+
+ zSql = sqlite3_mprintf(
+ "INSERT INTO " VEC0_SHADOW_IVF_CENTROIDS_NAME " (centroid_id, centroid) VALUES (?, ?)",
+ p->schemaName, p->tableName, col_idx);
+ if (!zSql) { sqlite3_free(qbuf); rc = SQLITE_NOMEM; goto train_error; }
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql);
+ if (rc != SQLITE_OK) { sqlite3_free(qbuf); goto train_error; }
+ for (int i = 0; i < nlist; i++) {
+ ivf_quantize(p, col_idx, ¢roids[i * D], qbuf);
+ sqlite3_reset(stmt);
+ sqlite3_bind_int(stmt, 1, i);
+ sqlite3_bind_blob(stmt, 2, qbuf, qvecSize, SQLITE_TRANSIENT);
+ if (sqlite3_step(stmt) != SQLITE_DONE) { sqlite3_finalize(stmt); sqlite3_free(qbuf); rc = SQLITE_ERROR; goto train_error; }
+ }
+ sqlite3_finalize(stmt);
+
+ // Build cells: group vectors by centroid, create fixed-size cells
+ {
+ // Prepare INSERT statements
+ sqlite3_stmt *stmtCell = NULL;
+ zSql = sqlite3_mprintf(
+ "INSERT INTO " VEC0_SHADOW_IVF_CELLS_NAME
+ " (centroid_id, n_vectors, validity, rowids, vectors) VALUES (?, ?, ?, ?, ?)",
+ p->schemaName, p->tableName, col_idx);
+ if (!zSql) { rc = SQLITE_NOMEM; goto train_error; }
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmtCell, NULL); sqlite3_free(zSql);
+ if (rc != SQLITE_OK) goto train_error;
+
+ sqlite3_stmt *stmtMap = NULL;
+ zSql = sqlite3_mprintf(
+ "INSERT INTO " VEC0_SHADOW_IVF_ROWID_MAP_NAME " (rowid, cell_id, slot) VALUES (?, ?, ?)",
+ p->schemaName, p->tableName, col_idx);
+ if (!zSql) { sqlite3_finalize(stmtCell); rc = SQLITE_NOMEM; goto train_error; }
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmtMap, NULL); sqlite3_free(zSql);
+ if (rc != SQLITE_OK) { sqlite3_finalize(stmtCell); goto train_error; }
+
+ int cap = VEC0_IVF_CELL_MAX_VECTORS;
+ unsigned char *val = sqlite3_malloc(cap / 8);
+ i64 *rids = sqlite3_malloc64((i64)cap * sizeof(i64));
+ unsigned char *vecs = sqlite3_malloc64((i64)cap * qvecSize); // quantized size
+ if (!val || !rids || !vecs) {
+ sqlite3_free(val); sqlite3_free(rids); sqlite3_free(vecs);
+ sqlite3_finalize(stmtCell); sqlite3_finalize(stmtMap);
+ sqlite3_free(qbuf);
+ rc = SQLITE_NOMEM; goto train_error;
+ }
+
+ // Process one centroid at a time
+ for (int c = 0; c < nlist; c++) {
+ int slot = 0;
+ memset(val, 0, cap / 8);
+ memset(rids, 0, cap * sizeof(i64));
+
+ for (int i = 0; i < N; i++) {
+ if (assignments[i] != c) continue;
+
+ if (slot >= cap) {
+ // Flush current cell
+ sqlite3_reset(stmtCell);
+ sqlite3_bind_int(stmtCell, 1, c);
+ sqlite3_bind_int(stmtCell, 2, slot);
+ sqlite3_bind_blob(stmtCell, 3, val, cap / 8, SQLITE_TRANSIENT);
+ sqlite3_bind_blob(stmtCell, 4, rids, cap * (int)sizeof(i64), SQLITE_TRANSIENT);
+ sqlite3_bind_blob(stmtCell, 5, vecs, cap * qvecSize, SQLITE_TRANSIENT);
+ sqlite3_step(stmtCell);
+ i64 flushed_cell_id = sqlite3_last_insert_rowid(p->db);
+
+ for (int s = 0; s < slot; s++) {
+ sqlite3_reset(stmtMap);
+ sqlite3_bind_int64(stmtMap, 1, rids[s]);
+ sqlite3_bind_int64(stmtMap, 2, flushed_cell_id);
+ sqlite3_bind_int(stmtMap, 3, s);
+ sqlite3_step(stmtMap);
+ }
+
+ slot = 0;
+ memset(val, 0, cap / 8);
+ memset(rids, 0, cap * sizeof(i64));
+ }
+
+ val[slot / 8] |= (1 << (slot % 8));
+ rids[slot] = rowids[i];
+ // Quantize float32 vector into cell blob
+ ivf_quantize(p, col_idx, &vectors[i * D], &vecs[slot * qvecSize]);
+ slot++;
+ }
+
+ // Flush remaining
+ if (slot > 0) {
+ sqlite3_reset(stmtCell);
+ sqlite3_bind_int(stmtCell, 1, c);
+ sqlite3_bind_int(stmtCell, 2, slot);
+ sqlite3_bind_blob(stmtCell, 3, val, cap / 8, SQLITE_TRANSIENT);
+ sqlite3_bind_blob(stmtCell, 4, rids, cap * (int)sizeof(i64), SQLITE_TRANSIENT);
+ sqlite3_bind_blob(stmtCell, 5, vecs, cap * qvecSize, SQLITE_TRANSIENT);
+ sqlite3_step(stmtCell);
+ i64 flushed_cell_id = sqlite3_last_insert_rowid(p->db);
+
+ for (int s = 0; s < slot; s++) {
+ sqlite3_reset(stmtMap);
+ sqlite3_bind_int64(stmtMap, 1, rids[s]);
+ sqlite3_bind_int64(stmtMap, 2, flushed_cell_id);
+ sqlite3_bind_int(stmtMap, 3, s);
+ sqlite3_step(stmtMap);
+ }
+ }
+ }
+
+ sqlite3_free(val); sqlite3_free(rids); sqlite3_free(vecs);
+ sqlite3_finalize(stmtCell); sqlite3_finalize(stmtMap);
+ }
+
+ sqlite3_free(qbuf);
+
+ // Store full-precision vectors in _ivf_vectors when quantized
+ if (quantizer != VEC0_IVF_QUANTIZER_NONE) {
+ ivf_exec(p, "DELETE FROM " VEC0_SHADOW_IVF_VECTORS_NAME, col_idx);
+ zSql = sqlite3_mprintf(
+ "INSERT INTO " VEC0_SHADOW_IVF_VECTORS_NAME " (rowid, vector) VALUES (?, ?)",
+ p->schemaName, p->tableName, col_idx);
+ if (!zSql) { rc = SQLITE_NOMEM; goto train_error; }
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql);
+ if (rc != SQLITE_OK) goto train_error;
+ for (int i = 0; i < N; i++) {
+ sqlite3_reset(stmt);
+ sqlite3_bind_int64(stmt, 1, rowids[i]);
+ sqlite3_bind_blob(stmt, 2, &vectors[i * D], vecSize, SQLITE_STATIC);
+ sqlite3_step(stmt);
+ }
+ sqlite3_finalize(stmt);
+ }
+
+ // Set trained = 1
+ {
+ zSql = sqlite3_mprintf(
+ "INSERT OR REPLACE INTO " VEC0_SHADOW_INFO_NAME " (key, value) VALUES ('ivf_trained_%d', '1')",
+ p->schemaName, p->tableName, col_idx);
+ if (zSql) { sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql);
+ sqlite3_step(stmt); sqlite3_finalize(stmt); }
+ }
+ p->ivfTrainedCache[col_idx] = 1;
+
+ sqlite3_exec(p->db, "RELEASE ivf_train", NULL, NULL, NULL);
+ sqlite3_free(vectors); sqlite3_free(rowids); sqlite3_free(centroids); sqlite3_free(assignments);
+ return SQLITE_OK;
+
+train_error:
+ sqlite3_exec(p->db, "ROLLBACK TO ivf_train", NULL, NULL, NULL);
+ sqlite3_exec(p->db, "RELEASE ivf_train", NULL, NULL, NULL);
+ sqlite3_free(vectors); sqlite3_free(rowids); sqlite3_free(centroids); sqlite3_free(assignments);
+ return rc;
+}
+
+static int ivf_cmd_set_centroid(vec0_vtab *p, int col_idx, int centroid_id,
+ const void *vectorData, int vectorSize) {
+ sqlite3_stmt *stmt = NULL;
+ int rc;
+ int D = (int)p->vector_columns[col_idx].dimensions;
+ if (vectorSize != (int)(D * sizeof(float))) { vtab_set_error(&p->base, "Dimension mismatch"); return SQLITE_ERROR; }
+
+ char *zSql = sqlite3_mprintf(
+ "INSERT OR REPLACE INTO " VEC0_SHADOW_IVF_CENTROIDS_NAME " (centroid_id, centroid) VALUES (?, ?)",
+ p->schemaName, p->tableName, col_idx);
+ if (!zSql) return SQLITE_NOMEM;
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql);
+ if (rc != SQLITE_OK) return rc;
+ sqlite3_bind_int(stmt, 1, centroid_id);
+ sqlite3_bind_blob(stmt, 2, vectorData, vectorSize, SQLITE_STATIC);
+ rc = sqlite3_step(stmt); sqlite3_finalize(stmt);
+ if (rc != SQLITE_DONE) return SQLITE_ERROR;
+
+ zSql = sqlite3_mprintf(
+ "INSERT OR REPLACE INTO " VEC0_SHADOW_INFO_NAME " (key, value) VALUES ('ivf_trained_%d', '1')",
+ p->schemaName, p->tableName, col_idx);
+ if (zSql) { sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql);
+ sqlite3_step(stmt); sqlite3_finalize(stmt); }
+ p->ivfTrainedCache[col_idx] = 1;
+ sqlite3_finalize(p->stmtIvfCentroidsAll[col_idx]); p->stmtIvfCentroidsAll[col_idx] = NULL;
+ return SQLITE_OK;
+}
+
+static int ivf_cmd_assign_vectors(vec0_vtab *p, int col_idx) {
+ if (!ivf_is_trained(p, col_idx)) { vtab_set_error(&p->base, "No centroids"); return SQLITE_ERROR; }
+
+ int D = (int)p->vector_columns[col_idx].dimensions;
+ int vecSize = D * (int)sizeof(float);
+ int rc;
+ sqlite3_stmt *stmt = NULL;
+ char *zSql;
+
+ // Load centroids
+ int nlist = 0;
+ float *centroids = NULL;
+ zSql = sqlite3_mprintf("SELECT count(*) FROM " VEC0_SHADOW_IVF_CENTROIDS_NAME,
+ p->schemaName, p->tableName, col_idx);
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql);
+ if (rc == SQLITE_OK && sqlite3_step(stmt) == SQLITE_ROW) nlist = sqlite3_column_int(stmt, 0);
+ sqlite3_finalize(stmt);
+ if (nlist == 0) { vtab_set_error(&p->base, "No centroids"); return SQLITE_ERROR; }
+
+ centroids = sqlite3_malloc64((i64)nlist * D * sizeof(float));
+ if (!centroids) return SQLITE_NOMEM;
+ zSql = sqlite3_mprintf("SELECT centroid_id, centroid FROM " VEC0_SHADOW_IVF_CENTROIDS_NAME " ORDER BY centroid_id",
+ p->schemaName, p->tableName, col_idx);
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql);
+ { int ci = 0; while (sqlite3_step(stmt) == SQLITE_ROW && ci < nlist) {
+ const void *b = sqlite3_column_blob(stmt, 1);
+ int bBytes = sqlite3_column_bytes(stmt, 1);
+ if (b && bBytes == vecSize) memcpy(¢roids[ci * D], b, vecSize);
+ ci++;
+ }}
+ sqlite3_finalize(stmt);
+
+ // Read unassigned cells, re-insert into trained cells
+ zSql = sqlite3_mprintf(
+ "SELECT rowid, n_vectors, validity, rowids, vectors FROM " VEC0_SHADOW_IVF_CELLS_NAME
+ " WHERE centroid_id = %d",
+ p->schemaName, p->tableName, col_idx, VEC0_IVF_UNASSIGNED_CENTROID_ID);
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql);
+
+ // Invalidate cached stmts since we'll be modifying cells
+ ivf_invalidate_cached(p, col_idx);
+
+ while (sqlite3_step(stmt) == SQLITE_ROW) {
+ int n = sqlite3_column_int(stmt, 1);
+ const unsigned char *val = (const unsigned char *)sqlite3_column_blob(stmt, 2);
+ const i64 *rids = (const i64 *)sqlite3_column_blob(stmt, 3);
+ const float *vecs = (const float *)sqlite3_column_blob(stmt, 4);
+ int valBytes = sqlite3_column_bytes(stmt, 2);
+ int ridsBytes = sqlite3_column_bytes(stmt, 3);
+ int vecsBytes = sqlite3_column_bytes(stmt, 4);
+ if (!val || !rids || !vecs) continue;
+ int cap = valBytes * 8;
+ if (ridsBytes / (int)sizeof(i64) < cap) cap = ridsBytes / (int)sizeof(i64);
+ if (vecsBytes / vecSize < cap) cap = vecsBytes / vecSize;
+
+ for (int i = 0; i < cap && n > 0; i++) {
+ if (!(val[i / 8] & (1 << (i % 8)))) continue;
+ n--;
+ int cid = ivf_find_nearest_centroid(p, col_idx, &vecs[i * D], centroids, D, nlist);
+
+ // Delete old rowid_map entry
+ sqlite3_stmt *sd = NULL;
+ char *zd = sqlite3_mprintf("DELETE FROM " VEC0_SHADOW_IVF_ROWID_MAP_NAME " WHERE rowid = ?",
+ p->schemaName, p->tableName, col_idx);
+ if (zd) { sqlite3_prepare_v2(p->db, zd, -1, &sd, NULL); sqlite3_free(zd);
+ sqlite3_bind_int64(sd, 1, rids[i]); sqlite3_step(sd); sqlite3_finalize(sd); }
+
+ ivf_cell_insert(p, col_idx, cid, rids[i], &vecs[i * D], vecSize);
+ }
+ }
+ sqlite3_finalize(stmt);
+
+ // Delete unassigned cells
+ zSql = sqlite3_mprintf(
+ "DELETE FROM " VEC0_SHADOW_IVF_CELLS_NAME " WHERE centroid_id = %d",
+ p->schemaName, p->tableName, col_idx, VEC0_IVF_UNASSIGNED_CENTROID_ID);
+ if (zSql) { sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql);
+ sqlite3_step(stmt); sqlite3_finalize(stmt); }
+
+ sqlite3_free(centroids);
+ return SQLITE_OK;
+}
+
+static int ivf_cmd_clear_centroids(vec0_vtab *p, int col_idx) {
+ float *vectors = NULL; i64 *rowids = NULL; int N = 0;
+ int vecSize = ivf_vec_size(p, col_idx);
+ int D = (int)p->vector_columns[col_idx].dimensions;
+ int rc;
+ sqlite3_stmt *stmt = NULL;
+ char *zSql;
+
+ rc = ivf_load_all_vectors(p, col_idx, &vectors, &rowids, &N);
+ if (rc != SQLITE_OK) return rc;
+
+ ivf_invalidate_cached(p, col_idx);
+
+ ivf_exec(p, "DELETE FROM " VEC0_SHADOW_IVF_CENTROIDS_NAME, col_idx);
+ ivf_exec(p, "DELETE FROM " VEC0_SHADOW_IVF_CELLS_NAME, col_idx);
+ ivf_exec(p, "DELETE FROM " VEC0_SHADOW_IVF_ROWID_MAP_NAME, col_idx);
+
+ // Re-insert all vectors into unassigned cells
+ for (int i = 0; i < N; i++) {
+ ivf_cell_insert(p, col_idx, VEC0_IVF_UNASSIGNED_CENTROID_ID,
+ rowids[i], &vectors[i * D], vecSize);
+ }
+
+ zSql = sqlite3_mprintf(
+ "INSERT OR REPLACE INTO " VEC0_SHADOW_INFO_NAME " (key, value) VALUES ('ivf_trained_%d', '0')",
+ p->schemaName, p->tableName, col_idx);
+ if (zSql) { sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); sqlite3_free(zSql);
+ sqlite3_step(stmt); sqlite3_finalize(stmt); }
+ p->ivfTrainedCache[col_idx] = 0;
+
+ sqlite3_free(vectors); sqlite3_free(rowids);
+ return SQLITE_OK;
+}
+
+// ============================================================================
+// KNN Query — scan all cells for probed centroids
+// ============================================================================
+
+struct IvfCentroidDist { int id; float dist; };
+struct IvfCandidate { i64 rowid; float distance; };
+
+static int ivf_candidate_cmp(const void *a, const void *b) {
+ float da = ((const struct IvfCandidate *)a)->distance;
+ float db = ((const struct IvfCandidate *)b)->distance;
+ if (da < db) return -1;
+ if (da > db) return 1;
+ return 0;
+}
+
+/**
+ * Scan cell rows from a prepared statement, computing distances in-memory.
+ * The statement must return (n_vectors, validity, rowids, vectors) columns.
+ * queryVecQ is the quantized query (same type as cell vectors).
+ * qvecSize is the size of one quantized vector in bytes.
+ */
+static int ivf_scan_cells_from_stmt(vec0_vtab *p, int col_idx,
+ sqlite3_stmt *stmt,
+ const void *queryVecQ, int qvecSize,
+ struct IvfCandidate **candidates,
+ int *nCandidates, int *cap) {
+ while (sqlite3_step(stmt) == SQLITE_ROW) {
+ int n = sqlite3_column_int(stmt, 0);
+ if (n == 0) continue;
+ const unsigned char *validity = (const unsigned char *)sqlite3_column_blob(stmt, 1);
+ const i64 *rowids = (const i64 *)sqlite3_column_blob(stmt, 2);
+ const unsigned char *vectors = (const unsigned char *)sqlite3_column_blob(stmt, 3);
+ int valBytes = sqlite3_column_bytes(stmt, 1);
+ int ridsBytes = sqlite3_column_bytes(stmt, 2);
+ int vecsBytes = sqlite3_column_bytes(stmt, 3);
+ if (!validity || !rowids || !vectors) continue;
+ int cell_cap = valBytes * 8;
+ if (ridsBytes / (int)sizeof(i64) < cell_cap) cell_cap = ridsBytes / (int)sizeof(i64);
+ if (vecsBytes / qvecSize < cell_cap) cell_cap = vecsBytes / qvecSize;
+
+ int found = 0;
+ for (int i = 0; i < cell_cap && found < n; i++) {
+ if (!(validity[i / 8] & (1 << (i % 8)))) continue;
+ found++;
+ if (*nCandidates >= *cap) {
+ *cap *= 2;
+ struct IvfCandidate *tmp = sqlite3_realloc64(*candidates, (i64)*cap * sizeof(struct IvfCandidate));
+ if (!tmp) return SQLITE_NOMEM;
+ *candidates = tmp;
+ }
+ (*candidates)[*nCandidates].rowid = rowids[i];
+ (*candidates)[*nCandidates].distance = ivf_distance(p, col_idx,
+ queryVecQ, &vectors[i * qvecSize]);
+ (*nCandidates)++;
+ }
+ }
+ return SQLITE_OK;
+}
+
+static int ivf_query_knn(vec0_vtab *p, int col_idx,
+ const void *queryVector, int queryVectorSize,
+ i64 k, struct vec0_query_knn_data *knn_data) {
+ UNUSED_PARAMETER(queryVectorSize);
+ int rc;
+ int nprobe = p->vector_columns[col_idx].ivf.nprobe;
+ int trained = ivf_is_trained(p, col_idx);
+ int quantizer = p->vector_columns[col_idx].ivf.quantizer;
+ int oversample = p->vector_columns[col_idx].ivf.oversample;
+ int qvecSize = ivf_vec_size(p, col_idx);
+
+ // Quantize query vector for scanning
+ void *queryQ = sqlite3_malloc(qvecSize);
+ if (!queryQ) return SQLITE_NOMEM;
+ ivf_quantize(p, col_idx, (const float *)queryVector, queryQ);
+
+ // With oversample, collect more candidates for re-ranking
+ i64 collect_k = (oversample > 1) ? k * oversample : k;
+
+ int cap = (collect_k < 1024) ? 1024 : (int)collect_k * 2;
+ int nCandidates = 0;
+ struct IvfCandidate *candidates = sqlite3_malloc64((i64)cap * sizeof(struct IvfCandidate));
+ if (!candidates) { sqlite3_free(queryQ); return SQLITE_NOMEM; }
+
+ if (trained) {
+ // Find top nprobe centroids using quantized distance
+ int nlist = 0;
+ rc = ivf_ensure_stmt(p, &p->stmtIvfCentroidsAll[col_idx],
+ "SELECT centroid_id, centroid FROM " VEC0_SHADOW_IVF_CENTROIDS_NAME, col_idx);
+ if (rc != SQLITE_OK) { sqlite3_free(queryQ); sqlite3_free(candidates); return rc; }
+ sqlite3_stmt *stmt = p->stmtIvfCentroidsAll[col_idx];
+ sqlite3_reset(stmt);
+
+ int centroid_cap = 64;
+ struct IvfCentroidDist *cd = sqlite3_malloc64(centroid_cap * sizeof(*cd));
+ if (!cd) { sqlite3_free(queryQ); sqlite3_free(candidates); return SQLITE_NOMEM; }
+
+ while (sqlite3_step(stmt) == SQLITE_ROW) {
+ if (nlist >= centroid_cap) {
+ centroid_cap *= 2;
+ struct IvfCentroidDist *tmp = sqlite3_realloc64(cd, centroid_cap * sizeof(*cd));
+ if (!tmp) { sqlite3_free(cd); sqlite3_free(queryQ); sqlite3_free(candidates); return SQLITE_NOMEM; }
+ cd = tmp;
+ }
+ cd[nlist].id = sqlite3_column_int(stmt, 0);
+ const void *c = sqlite3_column_blob(stmt, 1);
+ int cBytes = sqlite3_column_bytes(stmt, 1);
+ // Compare quantized query with quantized centroid
+ cd[nlist].dist = (c && cBytes == qvecSize) ? ivf_distance(p, col_idx, queryQ, c) : FLT_MAX;
+ nlist++;
+ }
+
+ int actual_nprobe = nprobe < nlist ? nprobe : nlist;
+ for (int i = 0; i < actual_nprobe; i++) {
+ int min_j = i;
+ for (int j = i + 1; j < nlist; j++) {
+ if (cd[j].dist < cd[min_j].dist) min_j = j;
+ }
+ if (min_j != i) { struct IvfCentroidDist tmp = cd[i]; cd[i] = cd[min_j]; cd[min_j] = tmp; }
+ }
+
+ // Scan probed cells + unassigned with quantized distance
+ {
+ sqlite3_str *s = sqlite3_str_new(NULL);
+ sqlite3_str_appendf(s,
+ "SELECT n_vectors, validity, rowids, vectors FROM " VEC0_SHADOW_IVF_CELLS_NAME
+ " WHERE centroid_id IN (",
+ p->schemaName, p->tableName, col_idx);
+ for (int i = 0; i < actual_nprobe; i++) {
+ if (i > 0) sqlite3_str_appendall(s, ",");
+ sqlite3_str_appendf(s, "%d", cd[i].id);
+ }
+ sqlite3_str_appendf(s, ",%d)", VEC0_IVF_UNASSIGNED_CENTROID_ID);
+ char *zSql = sqlite3_str_finish(s);
+ if (!zSql) { sqlite3_free(cd); sqlite3_free(queryQ); sqlite3_free(candidates); return SQLITE_NOMEM; }
+
+ sqlite3_stmt *stmtScan = NULL;
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmtScan, NULL);
+ sqlite3_free(zSql);
+ if (rc != SQLITE_OK) { sqlite3_free(cd); sqlite3_free(queryQ); sqlite3_free(candidates); return rc; }
+
+ rc = ivf_scan_cells_from_stmt(p, col_idx, stmtScan, queryQ, qvecSize,
+ &candidates, &nCandidates, &cap);
+ sqlite3_finalize(stmtScan);
+ if (rc != SQLITE_OK) { sqlite3_free(cd); sqlite3_free(queryQ); sqlite3_free(candidates); return rc; }
+ }
+
+ sqlite3_free(cd);
+ } else {
+ // Flat mode: scan only unassigned cells
+ sqlite3_stmt *stmtScan = NULL;
+ char *zSql = sqlite3_mprintf(
+ "SELECT n_vectors, validity, rowids, vectors FROM " VEC0_SHADOW_IVF_CELLS_NAME
+ " WHERE centroid_id = %d",
+ p->schemaName, p->tableName, col_idx, VEC0_IVF_UNASSIGNED_CENTROID_ID);
+ if (!zSql) { sqlite3_free(queryQ); sqlite3_free(candidates); return SQLITE_NOMEM; }
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmtScan, NULL); sqlite3_free(zSql);
+ if (rc == SQLITE_OK) {
+ rc = ivf_scan_cells_from_stmt(p, col_idx, stmtScan, queryQ, qvecSize,
+ &candidates, &nCandidates, &cap);
+ sqlite3_finalize(stmtScan);
+ if (rc != SQLITE_OK) { sqlite3_free(queryQ); sqlite3_free(candidates); return rc; }
+ }
+ }
+
+ sqlite3_free(queryQ);
+
+ // Sort candidates by quantized distance
+ qsort(candidates, nCandidates, sizeof(struct IvfCandidate), ivf_candidate_cmp);
+
+ // Oversample re-ranking: re-score top (oversample*k) with full-precision vectors
+ if (oversample > 1 && quantizer != VEC0_IVF_QUANTIZER_NONE && nCandidates > 0) {
+ i64 rescore_n = collect_k < nCandidates ? collect_k : nCandidates;
+ sqlite3_stmt *stmtVec = NULL;
+ char *zSql = sqlite3_mprintf(
+ "SELECT vector FROM " VEC0_SHADOW_IVF_VECTORS_NAME " WHERE rowid = ?",
+ p->schemaName, p->tableName, col_idx);
+ if (zSql) {
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmtVec, NULL); sqlite3_free(zSql);
+ if (rc == SQLITE_OK) {
+ for (i64 i = 0; i < rescore_n; i++) {
+ sqlite3_reset(stmtVec);
+ sqlite3_bind_int64(stmtVec, 1, candidates[i].rowid);
+ if (sqlite3_step(stmtVec) == SQLITE_ROW) {
+ const float *fullVec = (const float *)sqlite3_column_blob(stmtVec, 0);
+ int fullVecBytes = sqlite3_column_bytes(stmtVec, 0);
+ if (fullVec && fullVecBytes == (int)p->vector_columns[col_idx].dimensions * (int)sizeof(float)) {
+ candidates[i].distance = ivf_distance_float(p, col_idx,
+ (const float *)queryVector, fullVec);
+ }
+ }
+ }
+ sqlite3_finalize(stmtVec);
+ }
+ }
+ // Re-sort after re-scoring
+ qsort(candidates, (size_t)rescore_n, sizeof(struct IvfCandidate), ivf_candidate_cmp);
+ nCandidates = (int)rescore_n;
+ }
+
+ qsort(candidates, nCandidates, sizeof(struct IvfCandidate), ivf_candidate_cmp);
+ i64 nResults = nCandidates < k ? nCandidates : k;
+
+ if (nResults == 0) {
+ knn_data->rowids = NULL; knn_data->distances = NULL;
+ knn_data->k = k; knn_data->k_used = 0; knn_data->current_idx = 0;
+ sqlite3_free(candidates); return SQLITE_OK;
+ }
+
+ knn_data->rowids = sqlite3_malloc64(nResults * sizeof(i64));
+ knn_data->distances = sqlite3_malloc64(nResults * sizeof(f32));
+ if (!knn_data->rowids || !knn_data->distances) {
+ sqlite3_free(knn_data->rowids); sqlite3_free(knn_data->distances);
+ sqlite3_free(candidates); return SQLITE_NOMEM;
+ }
+ for (i64 i = 0; i < nResults; i++) {
+ knn_data->rowids[i] = candidates[i].rowid;
+ knn_data->distances[i] = candidates[i].distance;
+ }
+ knn_data->k = k; knn_data->k_used = nResults; knn_data->current_idx = 0;
+ sqlite3_free(candidates);
+ return SQLITE_OK;
+}
+
+// ============================================================================
+// Command dispatch
+// ============================================================================
+
+static int ivf_handle_command(vec0_vtab *p, const char *command,
+ int argc, sqlite3_value **argv) {
+ UNUSED_PARAMETER(argc);
+ int col_idx = -1;
+ for (int i = 0; i < p->numVectorColumns; i++) {
+ if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_IVF) { col_idx = i; break; }
+ }
+ if (col_idx < 0) return SQLITE_EMPTY;
+
+ // nprobe=N — change nprobe at runtime without rebuilding
+ if (strncmp(command, "nprobe=", 7) == 0) {
+ int new_nprobe = atoi(command + 7);
+ if (new_nprobe < 1) {
+ vtab_set_error(&p->base, "nprobe must be >= 1");
+ return SQLITE_ERROR;
+ }
+ p->vector_columns[col_idx].ivf.nprobe = new_nprobe;
+ return SQLITE_OK;
+ }
+
+ if (strcmp(command, "compute-centroids") == 0)
+ return ivf_cmd_compute_centroids(p, col_idx, 0, VEC0_IVF_KMEANS_MAX_ITER, VEC0_IVF_KMEANS_DEFAULT_SEED);
+
+ if (strncmp(command, "compute-centroids:", 18) == 0) {
+ const char *json = command + 18;
+ int nlist = 0, max_iter = VEC0_IVF_KMEANS_MAX_ITER;
+ uint32_t seed = VEC0_IVF_KMEANS_DEFAULT_SEED;
+ const char *pn = strstr(json, "\"nlist\":"); if (pn) nlist = atoi(pn + 8);
+ const char *pi = strstr(json, "\"max_iterations\":"); if (pi) max_iter = atoi(pi + 17);
+ const char *ps = strstr(json, "\"seed\":"); if (ps) seed = (uint32_t)atoi(ps + 7);
+ return ivf_cmd_compute_centroids(p, col_idx, nlist, max_iter, seed);
+ }
+
+ if (strncmp(command, "set-centroid:", 13) == 0) {
+ int centroid_id = atoi(command + 13);
+ for (int i = 0; i < (int)(p->numVectorColumns + p->numPartitionColumns +
+ p->numAuxiliaryColumns + p->numMetadataColumns); i++) {
+ if (p->user_column_kinds[i] == SQLITE_VEC0_USER_COLUMN_KIND_VECTOR &&
+ p->user_column_idxs[i] == col_idx) {
+ sqlite3_value *v = argv[2 + VEC0_COLUMN_USERN_START + i];
+ if (sqlite3_value_type(v) == SQLITE_NULL) { vtab_set_error(&p->base, "set-centroid requires vector"); return SQLITE_ERROR; }
+ return ivf_cmd_set_centroid(p, col_idx, centroid_id, sqlite3_value_blob(v), sqlite3_value_bytes(v));
+ }
+ }
+ return SQLITE_ERROR;
+ }
+
+ if (strcmp(command, "assign-vectors") == 0) return ivf_cmd_assign_vectors(p, col_idx);
+ if (strcmp(command, "clear-centroids") == 0) return ivf_cmd_clear_centroids(p, col_idx);
+ return SQLITE_EMPTY;
+}
+
+#endif /* SQLITE_VEC_IVF_C */
diff --git a/sqlite-vec.c b/sqlite-vec.c
index 7079f7e..88f60b9 100644
--- a/sqlite-vec.c
+++ b/sqlite-vec.c
@@ -93,6 +93,10 @@ typedef size_t usize;
#define COMPILER_SUPPORTS_VTAB_IN 1
#endif
+#ifndef SQLITE_VEC_ENABLE_IVF
+#define SQLITE_VEC_ENABLE_IVF 1
+#endif
+
#ifndef SQLITE_SUBTYPE
#define SQLITE_SUBTYPE 0x000100000
#endif
@@ -2539,6 +2543,7 @@ enum Vec0IndexType {
#if SQLITE_VEC_ENABLE_RESCORE
VEC0_INDEX_TYPE_RESCORE = 2,
#endif
+ VEC0_INDEX_TYPE_IVF = 3,
};
#if SQLITE_VEC_ENABLE_RESCORE
@@ -2553,6 +2558,22 @@ struct Vec0RescoreConfig {
};
#endif
+#if SQLITE_VEC_ENABLE_IVF
+enum Vec0IvfQuantizer {
+ VEC0_IVF_QUANTIZER_NONE = 0,
+ VEC0_IVF_QUANTIZER_INT8 = 1,
+ VEC0_IVF_QUANTIZER_BINARY = 2,
+};
+
+struct Vec0IvfConfig {
+ int nlist; // number of centroids (0 = deferred)
+ int nprobe; // cells to probe at query time
+ int quantizer; // VEC0_IVF_QUANTIZER_NONE / INT8 / BINARY
+ int oversample; // >= 1 (1 = no oversampling)
+};
+#else
+struct Vec0IvfConfig { char _unused; };
+#endif
struct VectorColumnDefinition {
char *name;
@@ -2564,6 +2585,7 @@ struct VectorColumnDefinition {
#if SQLITE_VEC_ENABLE_RESCORE
struct Vec0RescoreConfig rescore;
#endif
+ struct Vec0IvfConfig ivf;
};
struct Vec0PartitionColumnDefinition {
@@ -2715,6 +2737,12 @@ static int vec0_parse_rescore_options(struct Vec0Scanner *scanner,
* @return int SQLITE_OK on success, SQLITE_EMPTY is it's not a vector column
* definition, SQLITE_ERROR on error.
*/
+#if SQLITE_VEC_ENABLE_IVF
+// Forward declaration — defined in sqlite-vec-ivf.c
+static int vec0_parse_ivf_options(struct Vec0Scanner *scanner,
+ struct Vec0IvfConfig *config);
+#endif
+
int vec0_parse_vector_column(const char *source, int source_length,
struct VectorColumnDefinition *outColumn) {
// parses a vector column definition like so:
@@ -2733,6 +2761,8 @@ int vec0_parse_vector_column(const char *source, int source_length,
struct Vec0RescoreConfig rescoreConfig;
memset(&rescoreConfig, 0, sizeof(rescoreConfig));
#endif
+ struct Vec0IvfConfig ivfConfig;
+ memset(&ivfConfig, 0, sizeof(ivfConfig));
int dimensions;
vec0_scanner_init(&scanner, source, source_length);
@@ -2891,7 +2921,18 @@ int vec0_parse_vector_column(const char *source, int source_length,
}
}
#endif
- else {
+ else if (sqlite3_strnicmp(token.start, "ivf", indexNameLen) == 0) {
+#if SQLITE_VEC_ENABLE_IVF
+ indexType = VEC0_INDEX_TYPE_IVF;
+ memset(&ivfConfig, 0, sizeof(ivfConfig));
+ rc = vec0_parse_ivf_options(&scanner, &ivfConfig);
+ if (rc != SQLITE_OK) {
+ return SQLITE_ERROR;
+ }
+#else
+ return SQLITE_ERROR; // IVF not compiled in
+#endif
+ } else {
// unknown index type
return SQLITE_ERROR;
}
@@ -2914,6 +2955,7 @@ int vec0_parse_vector_column(const char *source, int source_length,
#if SQLITE_VEC_ENABLE_RESCORE
outColumn->rescore = rescoreConfig;
#endif
+ outColumn->ivf = ivfConfig;
return SQLITE_OK;
}
@@ -3279,6 +3321,18 @@ struct vec0_vtab {
int chunk_size;
+#if SQLITE_VEC_ENABLE_IVF
+ // IVF cached state per vector column
+ char *shadowIvfCellsNames[VEC0_MAX_VECTOR_COLUMNS]; // table name for blob_open
+ int ivfTrainedCache[VEC0_MAX_VECTOR_COLUMNS]; // -1=unknown, 0=no, 1=yes
+ sqlite3_stmt *stmtIvfCellMeta[VEC0_MAX_VECTOR_COLUMNS]; // SELECT n_vectors, length(validity)*8 FROM cells WHERE cell_id=?
+ sqlite3_stmt *stmtIvfCellUpdateN[VEC0_MAX_VECTOR_COLUMNS]; // UPDATE cells SET n_vectors=n_vectors+? WHERE cell_id=?
+ sqlite3_stmt *stmtIvfRowidMapInsert[VEC0_MAX_VECTOR_COLUMNS]; // INSERT INTO rowid_map(rowid,cell_id,slot) VALUES(?,?,?)
+ sqlite3_stmt *stmtIvfRowidMapLookup[VEC0_MAX_VECTOR_COLUMNS]; // SELECT cell_id,slot FROM rowid_map WHERE rowid=?
+ sqlite3_stmt *stmtIvfRowidMapDelete[VEC0_MAX_VECTOR_COLUMNS]; // DELETE FROM rowid_map WHERE rowid=?
+ sqlite3_stmt *stmtIvfCentroidsAll[VEC0_MAX_VECTOR_COLUMNS]; // SELECT centroid_id,centroid FROM centroids
+#endif
+
// select latest chunk from _chunks, getting chunk_id
sqlite3_stmt *stmtLatestChunk;
@@ -3364,6 +3418,17 @@ void vec0_free_resources(vec0_vtab *p) {
p->stmtRowidsUpdatePosition = NULL;
sqlite3_finalize(p->stmtRowidsGetChunkPosition);
p->stmtRowidsGetChunkPosition = NULL;
+
+#if SQLITE_VEC_ENABLE_IVF
+ for (int i = 0; i < VEC0_MAX_VECTOR_COLUMNS; i++) {
+ sqlite3_finalize(p->stmtIvfCellMeta[i]); p->stmtIvfCellMeta[i] = NULL;
+ sqlite3_finalize(p->stmtIvfCellUpdateN[i]); p->stmtIvfCellUpdateN[i] = NULL;
+ sqlite3_finalize(p->stmtIvfRowidMapInsert[i]); p->stmtIvfRowidMapInsert[i] = NULL;
+ sqlite3_finalize(p->stmtIvfRowidMapLookup[i]); p->stmtIvfRowidMapLookup[i] = NULL;
+ sqlite3_finalize(p->stmtIvfRowidMapDelete[i]); p->stmtIvfRowidMapDelete[i] = NULL;
+ sqlite3_finalize(p->stmtIvfCentroidsAll[i]); p->stmtIvfCentroidsAll[i] = NULL;
+ }
+#endif
}
/**
@@ -3386,6 +3451,10 @@ void vec0_free(vec0_vtab *p) {
for (int i = 0; i < p->numVectorColumns; i++) {
sqlite3_free(p->shadowVectorChunksNames[i]);
p->shadowVectorChunksNames[i] = NULL;
+#if SQLITE_VEC_ENABLE_IVF
+ sqlite3_free(p->shadowIvfCellsNames[i]);
+ p->shadowIvfCellsNames[i] = NULL;
+#endif
#if SQLITE_VEC_ENABLE_RESCORE
sqlite3_free(p->shadowRescoreChunksNames[i]);
@@ -3674,12 +3743,25 @@ int vec0_result_id(vec0_vtab *p, sqlite3_context *context, i64 rowid) {
* will be stored.
* @return int SQLITE_OK on success.
*/
+#if SQLITE_VEC_ENABLE_IVF
+// Forward declaration — defined in sqlite-vec-ivf.c (included later)
+static int ivf_get_vector_data(vec0_vtab *p, i64 rowid, int col_idx,
+ void **outVector, int *outVectorSize);
+#endif
+
int vec0_get_vector_data(vec0_vtab *pVtab, i64 rowid, int vector_column_idx,
void **outVector, int *outVectorSize) {
vec0_vtab *p = pVtab;
int rc, brc;
i64 chunk_id;
i64 chunk_offset;
+
+#if SQLITE_VEC_ENABLE_IVF
+ // IVF-indexed columns store vectors in _ivf_cells, not _vector_chunks
+ if (p->vector_columns[vector_column_idx].index_type == VEC0_INDEX_TYPE_IVF) {
+ return ivf_get_vector_data(p, rowid, vector_column_idx, outVector, outVectorSize);
+ }
+#endif
size_t size;
void *buf = NULL;
int blobOffset;
@@ -4327,8 +4409,12 @@ int vec0_new_chunk(vec0_vtab *p, sqlite3_value ** partitionKeyValues, i64 *chunk
int vector_column_idx = p->user_column_idxs[i];
#if SQLITE_VEC_ENABLE_RESCORE
- // Rescore columns don't use _vector_chunks for float storage
- if (p->vector_columns[vector_column_idx].index_type == VEC0_INDEX_TYPE_RESCORE) {
+ // Rescore and IVF columns don't use _vector_chunks for float storage
+ if (p->vector_columns[vector_column_idx].index_type == VEC0_INDEX_TYPE_RESCORE
+#if SQLITE_VEC_ENABLE_IVF
+ || p->vector_columns[vector_column_idx].index_type == VEC0_INDEX_TYPE_IVF
+#endif
+ ) {
continue;
}
#endif
@@ -4500,6 +4586,12 @@ void vec0_cursor_clear(vec0_cursor *pCur) {
}
}
+// IVF index implementation — #include'd here after all struct/helper definitions
+#if SQLITE_VEC_ENABLE_IVF
+#include "sqlite-vec-ivf-kmeans.c"
+#include "sqlite-vec-ivf.c"
+#endif
+
#define VEC_CONSTRUCTOR_ERROR "vec0 constructor error: "
static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
sqlite3_vtab **ppVtab, char **pzErr, bool isCreate) {
@@ -4761,6 +4853,34 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
}
#endif
+ // IVF indexes do not support auxiliary, metadata, or partition key columns.
+ {
+ int has_ivf = 0;
+ for (int i = 0; i < numVectorColumns; i++) {
+ if (pNew->vector_columns[i].index_type == VEC0_INDEX_TYPE_IVF) {
+ has_ivf = 1;
+ break;
+ }
+ }
+ if (has_ivf) {
+ if (numPartitionColumns > 0) {
+ *pzErr = sqlite3_mprintf(VEC_CONSTRUCTOR_ERROR
+ "partition key columns are not supported with IVF indexes");
+ goto error;
+ }
+ if (numAuxiliaryColumns > 0) {
+ *pzErr = sqlite3_mprintf(VEC_CONSTRUCTOR_ERROR
+ "auxiliary columns are not supported with IVF indexes");
+ goto error;
+ }
+ if (numMetadataColumns > 0) {
+ *pzErr = sqlite3_mprintf(VEC_CONSTRUCTOR_ERROR
+ "metadata columns are not supported with IVF indexes");
+ goto error;
+ }
+ }
+ }
+
sqlite3_str *createStr = sqlite3_str_new(NULL);
sqlite3_str_appendall(createStr, "CREATE TABLE x(");
if (pkColumnName) {
@@ -4866,6 +4986,15 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
}
#endif
}
+#if SQLITE_VEC_ENABLE_IVF
+ for (int i = 0; i < pNew->numVectorColumns; i++) {
+ if (pNew->vector_columns[i].index_type != VEC0_INDEX_TYPE_IVF) continue;
+ pNew->shadowIvfCellsNames[i] =
+ sqlite3_mprintf("%s_ivf_cells%02d", tableName, i);
+ if (!pNew->shadowIvfCellsNames[i]) goto error;
+ pNew->ivfTrainedCache[i] = -1; // unknown
+ }
+#endif
for (int i = 0; i < pNew->numMetadataColumns; i++) {
pNew->shadowMetadataChunksNames[i] =
sqlite3_mprintf("%s_metadatachunks%02d", tableName, i);
@@ -4989,8 +5118,8 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
for (int i = 0; i < pNew->numVectorColumns; i++) {
#if SQLITE_VEC_ENABLE_RESCORE
- // Rescore columns don't use _vector_chunks
- if (pNew->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE)
+ // Rescore and IVF columns don't use _vector_chunks
+ if (pNew->vector_columns[i].index_type != VEC0_INDEX_TYPE_FLAT)
continue;
#endif
char *zSql = sqlite3_mprintf(VEC0_SHADOW_VECTOR_N_CREATE,
@@ -5018,6 +5147,18 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
}
#endif
+#if SQLITE_VEC_ENABLE_IVF
+ // Create IVF shadow tables for IVF-indexed vector columns
+ for (int i = 0; i < pNew->numVectorColumns; i++) {
+ if (pNew->vector_columns[i].index_type != VEC0_INDEX_TYPE_IVF) continue;
+ rc = ivf_create_shadow_tables(pNew, i);
+ if (rc != SQLITE_OK) {
+ *pzErr = sqlite3_mprintf("Could not create IVF shadow tables for column %d", i);
+ goto error;
+ }
+ }
+#endif
+
// See SHADOW_TABLE_ROWID_QUIRK in vec0_new_chunk() — same "rowid PRIMARY KEY"
// without INTEGER type issue applies here.
for (int i = 0; i < pNew->numMetadataColumns; i++) {
@@ -5153,7 +5294,7 @@ static int vec0Destroy(sqlite3_vtab *pVtab) {
for (int i = 0; i < p->numVectorColumns; i++) {
#if SQLITE_VEC_ENABLE_RESCORE
- if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE)
+ if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_FLAT)
continue;
#endif
zSql = sqlite3_mprintf("DROP TABLE \"%w\".\"%w\"", p->schemaName,
@@ -5174,6 +5315,14 @@ static int vec0Destroy(sqlite3_vtab *pVtab) {
}
#endif
+#if SQLITE_VEC_ENABLE_IVF
+ // Drop IVF shadow tables
+ for (int i = 0; i < p->numVectorColumns; i++) {
+ if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_IVF) continue;
+ ivf_drop_shadow_tables(p, i);
+ }
+#endif
+
if(p->numAuxiliaryColumns > 0) {
zSql = sqlite3_mprintf("DROP TABLE " VEC0_SHADOW_AUXILIARY_NAME, p->schemaName, p->tableName);
rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0);
@@ -7186,6 +7335,21 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
}
#endif
+#if SQLITE_VEC_ENABLE_IVF
+ // IVF dispatch: if vector column has IVF, use IVF query instead of chunk scan
+ if (vector_column->index_type == VEC0_INDEX_TYPE_IVF) {
+ rc = ivf_query_knn(p, vectorColumnIdx, queryVector,
+ (int)vector_column_byte_size(*vector_column), k, knn_data);
+ if (rc != SQLITE_OK) {
+ goto cleanup;
+ }
+ pCur->knn_data = knn_data;
+ pCur->query_plan = VEC0_QUERY_PLAN_KNN;
+ rc = SQLITE_OK;
+ goto cleanup;
+ }
+#endif
+
rc = vec0_chunks_iter(p, idxStr, argc, argv, &stmtChunks);
if (rc != SQLITE_OK) {
// IMP: V06942_23781
@@ -8011,8 +8175,12 @@ int vec0Update_InsertWriteFinalStep(vec0_vtab *p, i64 chunk_rowid,
// Go insert the vector data into the vector chunk shadow tables
for (int i = 0; i < p->numVectorColumns; i++) {
#if SQLITE_VEC_ENABLE_RESCORE
- // Rescore columns store float vectors in _rescore_vectors instead
- if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE)
+ // Rescore and IVF columns don't use _vector_chunks
+ if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE
+#if SQLITE_VEC_ENABLE_IVF
+ || p->vector_columns[i].index_type == VEC0_INDEX_TYPE_IVF
+#endif
+ )
continue;
#endif
@@ -8425,6 +8593,18 @@ int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv,
}
#endif
+#if SQLITE_VEC_ENABLE_IVF
+ // Step #4: IVF index insert (if any vector column uses IVF)
+ for (int i = 0; i < p->numVectorColumns; i++) {
+ if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_IVF) continue;
+ int vecSize = (int)vector_column_byte_size(p->vector_columns[i]);
+ rc = ivf_insert(p, i, rowid, vectorDatas[i], vecSize);
+ if (rc != SQLITE_OK) {
+ goto cleanup;
+ }
+ }
+#endif
+
if(p->numAuxiliaryColumns > 0) {
sqlite3_stmt *stmt;
sqlite3_str * s = sqlite3_str_new(NULL);
@@ -8616,8 +8796,8 @@ int vec0Update_Delete_ClearVectors(vec0_vtab *p, i64 chunk_id,
int rc, brc;
for (int i = 0; i < p->numVectorColumns; i++) {
#if SQLITE_VEC_ENABLE_RESCORE
- // Rescore columns don't use _vector_chunks
- if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE)
+ // Non-FLAT columns don't use _vector_chunks
+ if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_FLAT)
continue;
#endif
sqlite3_blob *blobVectors = NULL;
@@ -8732,7 +8912,7 @@ int vec0Update_Delete_DeleteChunkIfEmpty(vec0_vtab *p, i64 chunk_id,
// Delete from each _vector_chunksNN
for (int i = 0; i < p->numVectorColumns; i++) {
#if SQLITE_VEC_ENABLE_RESCORE
- if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE)
+ if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_FLAT)
continue;
#endif
zSql = sqlite3_mprintf(
@@ -9009,6 +9189,15 @@ int vec0Update_Delete(sqlite3_vtab *pVTab, sqlite3_value *idValue) {
}
}
+#if SQLITE_VEC_ENABLE_IVF
+ // 7. delete from IVF index
+ for (int i = 0; i < p->numVectorColumns; i++) {
+ if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_IVF) continue;
+ rc = ivf_delete(p, i, rowid);
+ if (rc != SQLITE_OK) return rc;
+ }
+#endif
+
return SQLITE_OK;
}
@@ -9284,6 +9473,18 @@ static int vec0Update(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv,
}
// INSERT operation
else if (argc > 1 && sqlite3_value_type(argv[0]) == SQLITE_NULL) {
+#if SQLITE_VEC_ENABLE_IVF
+ // Check for IVF command inserts: INSERT INTO t(rowid) VALUES ('compute-centroids')
+ // The id column holds the command string.
+ sqlite3_value *idVal = argv[2 + VEC0_COLUMN_ID];
+ if (sqlite3_value_type(idVal) == SQLITE_TEXT) {
+ const char *cmd = (const char *)sqlite3_value_text(idVal);
+ vec0_vtab *p = (vec0_vtab *)pVTab;
+ int cmdRc = ivf_handle_command(p, cmd, argc, argv);
+ if (cmdRc != SQLITE_EMPTY) return cmdRc; // handled (or error)
+ // SQLITE_EMPTY means not an IVF command — fall through to normal insert
+ }
+#endif
return vec0Update_Insert(pVTab, argc, argv, pRowid);
}
// UPDATE operation
@@ -9431,9 +9632,15 @@ static sqlite3_module vec0Module = {
#define SQLITE_VEC_DEBUG_BUILD_RESCORE ""
#endif
+#if SQLITE_VEC_ENABLE_IVF
+#define SQLITE_VEC_DEBUG_BUILD_IVF "ivf"
+#else
+#define SQLITE_VEC_DEBUG_BUILD_IVF ""
+#endif
+
#define SQLITE_VEC_DEBUG_BUILD \
SQLITE_VEC_DEBUG_BUILD_AVX " " SQLITE_VEC_DEBUG_BUILD_NEON " " \
- SQLITE_VEC_DEBUG_BUILD_RESCORE
+ SQLITE_VEC_DEBUG_BUILD_RESCORE " " SQLITE_VEC_DEBUG_BUILD_IVF
#define SQLITE_VEC_DEBUG_STRING \
"Version: " SQLITE_VEC_VERSION "\n" \
diff --git a/tests/fuzz/Makefile b/tests/fuzz/Makefile
index 0030c2e..a3405a4 100644
--- a/tests/fuzz/Makefile
+++ b/tests/fuzz/Makefile
@@ -93,13 +93,40 @@ $(TARGET_DIR)/rescore_quantize_edge: rescore-quantize-edge.c $(FUZZ_SRCS) | $(TA
$(TARGET_DIR)/rescore_interleave: rescore-interleave.c $(FUZZ_SRCS) | $(TARGET_DIR)
$(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE $(FUZZ_SRCS) $< -o $@
+$(TARGET_DIR)/ivf_create: ivf-create.c $(FUZZ_SRCS) | $(TARGET_DIR)
+ $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
+
+$(TARGET_DIR)/ivf_operations: ivf-operations.c $(FUZZ_SRCS) | $(TARGET_DIR)
+ $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
+
+$(TARGET_DIR)/ivf_quantize: ivf-quantize.c $(FUZZ_SRCS) | $(TARGET_DIR)
+ $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
+
+$(TARGET_DIR)/ivf_kmeans: ivf-kmeans.c $(FUZZ_SRCS) | $(TARGET_DIR)
+ $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
+
+$(TARGET_DIR)/ivf_shadow_corrupt: ivf-shadow-corrupt.c $(FUZZ_SRCS) | $(TARGET_DIR)
+ $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
+
+$(TARGET_DIR)/ivf_knn_deep: ivf-knn-deep.c $(FUZZ_SRCS) | $(TARGET_DIR)
+ $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
+
+$(TARGET_DIR)/ivf_cell_overflow: ivf-cell-overflow.c $(FUZZ_SRCS) | $(TARGET_DIR)
+ $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
+
+$(TARGET_DIR)/ivf_rescore: ivf-rescore.c $(FUZZ_SRCS) | $(TARGET_DIR)
+ $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
+
FUZZ_TARGETS = vec0_create exec json numpy \
shadow_corrupt vec0_operations scalar_functions \
vec0_create_full metadata_columns vec_each vec_mismatch \
vec0_delete_completeness \
rescore_operations rescore_create rescore_quantize \
rescore_shadow_corrupt rescore_knn_deep \
- rescore_quantize_edge rescore_interleave
+ rescore_quantize_edge rescore_interleave \
+ ivf_create ivf_operations \
+ ivf_quantize ivf_kmeans ivf_shadow_corrupt \
+ ivf_knn_deep ivf_cell_overflow ivf_rescore
all: $(addprefix $(TARGET_DIR)/,$(FUZZ_TARGETS))
diff --git a/tests/fuzz/ivf-cell-overflow.c b/tests/fuzz/ivf-cell-overflow.c
new file mode 100644
index 0000000..4b18ba2
--- /dev/null
+++ b/tests/fuzz/ivf-cell-overflow.c
@@ -0,0 +1,192 @@
+/**
+ * Fuzz target: IVF cell overflow and boundary conditions.
+ *
+ * Pushes cells past VEC0_IVF_CELL_MAX_VECTORS (64) to trigger cell
+ * splitting, then exercises blob I/O at slot boundaries.
+ *
+ * Targets:
+ * - Cell splitting when n_vectors reaches cap (64)
+ * - Blob offset arithmetic: slot * vecSize, slot / 8, slot % 8
+ * - Validity bitmap at byte boundaries (slot 7->8, 15->16, etc.)
+ * - Insert into full cell -> create new cell path
+ * - Delete from various slot positions (first, last, middle)
+ * - Multiple cells per centroid
+ * - assign-vectors command with multi-cell centroids
+ */
+#include
+#include
+#include
+#include
+#include
+#include "sqlite-vec.h"
+#include "sqlite3.h"
+#include
+
+int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
+ if (size < 8) return 0;
+
+ int rc;
+ sqlite3 *db;
+
+ rc = sqlite3_open(":memory:", &db);
+ assert(rc == SQLITE_OK);
+ rc = sqlite3_vec_init(db, NULL, NULL);
+ assert(rc == SQLITE_OK);
+
+ // Use small dimensions for speed but enough vectors to overflow cells
+ int dim = (data[0] % 8) + 2; // 2..9
+ int nlist = (data[1] % 4) + 1; // 1..4
+ // We need >64 vectors to overflow a cell
+ int num_vecs = (data[2] % 64) + 65; // 65..128
+ int delete_pattern = data[3]; // Controls which vectors to delete
+
+ const uint8_t *payload = data + 4;
+ size_t payload_size = size - 4;
+
+ char sql[256];
+ snprintf(sql, sizeof(sql),
+ "CREATE VIRTUAL TABLE v USING vec0("
+ "emb float[%d] indexed by ivf(nlist=%d, nprobe=%d))",
+ dim, nlist, nlist);
+
+ rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
+ if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
+
+ // Insert enough vectors to overflow at least one cell
+ sqlite3_stmt *stmtInsert = NULL;
+ sqlite3_prepare_v2(db,
+ "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
+ if (!stmtInsert) { sqlite3_close(db); return 0; }
+
+ size_t offset = 0;
+ for (int i = 0; i < num_vecs; i++) {
+ float *vec = sqlite3_malloc(dim * sizeof(float));
+ if (!vec) break;
+ for (int d = 0; d < dim; d++) {
+ if (offset < payload_size) {
+ vec[d] = ((float)(int8_t)payload[offset++]) / 50.0f;
+ } else {
+ // Cluster vectors near specific centroids to ensure some cells overflow
+ int cluster = i % nlist;
+ vec[d] = (float)cluster + (float)(i % 10) * 0.01f + d * 0.001f;
+ }
+ }
+ sqlite3_reset(stmtInsert);
+ sqlite3_bind_int64(stmtInsert, 1, (int64_t)(i + 1));
+ sqlite3_bind_blob(stmtInsert, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT);
+ sqlite3_step(stmtInsert);
+ sqlite3_free(vec);
+ }
+ sqlite3_finalize(stmtInsert);
+
+ // Train to assign vectors to centroids (triggers cell building)
+ sqlite3_exec(db,
+ "INSERT INTO v(rowid) VALUES ('compute-centroids')",
+ NULL, NULL, NULL);
+
+ // Delete vectors at boundary positions based on fuzz data
+ // This tests validity bitmap manipulation at different slot positions
+ for (int i = 0; i < num_vecs; i++) {
+ int byte_idx = i / 8;
+ if (byte_idx < (int)payload_size && (payload[byte_idx] & (1 << (i % 8)))) {
+ // Use delete_pattern to thin deletions
+ if ((delete_pattern + i) % 3 == 0) {
+ char delsql[64];
+ snprintf(delsql, sizeof(delsql), "DELETE FROM v WHERE rowid = %d", i + 1);
+ sqlite3_exec(db, delsql, NULL, NULL, NULL);
+ }
+ }
+ }
+
+ // Insert more vectors after deletions (into cells with holes)
+ {
+ sqlite3_stmt *si = NULL;
+ sqlite3_prepare_v2(db,
+ "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &si, NULL);
+ if (si) {
+ for (int i = 0; i < 10; i++) {
+ float *vec = sqlite3_malloc(dim * sizeof(float));
+ if (!vec) break;
+ for (int d = 0; d < dim; d++)
+ vec[d] = (float)(i + 200) * 0.01f;
+ sqlite3_reset(si);
+ sqlite3_bind_int64(si, 1, (int64_t)(num_vecs + i + 1));
+ sqlite3_bind_blob(si, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT);
+ sqlite3_step(si);
+ sqlite3_free(vec);
+ }
+ sqlite3_finalize(si);
+ }
+ }
+
+ // KNN query that must scan multiple cells per centroid
+ {
+ float *qvec = sqlite3_malloc(dim * sizeof(float));
+ if (qvec) {
+ for (int d = 0; d < dim; d++) qvec[d] = 0.0f;
+ sqlite3_stmt *sk = NULL;
+ snprintf(sql, sizeof(sql),
+ "SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 20");
+ sqlite3_prepare_v2(db, sql, -1, &sk, NULL);
+ if (sk) {
+ sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
+ while (sqlite3_step(sk) == SQLITE_ROW) {}
+ sqlite3_finalize(sk);
+ }
+ sqlite3_free(qvec);
+ }
+ }
+
+ // Test assign-vectors with multi-cell state
+ // First clear centroids
+ sqlite3_exec(db,
+ "INSERT INTO v(rowid) VALUES ('clear-centroids')",
+ NULL, NULL, NULL);
+
+ // Set centroids manually, then assign
+ for (int c = 0; c < nlist; c++) {
+ float *cvec = sqlite3_malloc(dim * sizeof(float));
+ if (!cvec) break;
+ for (int d = 0; d < dim; d++) cvec[d] = (float)c + d * 0.1f;
+
+ char cmd[128];
+ snprintf(cmd, sizeof(cmd),
+ "INSERT INTO v(rowid, emb) VALUES ('set-centroid:%d', ?)", c);
+ sqlite3_stmt *sc = NULL;
+ sqlite3_prepare_v2(db, cmd, -1, &sc, NULL);
+ if (sc) {
+ sqlite3_bind_blob(sc, 1, cvec, dim * sizeof(float), SQLITE_TRANSIENT);
+ sqlite3_step(sc);
+ sqlite3_finalize(sc);
+ }
+ sqlite3_free(cvec);
+ }
+
+ sqlite3_exec(db,
+ "INSERT INTO v(rowid) VALUES ('assign-vectors')",
+ NULL, NULL, NULL);
+
+ // Final query after assign-vectors
+ {
+ float *qvec = sqlite3_malloc(dim * sizeof(float));
+ if (qvec) {
+ for (int d = 0; d < dim; d++) qvec[d] = 1.0f;
+ sqlite3_stmt *sk = NULL;
+ sqlite3_prepare_v2(db,
+ "SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 5",
+ -1, &sk, NULL);
+ if (sk) {
+ sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
+ while (sqlite3_step(sk) == SQLITE_ROW) {}
+ sqlite3_finalize(sk);
+ }
+ sqlite3_free(qvec);
+ }
+ }
+
+ // Full scan
+ sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
+
+ sqlite3_close(db);
+ return 0;
+}
diff --git a/tests/fuzz/ivf-create.c b/tests/fuzz/ivf-create.c
new file mode 100644
index 0000000..222b67b
--- /dev/null
+++ b/tests/fuzz/ivf-create.c
@@ -0,0 +1,36 @@
+#include
+#include
+#include
+#include
+#include
+#include "sqlite-vec.h"
+#include "sqlite3.h"
+#include
+
+int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
+ int rc;
+ sqlite3 *db;
+ sqlite3_stmt *stmt;
+
+ rc = sqlite3_open(":memory:", &db);
+ assert(rc == SQLITE_OK);
+ rc = sqlite3_vec_init(db, NULL, NULL);
+ assert(rc == SQLITE_OK);
+
+ sqlite3_str *s = sqlite3_str_new(NULL);
+ assert(s);
+ sqlite3_str_appendall(s, "CREATE VIRTUAL TABLE v USING vec0(emb float[4] indexed by ivf(");
+ sqlite3_str_appendf(s, "%.*s", (int)size, data);
+ sqlite3_str_appendall(s, "))");
+ const char *zSql = sqlite3_str_finish(s);
+ assert(zSql);
+
+ rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, NULL);
+ sqlite3_free((void *)zSql);
+ if (rc == SQLITE_OK) {
+ sqlite3_step(stmt);
+ }
+ sqlite3_finalize(stmt);
+ sqlite3_close(db);
+ return 0;
+}
diff --git a/tests/fuzz/ivf-create.dict b/tests/fuzz/ivf-create.dict
new file mode 100644
index 0000000..9a014e7
--- /dev/null
+++ b/tests/fuzz/ivf-create.dict
@@ -0,0 +1,16 @@
+"nlist"
+"nprobe"
+"quantizer"
+"oversample"
+"binary"
+"int8"
+"none"
+"="
+","
+"("
+")"
+"0"
+"1"
+"128"
+"65536"
+"65537"
diff --git a/tests/fuzz/ivf-kmeans.c b/tests/fuzz/ivf-kmeans.c
new file mode 100644
index 0000000..46804d0
--- /dev/null
+++ b/tests/fuzz/ivf-kmeans.c
@@ -0,0 +1,180 @@
+/**
+ * Fuzz target: IVF k-means clustering.
+ *
+ * Builds a table, inserts fuzz-controlled vectors, then runs
+ * compute-centroids with fuzz-controlled parameters (nlist, max_iter, seed).
+ * Targets:
+ * - kmeans with N < k (clamping), N == 1, k == 1
+ * - kmeans with duplicate/identical vectors (all distances zero)
+ * - kmeans with NaN/Inf vectors
+ * - Empty cluster reassignment path (farthest-point heuristic)
+ * - Large nlist relative to N
+ * - The compute-centroids:{json} command parsing
+ * - clear-centroids followed by compute-centroids (round-trip)
+ */
+#include
+#include
+#include
+#include
+#include
+#include "sqlite-vec.h"
+#include "sqlite3.h"
+#include
+
+int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
+ if (size < 10) return 0;
+
+ int rc;
+ sqlite3 *db;
+
+ rc = sqlite3_open(":memory:", &db);
+ assert(rc == SQLITE_OK);
+ rc = sqlite3_vec_init(db, NULL, NULL);
+ assert(rc == SQLITE_OK);
+
+ // Parse fuzz header
+ // Byte 0-1: dimension (1..128)
+ // Byte 2: nlist for CREATE (1..64)
+ // Byte 3: nlist override for compute-centroids (0 = use default)
+ // Byte 4: max_iter (1..50)
+ // Byte 5-8: seed
+ // Byte 9: num_vectors (1..64)
+ // Remaining: vector float data
+
+ int dim = (data[0] | (data[1] << 8)) % 128 + 1;
+ int nlist_create = (data[2] % 64) + 1;
+ int nlist_override = data[3] % 65; // 0 means use table default
+ int max_iter = (data[4] % 50) + 1;
+ uint32_t seed = (uint32_t)data[5] | ((uint32_t)data[6] << 8) |
+ ((uint32_t)data[7] << 16) | ((uint32_t)data[8] << 24);
+ int num_vecs = (data[9] % 64) + 1;
+
+ const uint8_t *payload = data + 10;
+ size_t payload_size = size - 10;
+
+ char sql[256];
+ snprintf(sql, sizeof(sql),
+ "CREATE VIRTUAL TABLE v USING vec0("
+ "emb float[%d] indexed by ivf(nlist=%d, nprobe=%d))",
+ dim, nlist_create, nlist_create);
+
+ rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
+ if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
+
+ // Insert vectors
+ sqlite3_stmt *stmtInsert = NULL;
+ sqlite3_prepare_v2(db,
+ "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
+ if (!stmtInsert) { sqlite3_close(db); return 0; }
+
+ size_t offset = 0;
+ for (int i = 0; i < num_vecs; i++) {
+ float *vec = sqlite3_malloc(dim * sizeof(float));
+ if (!vec) break;
+
+ for (int d = 0; d < dim; d++) {
+ if (offset + 4 <= payload_size) {
+ memcpy(&vec[d], payload + offset, sizeof(float));
+ offset += 4;
+ } else if (offset < payload_size) {
+ // Scale to interesting range including values > 1, < -1
+ vec[d] = ((float)(int8_t)payload[offset++]) / 5.0f;
+ } else {
+ // Reuse earlier bytes to fill remaining dimensions
+ vec[d] = (float)(i * dim + d) * 0.01f;
+ }
+ }
+
+ sqlite3_reset(stmtInsert);
+ sqlite3_bind_int64(stmtInsert, 1, (int64_t)(i + 1));
+ sqlite3_bind_blob(stmtInsert, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT);
+ sqlite3_step(stmtInsert);
+ sqlite3_free(vec);
+ }
+ sqlite3_finalize(stmtInsert);
+
+ // Exercise compute-centroids with JSON options
+ {
+ char cmd[256];
+ snprintf(cmd, sizeof(cmd),
+ "INSERT INTO v(rowid) VALUES "
+ "('compute-centroids:{\"nlist\":%d,\"max_iterations\":%d,\"seed\":%u}')",
+ nlist_override, max_iter, seed);
+ sqlite3_exec(db, cmd, NULL, NULL, NULL);
+ }
+
+ // KNN query after training
+ {
+ float *qvec = sqlite3_malloc(dim * sizeof(float));
+ if (qvec) {
+ for (int d = 0; d < dim; d++) {
+ qvec[d] = (d < 3) ? 1.0f : 0.0f;
+ }
+ sqlite3_stmt *stmtKnn = NULL;
+ sqlite3_prepare_v2(db,
+ "SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 5",
+ -1, &stmtKnn, NULL);
+ if (stmtKnn) {
+ sqlite3_bind_blob(stmtKnn, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
+ while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
+ sqlite3_finalize(stmtKnn);
+ }
+ sqlite3_free(qvec);
+ }
+ }
+
+ // Clear centroids and re-compute to test round-trip
+ sqlite3_exec(db,
+ "INSERT INTO v(rowid) VALUES ('clear-centroids')",
+ NULL, NULL, NULL);
+
+ // Insert a few more vectors in untrained state
+ {
+ sqlite3_stmt *si = NULL;
+ sqlite3_prepare_v2(db,
+ "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &si, NULL);
+ if (si) {
+ for (int i = 0; i < 3; i++) {
+ float *vec = sqlite3_malloc(dim * sizeof(float));
+ if (!vec) break;
+ for (int d = 0; d < dim; d++) vec[d] = (float)(i + 100) * 0.1f;
+ sqlite3_reset(si);
+ sqlite3_bind_int64(si, 1, (int64_t)(num_vecs + i + 1));
+ sqlite3_bind_blob(si, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT);
+ sqlite3_step(si);
+ sqlite3_free(vec);
+ }
+ sqlite3_finalize(si);
+ }
+ }
+
+ // Re-train
+ sqlite3_exec(db,
+ "INSERT INTO v(rowid) VALUES ('compute-centroids')",
+ NULL, NULL, NULL);
+
+ // Delete some rows after training, then query
+ sqlite3_exec(db, "DELETE FROM v WHERE rowid = 1", NULL, NULL, NULL);
+ sqlite3_exec(db, "DELETE FROM v WHERE rowid = 2", NULL, NULL, NULL);
+
+ // Query after deletes
+ {
+ float *qvec = sqlite3_malloc(dim * sizeof(float));
+ if (qvec) {
+ for (int d = 0; d < dim; d++) qvec[d] = 0.5f;
+ sqlite3_stmt *stmtKnn = NULL;
+ sqlite3_prepare_v2(db,
+ "SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 10",
+ -1, &stmtKnn, NULL);
+ if (stmtKnn) {
+ sqlite3_bind_blob(stmtKnn, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
+ while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
+ sqlite3_finalize(stmtKnn);
+ }
+ sqlite3_free(qvec);
+ }
+ }
+
+ sqlite3_close(db);
+ return 0;
+}
diff --git a/tests/fuzz/ivf-knn-deep.c b/tests/fuzz/ivf-knn-deep.c
new file mode 100644
index 0000000..27d19a1
--- /dev/null
+++ b/tests/fuzz/ivf-knn-deep.c
@@ -0,0 +1,199 @@
+/**
+ * Fuzz target: IVF KNN search deep paths.
+ *
+ * Exercises the full KNN pipeline with fuzz-controlled:
+ * - nprobe values (including > nlist, =1, =nlist)
+ * - Query vectors (including adversarial floats)
+ * - Mix of trained/untrained state
+ * - Oversample + rescore path (quantizer=int8 with oversample>1)
+ * - Multiple interleaved KNN queries
+ * - Candidate array realloc path (many vectors in probed cells)
+ *
+ * Targets:
+ * - ivf_scan_cells_from_stmt: candidate realloc, distance computation
+ * - ivf_query_knn: centroid sorting, nprobe selection
+ * - Oversample rescore: re-ranking with full-precision vectors
+ * - qsort with NaN distances
+ */
+#include
+#include
+#include
+#include
+#include
+#include "sqlite-vec.h"
+#include "sqlite3.h"
+#include
+
+static uint16_t read_u16(const uint8_t *p) {
+ return (uint16_t)(p[0] | (p[1] << 8));
+}
+
+int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
+ if (size < 16) return 0;
+
+ int rc;
+ sqlite3 *db;
+
+ rc = sqlite3_open(":memory:", &db);
+ assert(rc == SQLITE_OK);
+ rc = sqlite3_vec_init(db, NULL, NULL);
+ assert(rc == SQLITE_OK);
+
+ // Header
+ int dim = (data[0] % 32) + 2; // 2..33
+ int nlist = (data[1] % 16) + 1; // 1..16
+ int nprobe_initial = (data[2] % 20) + 1; // 1..20 (can be > nlist)
+ int quantizer_type = data[3] % 3; // 0=none, 1=int8, 2=binary
+ int oversample = (data[4] % 4) + 1; // 1..4
+ int num_vecs = (data[5] % 80) + 4; // 4..83
+ int num_queries = (data[6] % 8) + 1; // 1..8
+ int k_limit = (data[7] % 20) + 1; // 1..20
+
+ const uint8_t *payload = data + 8;
+ size_t payload_size = size - 8;
+
+ // For binary quantizer, dimension must be multiple of 8
+ if (quantizer_type == 2) {
+ dim = ((dim + 7) / 8) * 8;
+ if (dim == 0) dim = 8;
+ }
+
+ const char *qname;
+ switch (quantizer_type) {
+ case 1: qname = "int8"; break;
+ case 2: qname = "binary"; break;
+ default: qname = "none"; break;
+ }
+
+ // Oversample only valid with quantization
+ if (quantizer_type == 0) oversample = 1;
+
+ // Cap nprobe to nlist for CREATE (parser rejects nprobe > nlist)
+ int nprobe_create = nprobe_initial <= nlist ? nprobe_initial : nlist;
+
+ char sql[512];
+ snprintf(sql, sizeof(sql),
+ "CREATE VIRTUAL TABLE v USING vec0("
+ "emb float[%d] indexed by ivf(nlist=%d, nprobe=%d, quantizer=%s%s))",
+ dim, nlist, nprobe_create, qname,
+ oversample > 1 ? ", oversample=2" : "");
+
+ // If that fails (e.g. oversample with none), try without oversample
+ rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
+ if (rc != SQLITE_OK) {
+ snprintf(sql, sizeof(sql),
+ "CREATE VIRTUAL TABLE v USING vec0("
+ "emb float[%d] indexed by ivf(nlist=%d, nprobe=%d, quantizer=%s))",
+ dim, nlist, nprobe_create, qname);
+ rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
+ if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
+ }
+
+ // Insert vectors
+ sqlite3_stmt *stmtInsert = NULL;
+ sqlite3_prepare_v2(db,
+ "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
+ if (!stmtInsert) { sqlite3_close(db); return 0; }
+
+ size_t offset = 0;
+ for (int i = 0; i < num_vecs; i++) {
+ float *vec = sqlite3_malloc(dim * sizeof(float));
+ if (!vec) break;
+ for (int d = 0; d < dim; d++) {
+ if (offset < payload_size) {
+ vec[d] = ((float)(int8_t)payload[offset++]) / 20.0f;
+ } else {
+ vec[d] = (float)((i * dim + d) % 256 - 128) / 128.0f;
+ }
+ }
+ sqlite3_reset(stmtInsert);
+ sqlite3_bind_int64(stmtInsert, 1, (int64_t)(i + 1));
+ sqlite3_bind_blob(stmtInsert, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT);
+ sqlite3_step(stmtInsert);
+ sqlite3_free(vec);
+ }
+ sqlite3_finalize(stmtInsert);
+
+ // Query BEFORE training (flat scan path)
+ {
+ float *qvec = sqlite3_malloc(dim * sizeof(float));
+ if (qvec) {
+ for (int d = 0; d < dim; d++) qvec[d] = 0.5f;
+ sqlite3_stmt *sk = NULL;
+ snprintf(sql, sizeof(sql),
+ "SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit);
+ sqlite3_prepare_v2(db, sql, -1, &sk, NULL);
+ if (sk) {
+ sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
+ while (sqlite3_step(sk) == SQLITE_ROW) {}
+ sqlite3_finalize(sk);
+ }
+ sqlite3_free(qvec);
+ }
+ }
+
+ // Train
+ sqlite3_exec(db,
+ "INSERT INTO v(rowid) VALUES ('compute-centroids')",
+ NULL, NULL, NULL);
+
+ // Change nprobe at runtime (can exceed nlist -- tests clamping in query)
+ {
+ char cmd[64];
+ snprintf(cmd, sizeof(cmd),
+ "INSERT INTO v(rowid) VALUES ('nprobe=%d')", nprobe_initial);
+ sqlite3_exec(db, cmd, NULL, NULL, NULL);
+ }
+
+ // Multiple KNN queries with different fuzz-derived query vectors
+ for (int q = 0; q < num_queries; q++) {
+ float *qvec = sqlite3_malloc(dim * sizeof(float));
+ if (!qvec) break;
+ for (int d = 0; d < dim; d++) {
+ if (offset < payload_size) {
+ qvec[d] = ((float)(int8_t)payload[offset++]) / 10.0f;
+ } else {
+ qvec[d] = (q == 0) ? 1.0f : 0.0f;
+ }
+ }
+
+ sqlite3_stmt *sk = NULL;
+ snprintf(sql, sizeof(sql),
+ "SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit);
+ sqlite3_prepare_v2(db, sql, -1, &sk, NULL);
+ if (sk) {
+ sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
+ while (sqlite3_step(sk) == SQLITE_ROW) {}
+ sqlite3_finalize(sk);
+ }
+ sqlite3_free(qvec);
+ }
+
+ // Delete half the vectors then query again
+ for (int i = 1; i <= num_vecs / 2; i++) {
+ char delsql[64];
+ snprintf(delsql, sizeof(delsql), "DELETE FROM v WHERE rowid = %d", i);
+ sqlite3_exec(db, delsql, NULL, NULL, NULL);
+ }
+
+ // Query after mass deletion
+ {
+ float *qvec = sqlite3_malloc(dim * sizeof(float));
+ if (qvec) {
+ for (int d = 0; d < dim; d++) qvec[d] = -0.5f;
+ sqlite3_stmt *sk = NULL;
+ snprintf(sql, sizeof(sql),
+ "SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit);
+ sqlite3_prepare_v2(db, sql, -1, &sk, NULL);
+ if (sk) {
+ sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
+ while (sqlite3_step(sk) == SQLITE_ROW) {}
+ sqlite3_finalize(sk);
+ }
+ sqlite3_free(qvec);
+ }
+ }
+
+ sqlite3_close(db);
+ return 0;
+}
diff --git a/tests/fuzz/ivf-operations.c b/tests/fuzz/ivf-operations.c
new file mode 100644
index 0000000..a955870
--- /dev/null
+++ b/tests/fuzz/ivf-operations.c
@@ -0,0 +1,121 @@
+#include
+#include
+#include
+#include
+#include
+#include "sqlite-vec.h"
+#include "sqlite3.h"
+#include
+
+int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
+ if (size < 6) return 0;
+
+ int rc;
+ sqlite3 *db;
+ sqlite3_stmt *stmtInsert = NULL;
+ sqlite3_stmt *stmtDelete = NULL;
+ sqlite3_stmt *stmtKnn = NULL;
+ sqlite3_stmt *stmtScan = NULL;
+
+ rc = sqlite3_open(":memory:", &db);
+ assert(rc == SQLITE_OK);
+ rc = sqlite3_vec_init(db, NULL, NULL);
+ assert(rc == SQLITE_OK);
+
+ rc = sqlite3_exec(db,
+ "CREATE VIRTUAL TABLE v USING vec0(emb float[4] indexed by ivf(nlist=4, nprobe=4))",
+ NULL, NULL, NULL);
+ if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
+
+ sqlite3_prepare_v2(db,
+ "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
+ sqlite3_prepare_v2(db,
+ "DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
+ sqlite3_prepare_v2(db,
+ "SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 3",
+ -1, &stmtKnn, NULL);
+ sqlite3_prepare_v2(db,
+ "SELECT rowid FROM v", -1, &stmtScan, NULL);
+
+ if (!stmtInsert || !stmtDelete || !stmtKnn || !stmtScan) goto cleanup;
+
+ size_t i = 0;
+ while (i + 2 <= size) {
+ uint8_t op = data[i++] % 7;
+ uint8_t rowid_byte = data[i++];
+ int64_t rowid = (int64_t)(rowid_byte % 32) + 1;
+
+ switch (op) {
+ case 0: {
+ // INSERT: consume 16 bytes for 4 floats, or use what's left
+ float vec[4] = {0.0f, 0.0f, 0.0f, 0.0f};
+ for (int j = 0; j < 4 && i < size; j++, i++) {
+ vec[j] = (float)((int8_t)data[i]) / 10.0f;
+ }
+ sqlite3_reset(stmtInsert);
+ sqlite3_bind_int64(stmtInsert, 1, rowid);
+ sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
+ sqlite3_step(stmtInsert);
+ break;
+ }
+ case 1: {
+ // DELETE
+ sqlite3_reset(stmtDelete);
+ sqlite3_bind_int64(stmtDelete, 1, rowid);
+ sqlite3_step(stmtDelete);
+ break;
+ }
+ case 2: {
+ // KNN query with a fixed query vector
+ float qvec[4] = {1.0f, 0.0f, 0.0f, 0.0f};
+ sqlite3_reset(stmtKnn);
+ sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_STATIC);
+ while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
+ break;
+ }
+ case 3: {
+ // Full scan
+ sqlite3_reset(stmtScan);
+ while (sqlite3_step(stmtScan) == SQLITE_ROW) {}
+ break;
+ }
+ case 4: {
+ // compute-centroids command
+ sqlite3_exec(db,
+ "INSERT INTO v(rowid) VALUES ('compute-centroids')",
+ NULL, NULL, NULL);
+ break;
+ }
+ case 5: {
+ // clear-centroids command
+ sqlite3_exec(db,
+ "INSERT INTO v(rowid) VALUES ('clear-centroids')",
+ NULL, NULL, NULL);
+ break;
+ }
+ case 6: {
+ // nprobe=N command
+ if (i < size) {
+ uint8_t n = data[i++];
+ int nprobe = (n % 4) + 1;
+ char buf[64];
+ snprintf(buf, sizeof(buf),
+ "INSERT INTO v(rowid) VALUES ('nprobe=%d')", nprobe);
+ sqlite3_exec(db, buf, NULL, NULL, NULL);
+ }
+ break;
+ }
+ }
+ }
+
+ // Final operations — must not crash regardless of prior state
+ sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
+
+cleanup:
+ sqlite3_finalize(stmtInsert);
+ sqlite3_finalize(stmtDelete);
+ sqlite3_finalize(stmtKnn);
+ sqlite3_finalize(stmtScan);
+ sqlite3_close(db);
+ return 0;
+}
diff --git a/tests/fuzz/ivf-quantize.c b/tests/fuzz/ivf-quantize.c
new file mode 100644
index 0000000..22149ee
--- /dev/null
+++ b/tests/fuzz/ivf-quantize.c
@@ -0,0 +1,129 @@
+/**
+ * Fuzz target: IVF quantization functions.
+ *
+ * Directly exercises ivf_quantize_int8 and ivf_quantize_binary with
+ * fuzz-controlled dimensions and float data. Targets:
+ * - ivf_quantize_int8: clamping, int8 overflow boundary
+ * - ivf_quantize_binary: D not divisible by 8, memset(D/8) undercount
+ * - Round-trip through CREATE TABLE + INSERT with quantized IVF
+ */
+#include
+#include
+#include
+#include
+#include
+#include
+#include "sqlite-vec.h"
+#include "sqlite3.h"
+#include
+
+int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
+ if (size < 8) return 0;
+
+ int rc;
+ sqlite3 *db;
+
+ rc = sqlite3_open(":memory:", &db);
+ assert(rc == SQLITE_OK);
+ rc = sqlite3_vec_init(db, NULL, NULL);
+ assert(rc == SQLITE_OK);
+
+ // Byte 0: quantizer type (0=int8, 1=binary)
+ // Byte 1: dimension (1..64, but we test edge cases)
+ // Byte 2: nlist (1..8)
+ // Byte 3: num_vectors to insert (1..32)
+ // Remaining: float data
+ int qtype = data[0] % 2;
+ int dim = (data[1] % 64) + 1;
+ int nlist = (data[2] % 8) + 1;
+ int num_vecs = (data[3] % 32) + 1;
+ const uint8_t *payload = data + 4;
+ size_t payload_size = size - 4;
+
+ // For binary quantizer, D must be multiple of 8 to avoid the D/8 bug
+ // in production. But we explicitly want to test non-multiples too to
+ // find the bug. Use dim as-is.
+ const char *quantizer = qtype ? "binary" : "int8";
+
+ // Binary quantizer needs D multiple of 8 in current code, but let's
+ // test both valid and invalid dimensions to see what happens.
+ // For binary with non-multiple-of-8, the code does memset(dst, 0, D/8)
+ // which underallocates when D%8 != 0.
+ char sql[256];
+ snprintf(sql, sizeof(sql),
+ "CREATE VIRTUAL TABLE v USING vec0("
+ "emb float[%d] indexed by ivf(nlist=%d, nprobe=%d, quantizer=%s))",
+ dim, nlist, nlist, quantizer);
+
+ rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
+ if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
+
+ // Insert vectors with fuzz-controlled float values
+ sqlite3_stmt *stmtInsert = NULL;
+ sqlite3_prepare_v2(db,
+ "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
+ if (!stmtInsert) { sqlite3_close(db); return 0; }
+
+ size_t offset = 0;
+ for (int i = 0; i < num_vecs && offset < payload_size; i++) {
+ // Build float vector from fuzz data
+ float *vec = sqlite3_malloc(dim * sizeof(float));
+ if (!vec) break;
+
+ for (int d = 0; d < dim; d++) {
+ if (offset + 4 <= payload_size) {
+ // Use raw bytes as float -- can produce NaN, Inf, denormals
+ memcpy(&vec[d], payload + offset, sizeof(float));
+ offset += 4;
+ } else if (offset < payload_size) {
+ // Partial: use byte as scaled value
+ vec[d] = ((float)(int8_t)payload[offset++]) / 50.0f;
+ } else {
+ vec[d] = 0.0f;
+ }
+ }
+
+ sqlite3_reset(stmtInsert);
+ sqlite3_bind_int64(stmtInsert, 1, (int64_t)(i + 1));
+ sqlite3_bind_blob(stmtInsert, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT);
+ sqlite3_step(stmtInsert);
+ sqlite3_free(vec);
+ }
+ sqlite3_finalize(stmtInsert);
+
+ // Trigger compute-centroids to exercise kmeans + quantization together
+ sqlite3_exec(db,
+ "INSERT INTO v(rowid) VALUES ('compute-centroids')",
+ NULL, NULL, NULL);
+
+ // KNN query with fuzz-derived query vector
+ {
+ float *qvec = sqlite3_malloc(dim * sizeof(float));
+ if (qvec) {
+ for (int d = 0; d < dim; d++) {
+ if (offset < payload_size) {
+ qvec[d] = ((float)(int8_t)payload[offset++]) / 10.0f;
+ } else {
+ qvec[d] = 1.0f;
+ }
+ }
+
+ sqlite3_stmt *stmtKnn = NULL;
+ sqlite3_prepare_v2(db,
+ "SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 5",
+ -1, &stmtKnn, NULL);
+ if (stmtKnn) {
+ sqlite3_bind_blob(stmtKnn, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
+ while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
+ sqlite3_finalize(stmtKnn);
+ }
+ sqlite3_free(qvec);
+ }
+ }
+
+ // Full scan
+ sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
+
+ sqlite3_close(db);
+ return 0;
+}
diff --git a/tests/fuzz/ivf-rescore.c b/tests/fuzz/ivf-rescore.c
new file mode 100644
index 0000000..1c3f34a
--- /dev/null
+++ b/tests/fuzz/ivf-rescore.c
@@ -0,0 +1,182 @@
+/**
+ * Fuzz target: IVF oversample + rescore path.
+ *
+ * Specifically targets the code path where quantizer != none AND
+ * oversample > 1, which triggers:
+ * 1. Quantized KNN scan to collect oversample*k candidates
+ * 2. Full-precision vector lookup from _ivf_vectors table
+ * 3. Re-scoring with float32 distances
+ * 4. Re-sort and truncation
+ *
+ * This path has the most complex memory management in the KNN query:
+ * - Two separate distance computations (quantized + float)
+ * - Cross-table lookups (cells + vectors KV store)
+ * - Candidate array resizing
+ * - qsort over partially re-scored arrays
+ *
+ * Also tests the int8 + binary quantization round-trip fidelity
+ * under adversarial float inputs.
+ */
+#include
+#include
+#include
+#include
+#include
+#include
+#include "sqlite-vec.h"
+#include "sqlite3.h"
+#include
+
+int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
+ if (size < 12) return 0;
+
+ int rc;
+ sqlite3 *db;
+
+ rc = sqlite3_open(":memory:", &db);
+ assert(rc == SQLITE_OK);
+ rc = sqlite3_vec_init(db, NULL, NULL);
+ assert(rc == SQLITE_OK);
+
+ // Header
+ int quantizer_type = (data[0] % 2) + 1; // 1=int8, 2=binary (never none)
+ int dim = (data[1] % 32) + 8; // 8..39
+ int nlist = (data[2] % 8) + 1; // 1..8
+ int oversample = (data[3] % 4) + 2; // 2..5 (always > 1)
+ int num_vecs = (data[4] % 60) + 8; // 8..67
+ int k_limit = (data[5] % 15) + 1; // 1..15
+
+ const uint8_t *payload = data + 6;
+ size_t payload_size = size - 6;
+
+ // Binary quantizer needs D multiple of 8
+ if (quantizer_type == 2) {
+ dim = ((dim + 7) / 8) * 8;
+ }
+
+ const char *qname = (quantizer_type == 1) ? "int8" : "binary";
+
+ char sql[512];
+ snprintf(sql, sizeof(sql),
+ "CREATE VIRTUAL TABLE v USING vec0("
+ "emb float[%d] indexed by ivf(nlist=%d, nprobe=%d, quantizer=%s, oversample=%d))",
+ dim, nlist, nlist, qname, oversample);
+
+ rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
+ if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
+
+ // Insert vectors with diverse values
+ sqlite3_stmt *stmtInsert = NULL;
+ sqlite3_prepare_v2(db,
+ "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
+ if (!stmtInsert) { sqlite3_close(db); return 0; }
+
+ size_t offset = 0;
+ for (int i = 0; i < num_vecs; i++) {
+ float *vec = sqlite3_malloc(dim * sizeof(float));
+ if (!vec) break;
+ for (int d = 0; d < dim; d++) {
+ if (offset + 4 <= payload_size) {
+ // Use raw bytes as float for adversarial values
+ memcpy(&vec[d], payload + offset, sizeof(float));
+ offset += 4;
+ // Sanitize: replace NaN/Inf with bounded values to avoid
+ // poisoning the entire computation. We want edge values,
+ // not complete nonsense.
+ if (isnan(vec[d]) || isinf(vec[d])) {
+ vec[d] = (vec[d] > 0) ? 1e6f : -1e6f;
+ if (isnan(vec[d])) vec[d] = 0.0f;
+ }
+ } else if (offset < payload_size) {
+ vec[d] = ((float)(int8_t)payload[offset++]) / 30.0f;
+ } else {
+ vec[d] = (float)(i * dim + d) * 0.001f;
+ }
+ }
+ sqlite3_reset(stmtInsert);
+ sqlite3_bind_int64(stmtInsert, 1, (int64_t)(i + 1));
+ sqlite3_bind_blob(stmtInsert, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT);
+ sqlite3_step(stmtInsert);
+ sqlite3_free(vec);
+ }
+ sqlite3_finalize(stmtInsert);
+
+ // Train
+ sqlite3_exec(db,
+ "INSERT INTO v(rowid) VALUES ('compute-centroids')",
+ NULL, NULL, NULL);
+
+ // Multiple KNN queries to exercise rescore path
+ for (int q = 0; q < 4; q++) {
+ float *qvec = sqlite3_malloc(dim * sizeof(float));
+ if (!qvec) break;
+ for (int d = 0; d < dim; d++) {
+ if (offset < payload_size) {
+ qvec[d] = ((float)(int8_t)payload[offset++]) / 10.0f;
+ } else {
+ qvec[d] = (q == 0) ? 1.0f : (q == 1) ? -1.0f : 0.0f;
+ }
+ }
+
+ sqlite3_stmt *sk = NULL;
+ snprintf(sql, sizeof(sql),
+ "SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit);
+ sqlite3_prepare_v2(db, sql, -1, &sk, NULL);
+ if (sk) {
+ sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
+ while (sqlite3_step(sk) == SQLITE_ROW) {}
+ sqlite3_finalize(sk);
+ }
+ sqlite3_free(qvec);
+ }
+
+ // Delete some vectors, then query again (rescore with missing _ivf_vectors rows)
+ for (int i = 1; i <= num_vecs / 3; i++) {
+ char delsql[64];
+ snprintf(delsql, sizeof(delsql), "DELETE FROM v WHERE rowid = %d", i);
+ sqlite3_exec(db, delsql, NULL, NULL, NULL);
+ }
+
+ {
+ float *qvec = sqlite3_malloc(dim * sizeof(float));
+ if (qvec) {
+ for (int d = 0; d < dim; d++) qvec[d] = 0.5f;
+ sqlite3_stmt *sk = NULL;
+ snprintf(sql, sizeof(sql),
+ "SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit);
+ sqlite3_prepare_v2(db, sql, -1, &sk, NULL);
+ if (sk) {
+ sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
+ while (sqlite3_step(sk) == SQLITE_ROW) {}
+ sqlite3_finalize(sk);
+ }
+ sqlite3_free(qvec);
+ }
+ }
+
+ // Retrain after deletions
+ sqlite3_exec(db,
+ "INSERT INTO v(rowid) VALUES ('compute-centroids')",
+ NULL, NULL, NULL);
+
+ // Query after retrain
+ {
+ float *qvec = sqlite3_malloc(dim * sizeof(float));
+ if (qvec) {
+ for (int d = 0; d < dim; d++) qvec[d] = -0.3f;
+ sqlite3_stmt *sk = NULL;
+ snprintf(sql, sizeof(sql),
+ "SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit);
+ sqlite3_prepare_v2(db, sql, -1, &sk, NULL);
+ if (sk) {
+ sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
+ while (sqlite3_step(sk) == SQLITE_ROW) {}
+ sqlite3_finalize(sk);
+ }
+ sqlite3_free(qvec);
+ }
+ }
+
+ sqlite3_close(db);
+ return 0;
+}
diff --git a/tests/fuzz/ivf-shadow-corrupt.c b/tests/fuzz/ivf-shadow-corrupt.c
new file mode 100644
index 0000000..1153ac9
--- /dev/null
+++ b/tests/fuzz/ivf-shadow-corrupt.c
@@ -0,0 +1,228 @@
+/**
+ * Fuzz target: IVF shadow table corruption.
+ *
+ * Creates a trained IVF table, then corrupts IVF shadow table blobs
+ * (centroids, cells validity/rowids/vectors, rowid_map) with fuzz data.
+ * Then exercises all read/write paths. Must not crash.
+ *
+ * Targets:
+ * - Cell validity bitmap with wrong size
+ * - Cell rowids blob with wrong size/alignment
+ * - Cell vectors blob with wrong size
+ * - Centroid blob with wrong size
+ * - n_vectors inconsistent with validity bitmap
+ * - Missing rowid_map entries
+ * - KNN scan over corrupted cells
+ * - Insert/delete with corrupted rowid_map
+ */
+#include
+#include
+#include
+#include
+#include
+#include "sqlite-vec.h"
+#include "sqlite3.h"
+#include
+
+int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
+ if (size < 4) return 0;
+
+ int rc;
+ sqlite3 *db;
+
+ rc = sqlite3_open(":memory:", &db);
+ assert(rc == SQLITE_OK);
+ rc = sqlite3_vec_init(db, NULL, NULL);
+ assert(rc == SQLITE_OK);
+
+ // Create IVF table and insert enough vectors to train
+ rc = sqlite3_exec(db,
+ "CREATE VIRTUAL TABLE v USING vec0("
+ "emb float[8] indexed by ivf(nlist=2, nprobe=2))",
+ NULL, NULL, NULL);
+ if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
+
+ // Insert 10 vectors
+ {
+ sqlite3_stmt *si = NULL;
+ sqlite3_prepare_v2(db,
+ "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &si, NULL);
+ if (!si) { sqlite3_close(db); return 0; }
+ for (int i = 0; i < 10; i++) {
+ float vec[8];
+ for (int d = 0; d < 8; d++) {
+ vec[d] = (float)(i * 8 + d) * 0.1f;
+ }
+ sqlite3_reset(si);
+ sqlite3_bind_int64(si, 1, i + 1);
+ sqlite3_bind_blob(si, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
+ sqlite3_step(si);
+ }
+ sqlite3_finalize(si);
+ }
+
+ // Train
+ sqlite3_exec(db,
+ "INSERT INTO v(rowid) VALUES ('compute-centroids')",
+ NULL, NULL, NULL);
+
+ // Now corrupt shadow tables based on fuzz input
+ uint8_t target = data[0] % 10;
+ const uint8_t *payload = data + 1;
+ int payload_size = (int)(size - 1);
+
+ // Limit payload to avoid huge allocations
+ if (payload_size > 4096) payload_size = 4096;
+
+ sqlite3_stmt *stmt = NULL;
+
+ switch (target) {
+ case 0: {
+ // Corrupt cell validity blob
+ rc = sqlite3_prepare_v2(db,
+ "UPDATE v_ivf_cells00 SET validity = ? WHERE rowid = 1",
+ -1, &stmt, NULL);
+ if (rc == SQLITE_OK) {
+ sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC);
+ sqlite3_step(stmt); sqlite3_finalize(stmt);
+ }
+ break;
+ }
+ case 1: {
+ // Corrupt cell rowids blob
+ rc = sqlite3_prepare_v2(db,
+ "UPDATE v_ivf_cells00 SET rowids = ? WHERE rowid = 1",
+ -1, &stmt, NULL);
+ if (rc == SQLITE_OK) {
+ sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC);
+ sqlite3_step(stmt); sqlite3_finalize(stmt);
+ }
+ break;
+ }
+ case 2: {
+ // Corrupt cell vectors blob
+ rc = sqlite3_prepare_v2(db,
+ "UPDATE v_ivf_cells00 SET vectors = ? WHERE rowid = 1",
+ -1, &stmt, NULL);
+ if (rc == SQLITE_OK) {
+ sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC);
+ sqlite3_step(stmt); sqlite3_finalize(stmt);
+ }
+ break;
+ }
+ case 3: {
+ // Corrupt centroid blob
+ rc = sqlite3_prepare_v2(db,
+ "UPDATE v_ivf_centroids00 SET centroid = ? WHERE centroid_id = 0",
+ -1, &stmt, NULL);
+ if (rc == SQLITE_OK) {
+ sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC);
+ sqlite3_step(stmt); sqlite3_finalize(stmt);
+ }
+ break;
+ }
+ case 4: {
+ // Set n_vectors to a bogus value (larger than cell capacity)
+ int bogus_n = 99999;
+ if (payload_size >= 4) {
+ memcpy(&bogus_n, payload, 4);
+ bogus_n = abs(bogus_n) % 100000;
+ }
+ char sql[128];
+ snprintf(sql, sizeof(sql),
+ "UPDATE v_ivf_cells00 SET n_vectors = %d WHERE rowid = 1", bogus_n);
+ sqlite3_exec(db, sql, NULL, NULL, NULL);
+ break;
+ }
+ case 5: {
+ // Delete rowid_map entries (orphan vectors)
+ sqlite3_exec(db,
+ "DELETE FROM v_ivf_rowid_map00 WHERE rowid IN (1, 2, 3)",
+ NULL, NULL, NULL);
+ break;
+ }
+ case 6: {
+ // Corrupt rowid_map slot values
+ char sql[128];
+ int bogus_slot = payload_size > 0 ? (int)payload[0] * 1000 : 99999;
+ snprintf(sql, sizeof(sql),
+ "UPDATE v_ivf_rowid_map00 SET slot = %d WHERE rowid = 1", bogus_slot);
+ sqlite3_exec(db, sql, NULL, NULL, NULL);
+ break;
+ }
+ case 7: {
+ // Corrupt rowid_map cell_id values
+ sqlite3_exec(db,
+ "UPDATE v_ivf_rowid_map00 SET cell_id = 99999 WHERE rowid = 1",
+ NULL, NULL, NULL);
+ break;
+ }
+ case 8: {
+ // Delete all centroids (make trained but no centroids)
+ sqlite3_exec(db,
+ "DELETE FROM v_ivf_centroids00",
+ NULL, NULL, NULL);
+ break;
+ }
+ case 9: {
+ // Set validity to NULL
+ sqlite3_exec(db,
+ "UPDATE v_ivf_cells00 SET validity = NULL WHERE rowid = 1",
+ NULL, NULL, NULL);
+ break;
+ }
+ }
+
+ // Exercise all read paths over corrupted state — must not crash
+ float qvec[8] = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
+
+ // KNN query
+ {
+ sqlite3_stmt *sk = NULL;
+ sqlite3_prepare_v2(db,
+ "SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 5",
+ -1, &sk, NULL);
+ if (sk) {
+ sqlite3_bind_blob(sk, 1, qvec, sizeof(qvec), SQLITE_STATIC);
+ while (sqlite3_step(sk) == SQLITE_ROW) {}
+ sqlite3_finalize(sk);
+ }
+ }
+
+ // Full scan
+ sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
+
+ // Point query
+ sqlite3_exec(db, "SELECT * FROM v WHERE rowid = 1", NULL, NULL, NULL);
+ sqlite3_exec(db, "SELECT * FROM v WHERE rowid = 5", NULL, NULL, NULL);
+
+ // Delete
+ sqlite3_exec(db, "DELETE FROM v WHERE rowid = 3", NULL, NULL, NULL);
+
+ // Insert after corruption
+ {
+ float newvec[8] = {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f};
+ sqlite3_stmt *si = NULL;
+ sqlite3_prepare_v2(db,
+ "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &si, NULL);
+ if (si) {
+ sqlite3_bind_int64(si, 1, 100);
+ sqlite3_bind_blob(si, 2, newvec, sizeof(newvec), SQLITE_STATIC);
+ sqlite3_step(si);
+ sqlite3_finalize(si);
+ }
+ }
+
+ // compute-centroids over corrupted state
+ sqlite3_exec(db,
+ "INSERT INTO v(rowid) VALUES ('compute-centroids')",
+ NULL, NULL, NULL);
+
+ // clear-centroids
+ sqlite3_exec(db,
+ "INSERT INTO v(rowid) VALUES ('clear-centroids')",
+ NULL, NULL, NULL);
+
+ sqlite3_close(db);
+ return 0;
+}
diff --git a/tests/sqlite-vec-internal.h b/tests/sqlite-vec-internal.h
index ca04b74..67f1370 100644
--- a/tests/sqlite-vec-internal.h
+++ b/tests/sqlite-vec-internal.h
@@ -5,6 +5,10 @@
#include
#include
+#ifndef SQLITE_VEC_ENABLE_IVF
+#define SQLITE_VEC_ENABLE_IVF 1
+#endif
+
int min_idx(
const float *distances,
int32_t n,
@@ -68,8 +72,36 @@ enum Vec0IndexType {
#ifdef SQLITE_VEC_ENABLE_RESCORE
VEC0_INDEX_TYPE_RESCORE = 2,
#endif
+ VEC0_INDEX_TYPE_IVF = 3,
};
+enum Vec0RescoreQuantizerType {
+ VEC0_RESCORE_QUANTIZER_BIT = 1,
+ VEC0_RESCORE_QUANTIZER_INT8 = 2,
+};
+
+struct Vec0RescoreConfig {
+ enum Vec0RescoreQuantizerType quantizer_type;
+ int oversample;
+};
+
+#if SQLITE_VEC_ENABLE_IVF
+enum Vec0IvfQuantizer {
+ VEC0_IVF_QUANTIZER_NONE = 0,
+ VEC0_IVF_QUANTIZER_INT8 = 1,
+ VEC0_IVF_QUANTIZER_BINARY = 2,
+};
+
+struct Vec0IvfConfig {
+ int nlist;
+ int nprobe;
+ int quantizer;
+ int oversample;
+};
+#else
+struct Vec0IvfConfig { char _unused; };
+#endif
+
#ifdef SQLITE_VEC_ENABLE_RESCORE
enum Vec0RescoreQuantizerType {
VEC0_RESCORE_QUANTIZER_BIT = 1,
@@ -93,6 +125,7 @@ struct VectorColumnDefinition {
#ifdef SQLITE_VEC_ENABLE_RESCORE
struct Vec0RescoreConfig rescore;
#endif
+ struct Vec0IvfConfig ivf;
};
int vec0_parse_vector_column(const char *source, int source_length,
@@ -114,6 +147,10 @@ void _test_rescore_quantize_float_to_int8(const float *src, int8_t *dst, size_t
size_t _test_rescore_quantized_byte_size_bit(size_t dimensions);
size_t _test_rescore_quantized_byte_size_int8(size_t dimensions);
#endif
+#if SQLITE_VEC_ENABLE_IVF
+void ivf_quantize_int8(const float *src, int8_t *dst, int D);
+void ivf_quantize_binary(const float *src, uint8_t *dst, int D);
+#endif
#endif
#endif /* SQLITE_VEC_INTERNAL_H */
diff --git a/tests/test-ivf-mutations.py b/tests/test-ivf-mutations.py
new file mode 100644
index 0000000..5c61119
--- /dev/null
+++ b/tests/test-ivf-mutations.py
@@ -0,0 +1,575 @@
+"""
+Thorough IVF mutation tests: insert, delete, update, KNN correctness,
+error cases, edge cases, and cell overflow scenarios.
+"""
+import pytest
+import sqlite3
+import struct
+import math
+from helpers import _f32, exec
+
+
+@pytest.fixture()
+def db():
+ db = sqlite3.connect(":memory:")
+ db.row_factory = sqlite3.Row
+ db.enable_load_extension(True)
+ db.load_extension("dist/vec0")
+ db.enable_load_extension(False)
+ return db
+
+
+def ivf_total_vectors(db, table="t", col=0):
+ """Count total vectors across all IVF cells."""
+ return db.execute(
+ f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d}"
+ ).fetchone()[0]
+
+
+def ivf_unassigned_count(db, table="t", col=0):
+ return db.execute(
+ f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d} WHERE centroid_id = -1"
+ ).fetchone()[0]
+
+
+def ivf_assigned_count(db, table="t", col=0):
+ return db.execute(
+ f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d} WHERE centroid_id >= 0"
+ ).fetchone()[0]
+
+
+def knn(db, query, k, table="t", col="v"):
+ """Run a KNN query and return list of (rowid, distance) tuples."""
+ rows = db.execute(
+ f"SELECT rowid, distance FROM {table} WHERE {col} MATCH ? AND k = ?",
+ [_f32(query), k],
+ ).fetchall()
+ return [(r[0], r[1]) for r in rows]
+
+
+# ============================================================================
+# Single row insert + KNN
+# ============================================================================
+
+
+def test_insert_single_row_knn(db):
+ db.execute("CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf())")
+ db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 0, 0, 0])])
+ results = knn(db, [1, 0, 0, 0], 5)
+ assert len(results) == 1
+ assert results[0][0] == 1
+ assert results[0][1] < 0.001
+
+
+# ============================================================================
+# Batch insert + KNN recall
+# ============================================================================
+
+
+def test_batch_insert_knn_recall(db):
+ """Insert 200 vectors, train, verify KNN recall with nprobe=nlist."""
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=8, nprobe=8))"
+ )
+ for i in range(200):
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (?, ?)",
+ [i, _f32([float(i), 0, 0, 0])],
+ )
+ assert ivf_total_vectors(db) == 200
+
+ db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
+ assert ivf_assigned_count(db) == 200
+
+ # Query near 100 -- closest should be rowid 100
+ results = knn(db, [100.0, 0, 0, 0], 10)
+ assert len(results) == 10
+ assert results[0][0] == 100
+ assert results[0][1] < 0.01
+
+ # All results should be near 100
+ rowids = {r[0] for r in results}
+ assert all(95 <= r <= 105 for r in rowids)
+
+
+# ============================================================================
+# Delete rows, verify they're gone from KNN
+# ============================================================================
+
+
+def test_delete_rows_gone_from_knn(db):
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
+ )
+ for i in range(20):
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (?, ?)",
+ [i, _f32([float(i), 0, 0, 0])],
+ )
+
+ db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
+
+ # Delete rowid 10
+ db.execute("DELETE FROM t WHERE rowid = 10")
+
+ results = knn(db, [10.0, 0, 0, 0], 20)
+ rowids = {r[0] for r in results}
+ assert 10 not in rowids
+
+
+def test_delete_all_rows_empty_results(db):
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
+ )
+ for i in range(10):
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (?, ?)",
+ [i, _f32([float(i), 0, 0, 0])],
+ )
+
+ db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
+
+ for i in range(10):
+ db.execute("DELETE FROM t WHERE rowid = ?", [i])
+
+ assert ivf_total_vectors(db) == 0
+ results = knn(db, [5.0, 0, 0, 0], 10)
+ assert len(results) == 0
+
+
+# ============================================================================
+# Insert after delete (reuse rowids)
+# ============================================================================
+
+
+def test_insert_after_delete_reuse_rowid(db):
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
+ )
+ for i in range(10):
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (?, ?)",
+ [i, _f32([float(i), 0, 0, 0])],
+ )
+
+ db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
+
+ # Delete rowid 5
+ db.execute("DELETE FROM t WHERE rowid = 5")
+
+ # Re-insert rowid 5 with a very different vector
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (5, ?)", [_f32([999.0, 0, 0, 0])]
+ )
+
+ # KNN near 999 should find rowid 5
+ results = knn(db, [999.0, 0, 0, 0], 1)
+ assert len(results) >= 1
+ assert results[0][0] == 5
+
+
+# ============================================================================
+# Update vectors (INSERT OR REPLACE), verify KNN reflects new values
+# ============================================================================
+
+
+def test_update_vector_via_delete_insert(db):
+ """vec0 IVF update: delete then re-insert with new vector."""
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
+ )
+ for i in range(10):
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (?, ?)",
+ [i, _f32([float(i), 0, 0, 0])],
+ )
+
+ db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
+
+ # "Update" rowid 3: delete and re-insert with new vector
+ db.execute("DELETE FROM t WHERE rowid = 3")
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (3, ?)",
+ [_f32([100.0, 0, 0, 0])],
+ )
+
+ # KNN near 100 should find rowid 3
+ results = knn(db, [100.0, 0, 0, 0], 1)
+ assert results[0][0] == 3
+
+
+# ============================================================================
+# Error cases: IVF + auxiliary/metadata/partition key columns
+# ============================================================================
+
+
+def test_error_ivf_with_auxiliary_column(db):
+ result = exec(
+ db,
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(), +extra text)",
+ )
+ assert "error" in result
+ assert "auxiliary" in result.get("message", "").lower()
+
+
+def test_error_ivf_with_metadata_column(db):
+ result = exec(
+ db,
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(), genre text)",
+ )
+ assert "error" in result
+ assert "metadata" in result.get("message", "").lower()
+
+
+def test_error_ivf_with_partition_key(db):
+ result = exec(
+ db,
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(), user_id integer partition key)",
+ )
+ assert "error" in result
+ assert "partition" in result.get("message", "").lower()
+
+
+def test_flat_with_auxiliary_still_works(db):
+ """Regression guard: flat-indexed tables with aux columns should still work."""
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4], +extra text)"
+ )
+ db.execute(
+ "INSERT INTO t(rowid, v, extra) VALUES (1, ?, 'hello')",
+ [_f32([1, 0, 0, 0])],
+ )
+ row = db.execute("SELECT extra FROM t WHERE rowid = 1").fetchone()
+ assert row[0] == "hello"
+
+
+def test_flat_with_metadata_still_works(db):
+ """Regression guard: flat-indexed tables with metadata columns should still work."""
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4], genre text)"
+ )
+ db.execute(
+ "INSERT INTO t(rowid, v, genre) VALUES (1, ?, 'rock')",
+ [_f32([1, 0, 0, 0])],
+ )
+ row = db.execute("SELECT genre FROM t WHERE rowid = 1").fetchone()
+ assert row[0] == "rock"
+
+
+def test_flat_with_partition_key_still_works(db):
+ """Regression guard: flat-indexed tables with partition key should still work."""
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4], user_id integer partition key)"
+ )
+ db.execute(
+ "INSERT INTO t(rowid, v, user_id) VALUES (1, ?, 42)",
+ [_f32([1, 0, 0, 0])],
+ )
+ row = db.execute("SELECT user_id FROM t WHERE rowid = 1").fetchone()
+ assert row[0] == 42
+
+
+# ============================================================================
+# Edge cases
+# ============================================================================
+
+
+def test_zero_vectors(db):
+ """Insert zero vectors, verify KNN still works."""
+ db.execute("CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf())")
+ for i in range(5):
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (?, ?)",
+ [i, _f32([0, 0, 0, 0])],
+ )
+ results = knn(db, [0, 0, 0, 0], 5)
+ assert len(results) == 5
+ # All distances should be 0
+ for _, dist in results:
+ assert dist < 0.001
+
+
+def test_large_values(db):
+ """Insert vectors with very large and small values."""
+ db.execute("CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf())")
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1e6, 1e6, 1e6, 1e6])]
+ )
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([1e-6, 1e-6, 1e-6, 1e-6])]
+ )
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (3, ?)", [_f32([-1e6, -1e6, -1e6, -1e6])]
+ )
+
+ results = knn(db, [1e6, 1e6, 1e6, 1e6], 3)
+ assert results[0][0] == 1
+
+
+def test_single_row_compute_centroids(db):
+ """Single row table, compute-centroids should still work."""
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=1))"
+ )
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 2, 3, 4])]
+ )
+ db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
+ assert ivf_assigned_count(db) == 1
+
+ results = knn(db, [1, 2, 3, 4], 1)
+ assert len(results) == 1
+ assert results[0][0] == 1
+
+
+# ============================================================================
+# Cell overflow (many vectors in one cell)
+# ============================================================================
+
+
+def test_cell_overflow_many_vectors(db):
+ """Insert >64 vectors that all go to same centroid. Should create multiple cells."""
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=0))"
+ )
+ # Insert 100 very similar vectors
+ for i in range(100):
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (?, ?)",
+ [i, _f32([1.0 + i * 0.001, 0, 0, 0])],
+ )
+
+ # Set a single centroid so all vectors go there
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES ('set-centroid:0', ?)",
+ [_f32([1.0, 0, 0, 0])],
+ )
+ db.execute("INSERT INTO t(rowid) VALUES ('assign-vectors')")
+
+ assert ivf_assigned_count(db) == 100
+
+ # Should have more than 1 cell (64 max per cell)
+ cell_count = db.execute(
+ "SELECT count(*) FROM t_ivf_cells00 WHERE centroid_id = 0"
+ ).fetchone()[0]
+ assert cell_count >= 2 # 100 / 64 = 2 cells needed
+
+ # All vectors should be queryable
+ results = knn(db, [1.0, 0, 0, 0], 100)
+ assert len(results) == 100
+
+
+# ============================================================================
+# Large batch with training
+# ============================================================================
+
+
+def test_large_batch_with_training(db):
+ """Insert 500, train, insert 500 more, verify total is 1000."""
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=16, nprobe=16))"
+ )
+ for i in range(500):
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (?, ?)",
+ [i, _f32([float(i), 0, 0, 0])],
+ )
+
+ db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
+
+ for i in range(500, 1000):
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (?, ?)",
+ [i, _f32([float(i), 0, 0, 0])],
+ )
+
+ assert ivf_total_vectors(db) == 1000
+
+ # KNN should still work
+ results = knn(db, [750.0, 0, 0, 0], 5)
+ assert len(results) == 5
+ assert results[0][0] == 750
+
+
+# ============================================================================
+# KNN after interleaved insert/delete
+# ============================================================================
+
+
+def test_knn_after_interleaved_insert_delete(db):
+ """Insert 20, train, delete 10 closest to query, verify remaining."""
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4, nprobe=4))"
+ )
+ for i in range(20):
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (?, ?)",
+ [i, _f32([float(i), 0, 0, 0])],
+ )
+
+ db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
+
+ # Delete rowids 0-9 (closest to query at 5.0)
+ for i in range(10):
+ db.execute("DELETE FROM t WHERE rowid = ?", [i])
+
+ results = knn(db, [5.0, 0, 0, 0], 10)
+ rowids = {r[0] for r in results}
+ # None of the deleted rowids should appear
+ assert all(r >= 10 for r in rowids)
+ assert len(results) == 10
+
+
+def test_knn_empty_centroids_after_deletes(db):
+ """Some centroids may become empty after deletes. Should not crash."""
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4, nprobe=2))"
+ )
+ # Insert vectors clustered near 4 points
+ for i in range(40):
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (?, ?)",
+ [i, _f32([float(i % 10) * 10, 0, 0, 0])],
+ )
+
+ db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
+
+ # Delete a bunch, potentially emptying some centroids
+ for i in range(30):
+ db.execute("DELETE FROM t WHERE rowid = ?", [i])
+
+ # Should not crash even with empty centroids
+ results = knn(db, [50.0, 0, 0, 0], 20)
+ assert len(results) <= 10 # only 10 left
+
+
+# ============================================================================
+# KNN returns correct distances
+# ============================================================================
+
+
+def test_knn_correct_distances(db):
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
+ )
+ db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([0, 0, 0, 0])])
+ db.execute("INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([3, 0, 0, 0])])
+ db.execute("INSERT INTO t(rowid, v) VALUES (3, ?)", [_f32([0, 4, 0, 0])])
+
+ db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
+
+ results = knn(db, [0, 0, 0, 0], 3)
+ result_map = {r[0]: r[1] for r in results}
+
+ # L2 distances (sqrt of sum of squared differences)
+ assert abs(result_map[1] - 0.0) < 0.01
+ assert abs(result_map[2] - 3.0) < 0.01 # sqrt(3^2) = 3
+ assert abs(result_map[3] - 4.0) < 0.01 # sqrt(4^2) = 4
+
+
+# ============================================================================
+# Delete in flat mode leaves no orphan rowid_map entries
+# ============================================================================
+
+
+def test_delete_flat_mode_rowid_map_count(db):
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
+ )
+ for i in range(5):
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (?, ?)",
+ [i, _f32([float(i), 0, 0, 0])],
+ )
+
+ db.execute("DELETE FROM t WHERE rowid = 0")
+ db.execute("DELETE FROM t WHERE rowid = 2")
+ db.execute("DELETE FROM t WHERE rowid = 4")
+
+ assert db.execute("SELECT count(*) FROM t_ivf_rowid_map00").fetchone()[0] == 2
+ assert ivf_unassigned_count(db) == 2
+
+
+# ============================================================================
+# Duplicate rowid insert
+# ============================================================================
+
+
+def test_delete_reinsert_as_update(db):
+ """Simulate update via delete + insert on same rowid."""
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
+ )
+ db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 0, 0, 0])])
+
+ # Delete then re-insert as "update"
+ db.execute("DELETE FROM t WHERE rowid = 1")
+ db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([99, 0, 0, 0])])
+
+ results = knn(db, [99, 0, 0, 0], 1)
+ assert len(results) == 1
+ assert results[0][0] == 1
+ assert results[0][1] < 0.01
+
+
+def test_duplicate_rowid_insert_fails(db):
+ """Inserting a duplicate rowid should fail with a constraint error."""
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
+ )
+ db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 0, 0, 0])])
+
+ result = exec(
+ db,
+ "INSERT INTO t(rowid, v) VALUES (1, ?)",
+ [_f32([99, 0, 0, 0])],
+ )
+ assert "error" in result
+
+
+# ============================================================================
+# Interleaved insert/delete with KNN correctness
+# ============================================================================
+
+
+def test_interleaved_ops_correctness(db):
+ """Complex sequence of inserts and deletes, verify KNN is always correct."""
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4, nprobe=4))"
+ )
+
+ # Phase 1: Insert 50 vectors
+ for i in range(50):
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (?, ?)",
+ [i, _f32([float(i), 0, 0, 0])],
+ )
+
+ db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
+
+ # Phase 2: Delete even-numbered rowids
+ for i in range(0, 50, 2):
+ db.execute("DELETE FROM t WHERE rowid = ?", [i])
+
+ # Phase 3: Insert new vectors with higher rowids
+ for i in range(50, 75):
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (?, ?)",
+ [i, _f32([float(i), 0, 0, 0])],
+ )
+
+ # Phase 4: Delete some of the new ones
+ for i in range(60, 70):
+ db.execute("DELETE FROM t WHERE rowid = ?", [i])
+
+ # KNN query: should only find existing vectors
+ results = knn(db, [25.0, 0, 0, 0], 50)
+ rowids = {r[0] for r in results}
+
+ # Verify no deleted rowids appear
+ deleted = set(range(0, 50, 2)) | set(range(60, 70))
+ assert len(rowids & deleted) == 0
+
+ # Verify we get the right count (25 odd + 15 new - 10 deleted new = 30)
+ expected_alive = set(range(1, 50, 2)) | set(range(50, 60)) | set(range(70, 75))
+ assert rowids.issubset(expected_alive)
diff --git a/tests/test-ivf-quantization.py b/tests/test-ivf-quantization.py
new file mode 100644
index 0000000..9790680
--- /dev/null
+++ b/tests/test-ivf-quantization.py
@@ -0,0 +1,255 @@
+import pytest
+import sqlite3
+from helpers import _f32, exec
+
+
+@pytest.fixture()
+def db():
+ db = sqlite3.connect(":memory:")
+ db.row_factory = sqlite3.Row
+ db.enable_load_extension(True)
+ db.load_extension("dist/vec0")
+ db.enable_load_extension(False)
+ return db
+
+
+# ============================================================================
+# Parser tests — quantizer and oversample options
+# ============================================================================
+
+
+def test_ivf_quantizer_binary(db):
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0("
+ "v float[768] indexed by ivf(nlist=64, quantizer=binary, oversample=10))"
+ )
+ tables = [
+ r[0]
+ for r in db.execute(
+ "SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
+ ).fetchall()
+ ]
+ assert "t_ivf_centroids00" in tables
+ assert "t_ivf_cells00" in tables
+
+
+def test_ivf_quantizer_int8(db):
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0("
+ "v float[768] indexed by ivf(nlist=64, quantizer=int8))"
+ )
+ tables = [
+ r[0]
+ for r in db.execute(
+ "SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
+ ).fetchall()
+ ]
+ assert "t_ivf_centroids00" in tables
+
+
+def test_ivf_quantizer_none_explicit(db):
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0("
+ "v float[768] indexed by ivf(quantizer=none))"
+ )
+ # Should work — same as no quantizer
+ tables = [
+ r[0]
+ for r in db.execute(
+ "SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
+ ).fetchall()
+ ]
+ assert "t_ivf_centroids00" in tables
+
+
+def test_ivf_quantizer_all_params(db):
+ """All params together: nlist, nprobe, quantizer, oversample."""
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0("
+ "v float[768] distance_metric=cosine "
+ "indexed by ivf(nlist=128, nprobe=16, quantizer=int8, oversample=4))"
+ )
+ tables = [
+ r[0]
+ for r in db.execute(
+ "SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
+ ).fetchall()
+ ]
+ assert "t_ivf_centroids00" in tables
+
+
+def test_ivf_error_oversample_without_quantizer(db):
+ """oversample > 1 without quantizer should error."""
+ result = exec(
+ db,
+ "CREATE VIRTUAL TABLE t USING vec0("
+ "v float[768] indexed by ivf(oversample=10))",
+ )
+ assert "error" in result
+
+
+def test_ivf_error_unknown_quantizer(db):
+ result = exec(
+ db,
+ "CREATE VIRTUAL TABLE t USING vec0("
+ "v float[768] indexed by ivf(quantizer=pq))",
+ )
+ assert "error" in result
+
+
+def test_ivf_oversample_1_without_quantizer_ok(db):
+ """oversample=1 (default) is fine without quantizer."""
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0("
+ "v float[768] indexed by ivf(nlist=64))"
+ )
+ # Should succeed — oversample defaults to 1
+
+
+# ============================================================================
+# Functional tests — insert, train, query with quantized IVF
+# ============================================================================
+
+
+def test_ivf_int8_insert_and_query(db):
+ """int8 quantized IVF: insert, train, query."""
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0("
+ "v float[4] indexed by ivf(nlist=2, quantizer=int8, oversample=4))"
+ )
+ for i in range(20):
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
+ )
+
+ db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
+
+ # Should be able to query
+ rows = db.execute(
+ "SELECT rowid, distance FROM t WHERE v MATCH ? AND k = 5",
+ [_f32([10.0, 0, 0, 0])],
+ ).fetchall()
+ assert len(rows) == 5
+ # Top result should be close to 10
+ assert rows[0][0] in range(8, 13)
+
+ # Full vectors should be in _ivf_vectors table
+ fv_count = db.execute("SELECT count(*) FROM t_ivf_vectors00").fetchone()[0]
+ assert fv_count == 20
+
+
+def test_ivf_binary_insert_and_query(db):
+ """Binary quantized IVF: insert, train, query."""
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0("
+ "v float[8] indexed by ivf(nlist=2, quantizer=binary, oversample=4))"
+ )
+ for i in range(20):
+ # Vectors with varying sign patterns
+ v = [(i * 0.1 - 1.0) + j * 0.3 for j in range(8)]
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32(v)]
+ )
+
+ db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
+
+ rows = db.execute(
+ "SELECT rowid FROM t WHERE v MATCH ? AND k = 5",
+ [_f32([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5])],
+ ).fetchall()
+ assert len(rows) == 5
+
+ # Full vectors stored
+ fv_count = db.execute("SELECT count(*) FROM t_ivf_vectors00").fetchone()[0]
+ assert fv_count == 20
+
+
+def test_ivf_int8_cell_sizes_smaller(db):
+ """Cell blobs should be smaller with int8 quantization."""
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0("
+ "v float[768] indexed by ivf(nlist=2, quantizer=int8, oversample=1))"
+ )
+ for i in range(10):
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (?, ?)",
+ [i, _f32([float(x) / 768 for x in range(768)])],
+ )
+
+ # Cell vectors blob: 10 vectors at int8 = 10 * 768 = 7680 bytes
+ # vs float32 = 10 * 768 * 4 = 30720 bytes
+ # But cells have capacity 64, so blob = 64 * 768 = 49152 (int8) vs 64*768*4=196608 (float32)
+ blob_size = db.execute(
+ "SELECT length(vectors) FROM t_ivf_cells00 LIMIT 1"
+ ).fetchone()[0]
+ # int8: 64 slots * 768 bytes = 49152
+ assert blob_size == 64 * 768
+
+
+def test_ivf_binary_cell_sizes_smaller(db):
+ """Cell blobs should be much smaller with binary quantization."""
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0("
+ "v float[768] indexed by ivf(nlist=2, quantizer=binary, oversample=1))"
+ )
+ for i in range(10):
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (?, ?)",
+ [i, _f32([float(x) / 768 for x in range(768)])],
+ )
+
+ blob_size = db.execute(
+ "SELECT length(vectors) FROM t_ivf_cells00 LIMIT 1"
+ ).fetchone()[0]
+ # binary: 64 slots * 768/8 bytes = 6144
+ assert blob_size == 64 * (768 // 8)
+
+
+def test_ivf_int8_oversample_improves_recall(db):
+ """Oversample re-ranking should improve recall over oversample=1."""
+ # Create two tables: one with oversample=1, one with oversample=10
+ db.execute(
+ "CREATE VIRTUAL TABLE t1 USING vec0("
+ "v float[4] indexed by ivf(nlist=4, quantizer=int8, oversample=1))"
+ )
+ db.execute(
+ "CREATE VIRTUAL TABLE t2 USING vec0("
+ "v float[4] indexed by ivf(nlist=4, quantizer=int8, oversample=10))"
+ )
+ for i in range(100):
+ v = _f32([i * 0.1, (i * 0.1) % 3, (i * 0.3) % 5, i * 0.01])
+ db.execute("INSERT INTO t1(rowid, v) VALUES (?, ?)", [i, v])
+ db.execute("INSERT INTO t2(rowid, v) VALUES (?, ?)", [i, v])
+
+ db.execute("INSERT INTO t1(rowid) VALUES ('compute-centroids')")
+ db.execute("INSERT INTO t2(rowid) VALUES ('compute-centroids')")
+ db.execute("INSERT INTO t1(rowid) VALUES ('nprobe=4')")
+ db.execute("INSERT INTO t2(rowid) VALUES ('nprobe=4')")
+
+ query = _f32([5.0, 1.5, 2.5, 0.5])
+ r1 = db.execute("SELECT rowid FROM t1 WHERE v MATCH ? AND k=10", [query]).fetchall()
+ r2 = db.execute("SELECT rowid FROM t2 WHERE v MATCH ? AND k=10", [query]).fetchall()
+
+ # Both should return 10 results
+ assert len(r1) == 10
+ assert len(r2) == 10
+ # oversample=10 should have at least as good recall (same or better ordering)
+
+
+def test_ivf_quantized_delete(db):
+ """Delete should remove from both cells and _ivf_vectors."""
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0("
+ "v float[4] indexed by ivf(nlist=2, quantizer=int8, oversample=1))"
+ )
+ for i in range(10):
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
+ )
+
+ db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
+ assert db.execute("SELECT count(*) FROM t_ivf_vectors00").fetchone()[0] == 10
+
+ db.execute("DELETE FROM t WHERE rowid = 5")
+ # _ivf_vectors should have 9 rows
+ assert db.execute("SELECT count(*) FROM t_ivf_vectors00").fetchone()[0] == 9
diff --git a/tests/test-ivf.py b/tests/test-ivf.py
new file mode 100644
index 0000000..18a7532
--- /dev/null
+++ b/tests/test-ivf.py
@@ -0,0 +1,426 @@
+import pytest
+import sqlite3
+import struct
+import math
+from helpers import _f32, exec
+
+
+@pytest.fixture()
+def db():
+ db = sqlite3.connect(":memory:")
+ db.row_factory = sqlite3.Row
+ db.enable_load_extension(True)
+ db.load_extension("dist/vec0")
+ db.enable_load_extension(False)
+ return db
+
+
+def ivf_total_vectors(db, table="t", col=0):
+ """Count total vectors across all IVF cells."""
+ return db.execute(
+ f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d}"
+ ).fetchone()[0]
+
+
+def ivf_unassigned_count(db, table="t", col=0):
+ """Count vectors in unassigned cells (centroid_id=-1)."""
+ return db.execute(
+ f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d} WHERE centroid_id = -1"
+ ).fetchone()[0]
+
+
+def ivf_assigned_count(db, table="t", col=0):
+ """Count vectors in trained cells (centroid_id >= 0)."""
+ return db.execute(
+ f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d} WHERE centroid_id >= 0"
+ ).fetchone()[0]
+
+
+# ============================================================================
+# Parser tests
+# ============================================================================
+
+
+def test_ivf_create_defaults(db):
+ """ivf() with no args uses defaults."""
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf())"
+ )
+ tables = [
+ r[0]
+ for r in db.execute(
+ "SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
+ ).fetchall()
+ ]
+ assert "t_ivf_centroids00" in tables
+ assert "t_ivf_cells00" in tables
+ assert "t_ivf_rowid_map00" in tables
+
+
+def test_ivf_create_custom_params(db):
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=64, nprobe=8))"
+ )
+ tables = [
+ r[0]
+ for r in db.execute(
+ "SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
+ ).fetchall()
+ ]
+ assert "t_ivf_centroids00" in tables
+ assert "t_ivf_cells00" in tables
+
+
+def test_ivf_create_with_distance_metric(db):
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] distance_metric=cosine indexed by ivf(nlist=16))"
+ )
+ tables = [
+ r[0]
+ for r in db.execute(
+ "SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
+ ).fetchall()
+ ]
+ assert "t_ivf_centroids00" in tables
+
+
+def test_ivf_create_error_unknown_key(db):
+ result = exec(
+ db,
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(bogus=1))",
+ )
+ assert "error" in result
+
+
+def test_ivf_create_error_nprobe_gt_nlist(db):
+ result = exec(
+ db,
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4, nprobe=10))",
+ )
+ assert "error" in result
+
+
+# ============================================================================
+# Shadow table tests
+# ============================================================================
+
+
+def test_ivf_shadow_tables_created(db):
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=8))"
+ )
+ trained = db.execute(
+ "SELECT value FROM t_info WHERE key = 'ivf_trained_0'"
+ ).fetchone()[0]
+ assert str(trained) == "0"
+
+ # No cells yet (created lazily on first insert)
+ count = db.execute(
+ "SELECT count(*) FROM t_ivf_cells00"
+ ).fetchone()[0]
+ assert count == 0
+
+
+def test_ivf_drop_cleans_up(db):
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
+ )
+ db.execute("DROP TABLE t")
+ tables = [
+ r[0]
+ for r in db.execute(
+ "SELECT name FROM sqlite_master WHERE type='table'"
+ ).fetchall()
+ ]
+ assert not any("ivf" in t for t in tables)
+
+
+# ============================================================================
+# Insert tests (flat mode)
+# ============================================================================
+
+
+def test_ivf_insert_flat_mode(db):
+ """Before training, vectors go to unassigned cell."""
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
+ )
+ db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 2, 3, 4])])
+ db.execute("INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([5, 6, 7, 8])])
+
+ assert ivf_unassigned_count(db) == 2
+ assert ivf_assigned_count(db) == 0
+
+ # rowid_map should have 2 entries
+ assert db.execute("SELECT count(*) FROM t_ivf_rowid_map00").fetchone()[0] == 2
+
+
+def test_ivf_delete_flat_mode(db):
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
+ )
+ db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 2, 3, 4])])
+ db.execute("INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([5, 6, 7, 8])])
+ db.execute("DELETE FROM t WHERE rowid = 1")
+
+ assert ivf_unassigned_count(db) == 1
+ assert db.execute("SELECT count(*) FROM t_ivf_rowid_map00").fetchone()[0] == 1
+
+
+# ============================================================================
+# KNN flat mode tests
+# ============================================================================
+
+
+def test_ivf_knn_flat_mode(db):
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
+ )
+ db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 0, 0, 0])])
+ db.execute("INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([2, 0, 0, 0])])
+ db.execute("INSERT INTO t(rowid, v) VALUES (3, ?)", [_f32([9, 0, 0, 0])])
+
+ rows = db.execute(
+ "SELECT rowid, distance FROM t WHERE v MATCH ? AND k = 2",
+ [_f32([1.5, 0, 0, 0])],
+ ).fetchall()
+ assert len(rows) == 2
+ rowids = {r[0] for r in rows}
+ assert rowids == {1, 2}
+
+
+def test_ivf_knn_flat_empty(db):
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
+ )
+ rows = db.execute(
+ "SELECT rowid FROM t WHERE v MATCH ? AND k = 5",
+ [_f32([1, 0, 0, 0])],
+ ).fetchall()
+ assert len(rows) == 0
+
+
+# ============================================================================
+# compute-centroids tests
+# ============================================================================
+
+
+def test_compute_centroids(db):
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
+ )
+ for i in range(40):
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (?, ?)",
+ [i, _f32([i % 10, i // 10, 0, 0])],
+ )
+
+ assert ivf_unassigned_count(db) == 40
+
+ db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
+
+ # After training: unassigned cell should be gone (or empty), vectors in trained cells
+ assert ivf_unassigned_count(db) == 0
+ assert ivf_assigned_count(db) == 40
+ assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 4
+ trained = db.execute(
+ "SELECT value FROM t_info WHERE key='ivf_trained_0'"
+ ).fetchone()[0]
+ assert str(trained) == "1"
+
+
+def test_compute_centroids_recompute(db):
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2))"
+ )
+ for i in range(20):
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
+ )
+
+ db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
+ assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
+
+ db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
+ assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
+ assert ivf_assigned_count(db) == 20
+
+
+# ============================================================================
+# Insert after training (assigned mode)
+# ============================================================================
+
+
+def test_ivf_insert_after_training(db):
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2))"
+ )
+ for i in range(20):
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
+ )
+
+ db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
+
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (100, ?)", [_f32([5, 0, 0, 0])]
+ )
+
+ # Should be in a trained cell, not unassigned
+ row = db.execute(
+ "SELECT m.cell_id, c.centroid_id FROM t_ivf_rowid_map00 m "
+ "JOIN t_ivf_cells00 c ON c.rowid = m.cell_id "
+ "WHERE m.rowid = 100"
+ ).fetchone()
+ assert row is not None
+ assert row[1] >= 0 # centroid_id >= 0 means trained cell
+
+
+# ============================================================================
+# KNN after training (IVF probe mode)
+# ============================================================================
+
+
+def test_ivf_knn_after_training(db):
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4, nprobe=4))"
+ )
+ for i in range(100):
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
+ )
+
+ db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
+
+ rows = db.execute(
+ "SELECT rowid, distance FROM t WHERE v MATCH ? AND k = 5",
+ [_f32([50.0, 0, 0, 0])],
+ ).fetchall()
+ assert len(rows) == 5
+ assert rows[0][0] == 50
+ assert rows[0][1] < 0.01
+
+
+def test_ivf_knn_k_larger_than_n(db):
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
+ )
+ for i in range(5):
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
+ )
+
+ db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
+
+ rows = db.execute(
+ "SELECT rowid FROM t WHERE v MATCH ? AND k = 100",
+ [_f32([0, 0, 0, 0])],
+ ).fetchall()
+ assert len(rows) == 5
+
+
+# ============================================================================
+# Manual centroid import (set-centroid, assign-vectors)
+# ============================================================================
+
+
+def test_set_centroid_and_assign(db):
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=0))"
+ )
+ for i in range(20):
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
+ )
+
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES ('set-centroid:0', ?)",
+ [_f32([5, 0, 0, 0])],
+ )
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES ('set-centroid:1', ?)",
+ [_f32([15, 0, 0, 0])],
+ )
+
+ assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
+
+ db.execute("INSERT INTO t(rowid) VALUES ('assign-vectors')")
+
+ assert ivf_unassigned_count(db) == 0
+ assert ivf_assigned_count(db) == 20
+
+
+# ============================================================================
+# clear-centroids
+# ============================================================================
+
+
+def test_clear_centroids(db):
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2))"
+ )
+ for i in range(20):
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
+ )
+
+ db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
+ assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
+
+ db.execute("INSERT INTO t(rowid) VALUES ('clear-centroids')")
+ assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 0
+ assert ivf_unassigned_count(db) == 20
+ trained = db.execute(
+ "SELECT value FROM t_info WHERE key='ivf_trained_0'"
+ ).fetchone()[0]
+ assert str(trained) == "0"
+
+
+# ============================================================================
+# Delete after training
+# ============================================================================
+
+
+def test_ivf_delete_after_training(db):
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2))"
+ )
+ for i in range(10):
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
+ )
+
+ db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
+ assert ivf_assigned_count(db) == 10
+
+ db.execute("DELETE FROM t WHERE rowid = 5")
+ assert ivf_assigned_count(db) == 9
+ assert db.execute("SELECT count(*) FROM t_ivf_rowid_map00").fetchone()[0] == 9
+
+
+# ============================================================================
+# Recall test
+# ============================================================================
+
+
+def test_ivf_recall_nprobe_equals_nlist(db):
+ db.execute(
+ "CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=8, nprobe=8))"
+ )
+ for i in range(100):
+ db.execute(
+ "INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
+ )
+
+ db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
+
+ rows = db.execute(
+ "SELECT rowid FROM t WHERE v MATCH ? AND k = 10",
+ [_f32([50.0, 0, 0, 0])],
+ ).fetchall()
+ rowids = {r[0] for r in rows}
+
+ # 45 and 55 are equidistant from 50, so either may appear in top 10
+ assert 50 in rowids
+ assert len(rowids) == 10
+ assert all(45 <= r <= 55 for r in rowids)
diff --git a/tests/test-unit.c b/tests/test-unit.c
index b180625..27a469d 100644
--- a/tests/test-unit.c
+++ b/tests/test-unit.c
@@ -577,6 +577,182 @@ void test_vec0_parse_vector_column() {
assert(rc == SQLITE_ERROR);
}
+#if SQLITE_VEC_ENABLE_IVF
+ // IVF: indexed by ivf() — defaults
+ {
+ const char *input = "v float[4] indexed by ivf()";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.index_type == VEC0_INDEX_TYPE_IVF);
+ assert(col.dimensions == 4);
+ assert(col.index_type == VEC0_INDEX_TYPE_IVF);
+ assert(col.ivf.nlist == 128); // default
+ assert(col.ivf.nprobe == 10); // default
+ sqlite3_free(col.name);
+ }
+
+ // IVF: indexed by ivf(nlist=8) — nprobe auto-clamped to 8
+ {
+ const char *input = "v float[4] indexed by ivf(nlist=8)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.index_type == VEC0_INDEX_TYPE_IVF);
+ assert(col.index_type == VEC0_INDEX_TYPE_IVF);
+ assert(col.ivf.nlist == 8);
+ assert(col.ivf.nprobe == 8); // clamped from default 10
+ sqlite3_free(col.name);
+ }
+
+ // IVF: indexed by ivf(nlist=64, nprobe=8)
+ {
+ const char *input = "v float[4] indexed by ivf(nlist=64, nprobe=8)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.index_type == VEC0_INDEX_TYPE_IVF);
+ assert(col.ivf.nlist == 64);
+ assert(col.ivf.nprobe == 8);
+ sqlite3_free(col.name);
+ }
+
+ // IVF: with distance_metric before indexed by
+ {
+ const char *input = "v float[4] distance_metric=cosine indexed by ivf(nlist=16)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.index_type == VEC0_INDEX_TYPE_IVF);
+ assert(col.distance_metric == VEC0_DISTANCE_METRIC_COSINE);
+ assert(col.index_type == VEC0_INDEX_TYPE_IVF);
+ assert(col.ivf.nlist == 16);
+ sqlite3_free(col.name);
+ }
+
+ // IVF: nlist=0 (deferred)
+ {
+ const char *input = "v float[4] indexed by ivf(nlist=0)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.ivf.nlist == 0);
+ sqlite3_free(col.name);
+ }
+
+ // IVF error: nprobe > nlist
+ {
+ const char *input = "v float[4] indexed by ivf(nlist=4, nprobe=10)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_ERROR);
+ }
+
+ // IVF error: unknown key
+ {
+ const char *input = "v float[4] indexed by ivf(bogus=1)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_ERROR);
+ }
+
+ // IVF error: unknown index type (hnsw not supported)
+ {
+ const char *input = "v float[4] indexed by hnsw()";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_ERROR);
+ }
+
+ // Not IVF: no ivf config
+ {
+ const char *input = "v float[4]";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.index_type == VEC0_INDEX_TYPE_FLAT);
+ sqlite3_free(col.name);
+ }
+
+ // IVF: quantizer=binary
+ {
+ const char *input = "v float[768] indexed by ivf(nlist=128, quantizer=binary)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.index_type == VEC0_INDEX_TYPE_IVF);
+ assert(col.ivf.nlist == 128);
+ assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_BINARY);
+ assert(col.ivf.oversample == 1);
+ sqlite3_free(col.name);
+ }
+
+ // IVF: quantizer=int8
+ {
+ const char *input = "v float[768] indexed by ivf(nlist=64, quantizer=int8)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_INT8);
+ sqlite3_free(col.name);
+ }
+
+ // IVF: quantizer=none (explicit)
+ {
+ const char *input = "v float[768] indexed by ivf(quantizer=none)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_NONE);
+ sqlite3_free(col.name);
+ }
+
+ // IVF: oversample=10 with quantizer
+ {
+ const char *input = "v float[768] indexed by ivf(nlist=128, quantizer=binary, oversample=10)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_BINARY);
+ assert(col.ivf.oversample == 10);
+ assert(col.ivf.nlist == 128);
+ sqlite3_free(col.name);
+ }
+
+ // IVF: all params
+ {
+ const char *input = "v float[768] distance_metric=cosine indexed by ivf(nlist=256, nprobe=16, quantizer=int8, oversample=4)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.distance_metric == VEC0_DISTANCE_METRIC_COSINE);
+ assert(col.ivf.nlist == 256);
+ assert(col.ivf.nprobe == 16);
+ assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_INT8);
+ assert(col.ivf.oversample == 4);
+ sqlite3_free(col.name);
+ }
+
+ // IVF error: oversample > 1 without quantizer
+ {
+ const char *input = "v float[768] indexed by ivf(oversample=10)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_ERROR);
+ }
+
+ // IVF error: unknown quantizer value
+ {
+ const char *input = "v float[768] indexed by ivf(quantizer=pq)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_ERROR);
+ }
+
+ // IVF: quantizer with defaults (nlist=128 default, nprobe=10 default)
+ {
+ const char *input = "v float[768] indexed by ivf(quantizer=binary, oversample=5)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.ivf.nlist == 128);
+ assert(col.ivf.nprobe == 10);
+ assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_BINARY);
+ assert(col.ivf.oversample == 5);
+ sqlite3_free(col.name);
+ }
+#else
+ // When IVF is disabled, parsing "ivf" should fail
+ {
+ const char *input = "v float[4] indexed by ivf()";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_ERROR);
+ }
+#endif /* SQLITE_VEC_ENABLE_IVF */
+
printf(" All vec0_parse_vector_column tests passed.\n");
}
@@ -821,6 +997,38 @@ void test_rescore_quantize_float_to_int8() {
float src[8] = {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f};
_test_rescore_quantize_float_to_int8(src, dst, 8);
for (int i = 0; i < 8; i++) {
+#if SQLITE_VEC_ENABLE_IVF
+void test_ivf_quantize_int8() {
+ printf("Starting %s...\n", __func__);
+
+ // Basic values in [-1, 1] range
+ {
+ float src[] = {0.0f, 1.0f, -1.0f, 0.5f};
+ int8_t dst[4];
+ ivf_quantize_int8(src, dst, 4);
+ assert(dst[0] == 0);
+ assert(dst[1] == 127);
+ assert(dst[2] == -127);
+ assert(dst[3] == 63); // 0.5 * 127 = 63.5, truncated to 63
+ }
+
+ // Clamping: values beyond [-1, 1]
+ {
+ float src[] = {2.0f, -3.0f, 100.0f, -0.01f};
+ int8_t dst[4];
+ ivf_quantize_int8(src, dst, 4);
+ assert(dst[0] == 127); // clamped to 1.0
+ assert(dst[1] == -127); // clamped to -1.0
+ assert(dst[2] == 127); // clamped to 1.0
+ assert(dst[3] == (int8_t)(-0.01f * 127.0f));
+ }
+
+ // Zero vector
+ {
+ float src[] = {0.0f, 0.0f, 0.0f, 0.0f};
+ int8_t dst[4];
+ ivf_quantize_int8(src, dst, 4);
+ for (int i = 0; i < 4; i++) {
assert(dst[i] == 0);
}
}
@@ -882,6 +1090,103 @@ void test_rescore_quantized_byte_size() {
}
void test_vec0_parse_vector_column_rescore() {
+ // Negative zero
+ {
+ float src[] = {-0.0f};
+ int8_t dst[1];
+ ivf_quantize_int8(src, dst, 1);
+ assert(dst[0] == 0);
+ }
+
+ // Single element
+ {
+ float src[] = {0.75f};
+ int8_t dst[1];
+ ivf_quantize_int8(src, dst, 1);
+ assert(dst[0] == (int8_t)(0.75f * 127.0f));
+ }
+
+ // Boundary: exactly 1.0 and -1.0
+ {
+ float src[] = {1.0f, -1.0f};
+ int8_t dst[2];
+ ivf_quantize_int8(src, dst, 2);
+ assert(dst[0] == 127);
+ assert(dst[1] == -127);
+ }
+
+ printf(" All ivf_quantize_int8 tests passed.\n");
+}
+
+void test_ivf_quantize_binary() {
+ printf("Starting %s...\n", __func__);
+
+ // Basic sign-bit quantization: positive -> 1, negative/zero -> 0
+ {
+ float src[] = {1.0f, -1.0f, 0.5f, -0.5f, 0.0f, 0.1f, -0.1f, 2.0f};
+ uint8_t dst[1];
+ ivf_quantize_binary(src, dst, 8);
+ // bit 0: 1.0 > 0 -> 1 (LSB)
+ // bit 1: -1.0 -> 0
+ // bit 2: 0.5 > 0 -> 1
+ // bit 3: -0.5 -> 0
+ // bit 4: 0.0 -> 0 (not > 0)
+ // bit 5: 0.1 > 0 -> 1
+ // bit 6: -0.1 -> 0
+ // bit 7: 2.0 > 0 -> 1
+ // Expected: bits 0,2,5,7 = 0b10100101 = 0xA5
+ assert(dst[0] == 0xA5);
+ }
+
+ // All positive
+ {
+ float src[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f};
+ uint8_t dst[1];
+ ivf_quantize_binary(src, dst, 8);
+ assert(dst[0] == 0xFF);
+ }
+
+ // All negative
+ {
+ float src[] = {-1.0f, -2.0f, -3.0f, -4.0f, -5.0f, -6.0f, -7.0f, -8.0f};
+ uint8_t dst[1];
+ ivf_quantize_binary(src, dst, 8);
+ assert(dst[0] == 0x00);
+ }
+
+ // All zero (zero is NOT > 0, so all bits should be 0)
+ {
+ float src[] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
+ uint8_t dst[1];
+ ivf_quantize_binary(src, dst, 8);
+ assert(dst[0] == 0x00);
+ }
+
+ // Multi-byte: 16 dimensions -> 2 bytes
+ {
+ float src[16];
+ for (int i = 0; i < 16; i++) src[i] = (i % 2 == 0) ? 1.0f : -1.0f;
+ uint8_t dst[2];
+ ivf_quantize_binary(src, dst, 16);
+ // Even indices are positive: bits 0,2,4,6 in each byte
+ // byte 0: bits 0,2,4,6 = 0b01010101 = 0x55
+ // byte 1: same pattern = 0x55
+ assert(dst[0] == 0x55);
+ assert(dst[1] == 0x55);
+ }
+
+ // Single byte, only first bit set
+ {
+ float src[] = {0.1f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f};
+ uint8_t dst[1];
+ ivf_quantize_binary(src, dst, 8);
+ assert(dst[0] == 0x01);
+ }
+
+ printf(" All ivf_quantize_binary tests passed.\n");
+}
+
+void test_ivf_config_parsing() {
printf("Starting %s...\n", __func__);
struct VectorColumnDefinition col;
int rc;
@@ -955,6 +1260,116 @@ void test_vec0_parse_vector_column_rescore() {
}
#endif /* SQLITE_VEC_ENABLE_RESCORE */
+ // Default IVF config
+ {
+ const char *s = "v float[4] indexed by ivf()";
+ rc = vec0_parse_vector_column(s, (int)strlen(s), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.index_type == VEC0_INDEX_TYPE_IVF);
+ assert(col.ivf.nlist == 128); // default
+ assert(col.ivf.nprobe == 10); // default
+ assert(col.ivf.quantizer == 0); // VEC0_IVF_QUANTIZER_NONE
+ sqlite3_free(col.name);
+ }
+
+ // Custom nlist and nprobe
+ {
+ const char *s = "v float[4] indexed by ivf(nlist=64, nprobe=8)";
+ rc = vec0_parse_vector_column(s, (int)strlen(s), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.ivf.nlist == 64);
+ assert(col.ivf.nprobe == 8);
+ sqlite3_free(col.name);
+ }
+
+ // nlist=0 (deferred)
+ {
+ const char *s = "v float[4] indexed by ivf(nlist=0)";
+ rc = vec0_parse_vector_column(s, (int)strlen(s), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.ivf.nlist == 0);
+ sqlite3_free(col.name);
+ }
+
+ // Quantizer options
+ {
+ const char *s = "v float[8] indexed by ivf(quantizer=int8)";
+ rc = vec0_parse_vector_column(s, (int)strlen(s), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_INT8);
+ sqlite3_free(col.name);
+ }
+
+ {
+ const char *s = "v float[8] indexed by ivf(quantizer=binary)";
+ rc = vec0_parse_vector_column(s, (int)strlen(s), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_BINARY);
+ sqlite3_free(col.name);
+ }
+
+ // nprobe > nlist (explicit) should fail
+ {
+ const char *s = "v float[4] indexed by ivf(nlist=4, nprobe=10)";
+ rc = vec0_parse_vector_column(s, (int)strlen(s), &col);
+ assert(rc == SQLITE_ERROR);
+ }
+
+ // Unknown key
+ {
+ const char *s = "v float[4] indexed by ivf(bogus=1)";
+ rc = vec0_parse_vector_column(s, (int)strlen(s), &col);
+ assert(rc == SQLITE_ERROR);
+ }
+
+ // nlist > max (65536) should fail
+ {
+ const char *s = "v float[4] indexed by ivf(nlist=65537)";
+ rc = vec0_parse_vector_column(s, (int)strlen(s), &col);
+ assert(rc == SQLITE_ERROR);
+ }
+
+ // nlist at max boundary (65536) should succeed
+ {
+ const char *s = "v float[4] indexed by ivf(nlist=65536)";
+ rc = vec0_parse_vector_column(s, (int)strlen(s), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.ivf.nlist == 65536);
+ sqlite3_free(col.name);
+ }
+
+ // oversample > 1 without quantization should fail
+ {
+ const char *s = "v float[4] indexed by ivf(oversample=4)";
+ rc = vec0_parse_vector_column(s, (int)strlen(s), &col);
+ assert(rc == SQLITE_ERROR);
+ }
+
+ // oversample with quantizer should succeed
+ {
+ const char *s = "v float[8] indexed by ivf(quantizer=int8, oversample=4)";
+ rc = vec0_parse_vector_column(s, (int)strlen(s), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.ivf.oversample == 4);
+ assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_INT8);
+ sqlite3_free(col.name);
+ }
+
+ // All options combined
+ {
+ const char *s = "v float[8] indexed by ivf(nlist=32, nprobe=4, quantizer=int8, oversample=2)";
+ rc = vec0_parse_vector_column(s, (int)strlen(s), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.ivf.nlist == 32);
+ assert(col.ivf.nprobe == 4);
+ assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_INT8);
+ assert(col.ivf.oversample == 2);
+ sqlite3_free(col.name);
+ }
+
+ printf(" All ivf_config_parsing tests passed.\n");
+}
+#endif /* SQLITE_VEC_ENABLE_IVF */
int main() {
printf("Starting unit tests...\n");
@@ -982,6 +1397,10 @@ int main() {
test_rescore_quantize_float_to_int8();
test_rescore_quantized_byte_size();
test_vec0_parse_vector_column_rescore();
+#if SQLITE_VEC_ENABLE_IVF
+ test_ivf_quantize_int8();
+ test_ivf_quantize_binary();
+ test_ivf_config_parsing();
#endif
printf("All unit tests passed.\n");
}