mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 00:36:56 +02:00
Add IVF index for vec0 virtual table
Add inverted file (IVF) index type: partitions vectors into clusters via k-means, quantizes to int8, and scans only the nearest nprobe partitions at query time. Includes shadow table management, insert/delete, KNN integration, compile flag (SQLITE_VEC_ENABLE_IVF), fuzz targets, and tests. Removes superseded ivf-benchmarks/ directory.
This commit is contained in:
parent
43982c144b
commit
3358e127f6
22 changed files with 5237 additions and 28 deletions
264
IVF_PLAN.md
Normal file
264
IVF_PLAN.md
Normal file
|
|
@ -0,0 +1,264 @@
|
||||||
|
# IVF Index for sqlite-vec
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
IVF (Inverted File Index) is an approximate nearest neighbor index for
|
||||||
|
sqlite-vec's `vec0` virtual table. It partitions vectors into clusters via
|
||||||
|
k-means, then at query time only scans the nearest clusters instead of all
|
||||||
|
vectors. Combined with scalar or binary quantization, this gives 5-20x query
|
||||||
|
speedups over brute-force with tunable recall.
|
||||||
|
|
||||||
|
## SQL API
|
||||||
|
|
||||||
|
### Table Creation
|
||||||
|
|
||||||
|
```sql
|
||||||
|
CREATE VIRTUAL TABLE vec_items USING vec0(
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
embedding float[768] distance_metric=cosine
|
||||||
|
INDEXED BY ivf(nlist=128, nprobe=16)
|
||||||
|
);
|
||||||
|
|
||||||
|
-- With quantization (4x smaller cells, rescore for recall)
|
||||||
|
CREATE VIRTUAL TABLE vec_items USING vec0(
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
embedding float[768] distance_metric=cosine
|
||||||
|
INDEXED BY ivf(nlist=128, nprobe=16, quantizer=int8, oversample=4)
|
||||||
|
);
|
||||||
|
```
|
||||||
|
|
||||||
|
### Parameters
|
||||||
|
|
||||||
|
| Parameter | Values | Default | Description |
|
||||||
|
|-----------|--------|---------|-------------|
|
||||||
|
| `nlist` | 1-65536, or 0 | 128 | Number of k-means clusters. Rule of thumb: `sqrt(N)` |
|
||||||
|
| `nprobe` | 1-nlist | 10 | Clusters to search at query time. More = better recall, slower |
|
||||||
|
| `quantizer` | `none`, `int8`, `binary` | `none` | How vectors are stored in cells |
|
||||||
|
| `oversample` | >= 1 | 1 | Re-rank `oversample * k` candidates with full-precision distance |
|
||||||
|
|
||||||
|
### Inserting Vectors
|
||||||
|
|
||||||
|
```sql
|
||||||
|
-- Works immediately, even before training
|
||||||
|
INSERT INTO vec_items(id, embedding) VALUES (1, :vector);
|
||||||
|
```
|
||||||
|
|
||||||
|
Before centroids exist, vectors go to an "unassigned" partition and queries do
|
||||||
|
brute-force. After training, new inserts are assigned to the nearest centroid.
|
||||||
|
|
||||||
|
### Training (Computing Centroids)
|
||||||
|
|
||||||
|
```sql
|
||||||
|
-- Run built-in k-means on all vectors
|
||||||
|
INSERT INTO vec_items(id) VALUES ('compute-centroids');
|
||||||
|
```
|
||||||
|
|
||||||
|
This loads all vectors into memory, runs k-means++ with Lloyd's algorithm,
|
||||||
|
creates quantized centroids, and redistributes all vectors into cluster cells.
|
||||||
|
It's a blocking operation — run it once after bulk insert.
|
||||||
|
|
||||||
|
### Manual Centroid Import
|
||||||
|
|
||||||
|
```sql
|
||||||
|
-- Import externally-computed centroids
|
||||||
|
INSERT INTO vec_items(id, embedding) VALUES ('set-centroid:0', :centroid_0);
|
||||||
|
INSERT INTO vec_items(id, embedding) VALUES ('set-centroid:1', :centroid_1);
|
||||||
|
|
||||||
|
-- Assign vectors to imported centroids
|
||||||
|
INSERT INTO vec_items(id) VALUES ('assign-vectors');
|
||||||
|
```
|
||||||
|
|
||||||
|
### Runtime Parameter Tuning
|
||||||
|
|
||||||
|
```sql
|
||||||
|
-- Change nprobe without rebuilding the index
|
||||||
|
INSERT INTO vec_items(id) VALUES ('nprobe=32');
|
||||||
|
```
|
||||||
|
|
||||||
|
### KNN Queries
|
||||||
|
|
||||||
|
```sql
|
||||||
|
-- Same syntax as standard vec0
|
||||||
|
SELECT id, distance
|
||||||
|
FROM vec_items
|
||||||
|
WHERE embedding MATCH :query AND k = 10;
|
||||||
|
```
|
||||||
|
|
||||||
|
### Other Commands
|
||||||
|
|
||||||
|
```sql
|
||||||
|
-- Remove centroids, move all vectors back to unassigned
|
||||||
|
INSERT INTO vec_items(id) VALUES ('clear-centroids');
|
||||||
|
```
|
||||||
|
|
||||||
|
## How It Works
|
||||||
|
|
||||||
|
### Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
User vector (float32)
|
||||||
|
→ quantize to int8/binary (if quantizer != none)
|
||||||
|
→ find nearest centroid (quantized distance)
|
||||||
|
→ store quantized vector in cell blob
|
||||||
|
→ store full vector in KV table (if quantizer != none)
|
||||||
|
→ query:
|
||||||
|
1. quantize query vector
|
||||||
|
2. find top nprobe centroids by quantized distance
|
||||||
|
3. scan cell blobs: quantized distance (fast, small I/O)
|
||||||
|
4. if oversample > 1: re-score top N*k with full vectors
|
||||||
|
5. return top k
|
||||||
|
```
|
||||||
|
|
||||||
|
### Shadow Tables
|
||||||
|
|
||||||
|
For a table `vec_items` with vector column index 0:
|
||||||
|
|
||||||
|
| Table | Schema | Purpose |
|
||||||
|
|-------|--------|---------|
|
||||||
|
| `vec_items_ivf_centroids00` | `centroid_id PK, centroid BLOB` | K-means centroids (quantized) |
|
||||||
|
| `vec_items_ivf_cells00` | `centroid_id, n_vectors, validity BLOB, rowids BLOB, vectors BLOB` | Packed vector cells, 64 vectors max per row. Multiple rows per centroid. Index on centroid_id. |
|
||||||
|
| `vec_items_ivf_rowid_map00` | `rowid PK, cell_id, slot` | Maps vector rowid → cell location for O(1) delete |
|
||||||
|
| `vec_items_ivf_vectors00` | `rowid PK, vector BLOB` | Full-precision vectors (only when quantizer != none) |
|
||||||
|
|
||||||
|
### Cell Storage
|
||||||
|
|
||||||
|
Cells use packed blob storage identical to vec0's chunk layout:
|
||||||
|
- **validity**: bitmap (1 bit per slot) marking live vectors
|
||||||
|
- **rowids**: packed i64 array
|
||||||
|
- **vectors**: packed array of quantized vectors
|
||||||
|
|
||||||
|
Cells are capped at 64 vectors (~200KB at 768-dim float32, ~48KB for int8,
|
||||||
|
~6KB for binary). When a cell fills, a new row is created for the same
|
||||||
|
centroid. This avoids SQLite overflow page traversal which was a 110x
|
||||||
|
performance bottleneck with unbounded cells.
|
||||||
|
|
||||||
|
### Quantization
|
||||||
|
|
||||||
|
**int8**: Each float32 dimension clamped to [-1,1] and scaled to int8
|
||||||
|
[-127,127]. 4x storage reduction. Distance computed via int8 L2.
|
||||||
|
|
||||||
|
**binary**: Sign-bit quantization — each bit is 1 if the float is positive.
|
||||||
|
32x storage reduction. Distance computed via hamming distance.
|
||||||
|
|
||||||
|
**Oversample re-ranking**: When `oversample > 1`, the quantized scan collects
|
||||||
|
`oversample * k` candidates, then looks up each candidate's full-precision
|
||||||
|
vector from the KV table and re-computes exact distance. This recovers nearly
|
||||||
|
all recall lost from quantization. At oversample=4 with int8, recall matches
|
||||||
|
non-quantized IVF exactly.
|
||||||
|
|
||||||
|
### K-Means
|
||||||
|
|
||||||
|
Uses Lloyd's algorithm with k-means++ initialization:
|
||||||
|
1. K-means++ picks initial centroids weighted by distance
|
||||||
|
2. Lloyd's iterations: assign vectors to nearest centroid, recompute centroids as cluster means
|
||||||
|
3. Empty cluster handling: reassign to farthest point
|
||||||
|
4. K-means runs in float32; centroids are quantized before storage
|
||||||
|
|
||||||
|
Training data: recommend 16× nlist vectors. At nlist=1000, that's 16k
|
||||||
|
vectors — k-means takes ~140s on 768-dim data.
|
||||||
|
|
||||||
|
## Performance
|
||||||
|
|
||||||
|
### 100k vectors (COHERE 768-dim cosine)
|
||||||
|
|
||||||
|
```
|
||||||
|
name qry(ms) recall
|
||||||
|
───────────────────────────────────────────────
|
||||||
|
ivf(q=int8,os=4),p=8 5.3ms 0.934 ← 6x faster than flat
|
||||||
|
ivf(q=int8,os=4),p=16 5.4ms 0.968
|
||||||
|
ivf(q=none),p=8 5.3ms 0.934
|
||||||
|
ivf(q=binary,os=10),p=16 1.3ms 0.832 ← 26x faster than flat
|
||||||
|
ivf(q=int8,os=4),p=32 7.4ms 0.990
|
||||||
|
ivf(q=none),p=32 15.5ms 0.992
|
||||||
|
int8(os=4) 18.7ms 0.996
|
||||||
|
bit(os=8) 18.7ms 0.884
|
||||||
|
flat 33.7ms 1.000
|
||||||
|
```
|
||||||
|
|
||||||
|
### 1M vectors (COHERE 768-dim cosine)
|
||||||
|
|
||||||
|
```
|
||||||
|
name insert train MB qry(ms) recall
|
||||||
|
──────────────────────────────────────────────────────────────────────
|
||||||
|
ivf(q=int8,os=4),p=8 163s 142s 4725 16.3ms 0.892
|
||||||
|
ivf(q=binary,os=10),p=16 118s 144s 4073 17.7ms 0.830
|
||||||
|
ivf(q=int8,os=4),p=16 163s 142s 4725 24.3ms 0.950
|
||||||
|
ivf(q=int8,os=4),p=32 163s 142s 4725 41.6ms 0.980
|
||||||
|
ivf(q=none),p=8 497s 144s 3101 52.1ms 0.890
|
||||||
|
ivf(q=none),p=16 497s 144s 3101 56.6ms 0.950
|
||||||
|
bit(os=8) 18s - 3048 83.5ms 0.918
|
||||||
|
ivf(q=none),p=32 497s 144s 3101 103.9ms 0.980
|
||||||
|
int8(os=4) 19s - 3689 169.1ms 0.994
|
||||||
|
flat 20s - 2955 338.0ms 1.000
|
||||||
|
```
|
||||||
|
|
||||||
|
**Best config at 1M: `ivf(quantizer=int8, oversample=4, nprobe=16)`** —
|
||||||
|
24ms query, 0.95 recall, 14x faster than flat, 7x faster than int8 baseline.
|
||||||
|
|
||||||
|
### Scaling Characteristics
|
||||||
|
|
||||||
|
| Metric | 100k | 1M | Scaling |
|
||||||
|
|--------|------|-----|---------|
|
||||||
|
| Flat query | 34ms | 338ms | 10x (linear) |
|
||||||
|
| IVF int8 p=16 | 5.4ms | 24.3ms | 4.5x (sublinear) |
|
||||||
|
| IVF insert rate | ~10k/s | ~6k/s | Slight degradation |
|
||||||
|
| Training (nlist=1000) | 13s | 142s | ~11x |
|
||||||
|
|
||||||
|
## Implementation
|
||||||
|
|
||||||
|
### File Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
sqlite-vec-ivf-kmeans.c K-means++ algorithm (pure C, no SQLite deps)
|
||||||
|
sqlite-vec-ivf.c All IVF logic: parser, shadow tables, insert,
|
||||||
|
delete, query, centroid commands, quantization
|
||||||
|
sqlite-vec.c ~50 lines of additions: struct fields, #includes,
|
||||||
|
dispatch hooks in parse/create/insert/delete/filter
|
||||||
|
```
|
||||||
|
|
||||||
|
Both IVF files are `#include`d into `sqlite-vec.c`. No Makefile changes needed.
|
||||||
|
|
||||||
|
### Key Design Decisions
|
||||||
|
|
||||||
|
1. **Fixed-size cells (64 vectors)** instead of one blob per centroid. Avoids
|
||||||
|
SQLite overflow page traversal which caused 110x insert slowdown.
|
||||||
|
|
||||||
|
2. **Multiple cell rows per centroid** with an index on centroid_id. When a
|
||||||
|
cell fills, a new row is created. Query scans all rows for probed centroids
|
||||||
|
via `WHERE centroid_id IN (...)`.
|
||||||
|
|
||||||
|
3. **Always store full vectors** when quantizer != none (in `_ivf_vectors` KV
|
||||||
|
table). Enables oversample re-ranking and point queries returning original
|
||||||
|
precision.
|
||||||
|
|
||||||
|
4. **K-means in float32, quantize after**. Simpler than quantized k-means,
|
||||||
|
and assignment accuracy doesn't suffer much since nprobe compensates.
|
||||||
|
|
||||||
|
5. **NEON SIMD for cosine distance**. Added `cosine_float_neon()` with 4-wide
|
||||||
|
FMA for dot product + magnitudes. Benefits all vec0 queries, not just IVF.
|
||||||
|
|
||||||
|
6. **Runtime nprobe tuning**. `INSERT INTO t(id) VALUES ('nprobe=N')` changes
|
||||||
|
the probe count without rebuilding — enables fast parameter sweeps.
|
||||||
|
|
||||||
|
### Optimization History
|
||||||
|
|
||||||
|
| Optimization | Impact |
|
||||||
|
|-------------|--------|
|
||||||
|
| Fixed-size cells (64 max) | 110x insert speedup |
|
||||||
|
| Skip chunk writes for IVF | 2x DB size reduction |
|
||||||
|
| NEON cosine distance | 2x query speedup + 13% recall improvement (correct metric) |
|
||||||
|
| Cached prepared statements | Eliminated per-insert prepare/finalize |
|
||||||
|
| Batched cell reads (IN clause) | Fewer SQLite queries per KNN |
|
||||||
|
| int8 quantization | 2.5x query speedup at same recall |
|
||||||
|
| Binary quantization | 32x less cell I/O |
|
||||||
|
| Oversample re-ranking | Recovers quantization recall loss |
|
||||||
|
|
||||||
|
## Remaining Work
|
||||||
|
|
||||||
|
See `ivf-benchmarks/TODO.md` for the full list. Key items:
|
||||||
|
|
||||||
|
- **Cache centroids in memory** — each insert re-reads all centroids from SQLite
|
||||||
|
- **Runtime oversample** — same pattern as nprobe runtime command
|
||||||
|
- **SIMD k-means** — training uses scalar distance, could be 4x faster
|
||||||
|
- **Top-k heap** — replace qsort with min-heap for large nprobe
|
||||||
|
- **IVF-PQ** — product quantization for better compression/recall tradeoff
|
||||||
15
Makefile
15
Makefile
|
|
@ -206,6 +206,21 @@ test-loadable-watch:
|
||||||
test-unit:
|
test-unit:
|
||||||
$(CC) -DSQLITE_CORE -DSQLITE_VEC_TEST -DSQLITE_VEC_ENABLE_RESCORE tests/test-unit.c sqlite-vec.c vendor/sqlite3.c -I./ -Ivendor -o $(prefix)/test-unit && $(prefix)/test-unit
|
$(CC) -DSQLITE_CORE -DSQLITE_VEC_TEST -DSQLITE_VEC_ENABLE_RESCORE tests/test-unit.c sqlite-vec.c vendor/sqlite3.c -I./ -Ivendor -o $(prefix)/test-unit && $(prefix)/test-unit
|
||||||
|
|
||||||
|
# Standalone sqlite3 CLI with vec0 compiled in. Useful for benchmarking,
|
||||||
|
# profiling (has debug symbols), and scripting without .load_extension.
|
||||||
|
# make cli
|
||||||
|
# dist/sqlite3 :memory: "SELECT vec_version()"
|
||||||
|
# dist/sqlite3 < script.sql
|
||||||
|
cli: sqlite-vec.h $(prefix)
|
||||||
|
$(CC) -O2 -g \
|
||||||
|
-DSQLITE_CORE \
|
||||||
|
-DSQLITE_EXTRA_INIT=core_init \
|
||||||
|
-DSQLITE_THREADSAFE=0 \
|
||||||
|
-Ivendor/ -I./ \
|
||||||
|
$(CFLAGS) \
|
||||||
|
vendor/sqlite3.c vendor/shell.c sqlite-vec.c examples/sqlite3-cli/core_init.c \
|
||||||
|
-ldl -lm -o $(prefix)/sqlite3
|
||||||
|
|
||||||
fuzz-build:
|
fuzz-build:
|
||||||
$(MAKE) -C tests/fuzz all
|
$(MAKE) -C tests/fuzz all
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,27 +8,20 @@ BASELINES = \
|
||||||
"brute-int8:type=baseline,variant=int8" \
|
"brute-int8:type=baseline,variant=int8" \
|
||||||
"brute-bit:type=baseline,variant=bit"
|
"brute-bit:type=baseline,variant=bit"
|
||||||
|
|
||||||
# --- Index-specific configs ---
|
# --- IVF configs ---
|
||||||
# Each index branch should add its own configs here. Example:
|
IVF_CONFIGS = \
|
||||||
#
|
"ivf-n32-p8:type=ivf,nlist=32,nprobe=8" \
|
||||||
# DISKANN_CONFIGS = \
|
"ivf-n128-p16:type=ivf,nlist=128,nprobe=16" \
|
||||||
# "diskann-R48-binary:type=diskann,R=48,L=128,quantizer=binary" \
|
"ivf-n512-p32:type=ivf,nlist=512,nprobe=32"
|
||||||
# "diskann-R72-int8:type=diskann,R=72,L=128,quantizer=int8"
|
|
||||||
#
|
|
||||||
# IVF_CONFIGS = \
|
|
||||||
# "ivf-n128-p16:type=ivf,nlist=128,nprobe=16"
|
|
||||||
#
|
|
||||||
# ANNOY_CONFIGS = \
|
|
||||||
# "annoy-t50:type=annoy,n_trees=50"
|
|
||||||
|
|
||||||
RESCORE_CONFIGS = \
|
RESCORE_CONFIGS = \
|
||||||
"rescore-bit-os8:type=rescore,quantizer=bit,oversample=8" \
|
"rescore-bit-os8:type=rescore,quantizer=bit,oversample=8" \
|
||||||
"rescore-bit-os16:type=rescore,quantizer=bit,oversample=16" \
|
"rescore-bit-os16:type=rescore,quantizer=bit,oversample=16" \
|
||||||
"rescore-int8-os8:type=rescore,quantizer=int8,oversample=8"
|
"rescore-int8-os8:type=rescore,quantizer=int8,oversample=8"
|
||||||
|
|
||||||
ALL_CONFIGS = $(BASELINES) $(RESCORE_CONFIGS)
|
ALL_CONFIGS = $(BASELINES) $(RESCORE_CONFIGS) $(IVF_CONFIGS)
|
||||||
|
|
||||||
.PHONY: seed ground-truth bench-smoke bench-rescore bench-10k bench-50k bench-100k bench-all \
|
.PHONY: seed ground-truth bench-smoke bench-rescore bench-ivf bench-10k bench-50k bench-100k bench-all \
|
||||||
report clean
|
report clean
|
||||||
|
|
||||||
# --- Data preparation ---
|
# --- Data preparation ---
|
||||||
|
|
@ -43,7 +36,8 @@ ground-truth: seed
|
||||||
# --- Quick smoke test ---
|
# --- Quick smoke test ---
|
||||||
bench-smoke: seed
|
bench-smoke: seed
|
||||||
$(BENCH) --subset-size 5000 -k 10 -n 20 -o runs/smoke \
|
$(BENCH) --subset-size 5000 -k 10 -n 20 -o runs/smoke \
|
||||||
$(BASELINES)
|
"brute-float:type=baseline,variant=float" \
|
||||||
|
"ivf-quick:type=ivf,nlist=16,nprobe=4"
|
||||||
|
|
||||||
bench-rescore: seed
|
bench-rescore: seed
|
||||||
$(BENCH) --subset-size 10000 -k 10 -o runs/rescore \
|
$(BENCH) --subset-size 10000 -k 10 -o runs/rescore \
|
||||||
|
|
@ -62,6 +56,12 @@ bench-100k: seed
|
||||||
|
|
||||||
bench-all: bench-10k bench-50k bench-100k
|
bench-all: bench-10k bench-50k bench-100k
|
||||||
|
|
||||||
|
# --- IVF across sizes ---
|
||||||
|
bench-ivf: seed
|
||||||
|
$(BENCH) --subset-size 10000 -k 10 -o runs/ivf $(BASELINES) $(IVF_CONFIGS)
|
||||||
|
$(BENCH) --subset-size 50000 -k 10 -o runs/ivf $(BASELINES) $(IVF_CONFIGS)
|
||||||
|
$(BENCH) --subset-size 100000 -k 10 -o runs/ivf $(BASELINES) $(IVF_CONFIGS)
|
||||||
|
|
||||||
# --- Report ---
|
# --- Report ---
|
||||||
report:
|
report:
|
||||||
@echo "Use: sqlite3 runs/<dir>/results.db 'SELECT * FROM bench_results ORDER BY recall DESC'"
|
@echo "Use: sqlite3 runs/<dir>/results.db 'SELECT * FROM bench_results ORDER BY recall DESC'"
|
||||||
|
|
|
||||||
|
|
@ -173,6 +173,48 @@ INDEX_REGISTRY["rescore"] = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# IVF implementation
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def _ivf_create_table_sql(params):
|
||||||
|
return (
|
||||||
|
f"CREATE VIRTUAL TABLE vec_items USING vec0("
|
||||||
|
f" id integer primary key,"
|
||||||
|
f" embedding float[768] distance_metric=cosine"
|
||||||
|
f" indexed by ivf("
|
||||||
|
f" nlist={params['nlist']},"
|
||||||
|
f" nprobe={params['nprobe']}"
|
||||||
|
f" )"
|
||||||
|
f")"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ivf_post_insert_hook(conn, params):
|
||||||
|
print(" Training k-means centroids...", flush=True)
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
conn.execute("INSERT INTO vec_items(id) VALUES ('compute-centroids')")
|
||||||
|
conn.commit()
|
||||||
|
elapsed = time.perf_counter() - t0
|
||||||
|
print(f" Training done in {elapsed:.1f}s", flush=True)
|
||||||
|
return elapsed
|
||||||
|
|
||||||
|
|
||||||
|
def _ivf_describe(params):
|
||||||
|
return f"ivf nlist={params['nlist']:<4} nprobe={params['nprobe']}"
|
||||||
|
|
||||||
|
|
||||||
|
INDEX_REGISTRY["ivf"] = {
|
||||||
|
"defaults": {"nlist": 128, "nprobe": 16},
|
||||||
|
"create_table_sql": _ivf_create_table_sql,
|
||||||
|
"insert_sql": None,
|
||||||
|
"post_insert_hook": _ivf_post_insert_hook,
|
||||||
|
"run_query": None,
|
||||||
|
"describe": _ivf_describe,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Config parsing
|
# Config parsing
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
|
||||||
214
sqlite-vec-ivf-kmeans.c
Normal file
214
sqlite-vec-ivf-kmeans.c
Normal file
|
|
@ -0,0 +1,214 @@
|
||||||
|
/**
|
||||||
|
* sqlite-vec-ivf-kmeans.c — Pure k-means clustering algorithm.
|
||||||
|
*
|
||||||
|
* No SQLite dependency. Operates on float arrays in memory.
|
||||||
|
* #include'd into sqlite-vec.c after struct definitions.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef SQLITE_VEC_IVF_KMEANS_C
|
||||||
|
#define SQLITE_VEC_IVF_KMEANS_C
|
||||||
|
|
||||||
|
// When opened standalone in an editor, pull in types so the LSP is happy.
|
||||||
|
// When #include'd from sqlite-vec.c, SQLITE_VEC_H is already defined.
|
||||||
|
#ifndef SQLITE_VEC_H
|
||||||
|
#include "sqlite-vec.c" // IWYU pragma: keep
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <float.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#define VEC0_IVF_KMEANS_MAX_ITER 25
|
||||||
|
#define VEC0_IVF_KMEANS_DEFAULT_SEED 0
|
||||||
|
|
||||||
|
// Simple xorshift32 PRNG
|
||||||
|
static uint32_t ivf_xorshift32(uint32_t *state) {
|
||||||
|
uint32_t x = *state;
|
||||||
|
x ^= x << 13;
|
||||||
|
x ^= x >> 17;
|
||||||
|
x ^= x << 5;
|
||||||
|
*state = x;
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// L2 squared distance between two float vectors
|
||||||
|
static float ivf_l2_dist(const float *a, const float *b, int D) {
|
||||||
|
float sum = 0.0f;
|
||||||
|
for (int d = 0; d < D; d++) {
|
||||||
|
float diff = a[d] - b[d];
|
||||||
|
sum += diff * diff;
|
||||||
|
}
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find nearest centroid for a single vector. Returns centroid index.
|
||||||
|
static int ivf_nearest_centroid(const float *vec, const float *centroids,
|
||||||
|
int D, int k) {
|
||||||
|
float min_dist = FLT_MAX;
|
||||||
|
int best = 0;
|
||||||
|
for (int c = 0; c < k; c++) {
|
||||||
|
float dist = ivf_l2_dist(vec, ¢roids[c * D], D);
|
||||||
|
if (dist < min_dist) {
|
||||||
|
min_dist = dist;
|
||||||
|
best = c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return best;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* K-means++ initialization.
|
||||||
|
* Picks k initial centroids from the data with probability proportional
|
||||||
|
* to squared distance from nearest existing centroid.
|
||||||
|
*/
|
||||||
|
static int ivf_kmeans_init_plusplus(const float *vectors, int N, int D,
|
||||||
|
int k, uint32_t seed, float *centroids) {
|
||||||
|
if (N <= 0 || k <= 0 || D <= 0)
|
||||||
|
return -1;
|
||||||
|
if (seed == 0)
|
||||||
|
seed = 42;
|
||||||
|
|
||||||
|
// Pick first centroid randomly
|
||||||
|
int first = ivf_xorshift32(&seed) % N;
|
||||||
|
memcpy(centroids, &vectors[first * D], D * sizeof(float));
|
||||||
|
|
||||||
|
if (k == 1)
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
// Allocate distance array
|
||||||
|
float *dists = sqlite3_malloc64((i64)N * sizeof(float));
|
||||||
|
if (!dists)
|
||||||
|
return -1;
|
||||||
|
|
||||||
|
for (int c = 1; c < k; c++) {
|
||||||
|
// Compute D(x) = distance to nearest existing centroid
|
||||||
|
double total = 0.0;
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
float d = ivf_l2_dist(&vectors[i * D], ¢roids[(c - 1) * D], D);
|
||||||
|
if (c == 1 || d < dists[i]) {
|
||||||
|
dists[i] = d;
|
||||||
|
}
|
||||||
|
total += dists[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Weighted random selection
|
||||||
|
if (total <= 0.0) {
|
||||||
|
// All distances zero — pick randomly
|
||||||
|
int pick = ivf_xorshift32(&seed) % N;
|
||||||
|
memcpy(¢roids[c * D], &vectors[pick * D], D * sizeof(float));
|
||||||
|
} else {
|
||||||
|
double threshold = ((double)ivf_xorshift32(&seed) / (double)0xFFFFFFFF) * total;
|
||||||
|
double cumulative = 0.0;
|
||||||
|
int pick = N - 1;
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
cumulative += dists[i];
|
||||||
|
if (cumulative >= threshold) {
|
||||||
|
pick = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
memcpy(¢roids[c * D], &vectors[pick * D], D * sizeof(float));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3_free(dists);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Lloyd's k-means algorithm.
|
||||||
|
*
|
||||||
|
* @param vectors N*D float array (row-major)
|
||||||
|
* @param N number of vectors
|
||||||
|
* @param D dimensionality
|
||||||
|
* @param k number of clusters
|
||||||
|
* @param max_iter maximum iterations
|
||||||
|
* @param seed PRNG seed for initialization
|
||||||
|
* @param out_centroids output: k*D float array (caller-allocated)
|
||||||
|
* @return 0 on success, -1 on error
|
||||||
|
*/
|
||||||
|
static int ivf_kmeans(const float *vectors, int N, int D, int k,
|
||||||
|
int max_iter, uint32_t seed, float *out_centroids) {
|
||||||
|
if (N <= 0 || D <= 0 || k <= 0)
|
||||||
|
return -1;
|
||||||
|
|
||||||
|
// Clamp k to N
|
||||||
|
if (k > N)
|
||||||
|
k = N;
|
||||||
|
|
||||||
|
// Allocate working memory
|
||||||
|
int *assignments = sqlite3_malloc64((i64)N * sizeof(int));
|
||||||
|
float *new_centroids = sqlite3_malloc64((i64)k * D * sizeof(float));
|
||||||
|
int *counts = sqlite3_malloc64((i64)k * sizeof(int));
|
||||||
|
|
||||||
|
if (!assignments || !new_centroids || !counts) {
|
||||||
|
sqlite3_free(assignments);
|
||||||
|
sqlite3_free(new_centroids);
|
||||||
|
sqlite3_free(counts);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
memset(assignments, -1, N * sizeof(int));
|
||||||
|
|
||||||
|
// Initialize centroids via k-means++
|
||||||
|
if (ivf_kmeans_init_plusplus(vectors, N, D, k, seed, out_centroids) != 0) {
|
||||||
|
sqlite3_free(assignments);
|
||||||
|
sqlite3_free(new_centroids);
|
||||||
|
sqlite3_free(counts);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int iter = 0; iter < max_iter; iter++) {
|
||||||
|
// Assignment step
|
||||||
|
int changed = 0;
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
int nearest = ivf_nearest_centroid(&vectors[i * D], out_centroids, D, k);
|
||||||
|
if (nearest != assignments[i]) {
|
||||||
|
assignments[i] = nearest;
|
||||||
|
changed++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (changed == 0)
|
||||||
|
break;
|
||||||
|
|
||||||
|
// Update step
|
||||||
|
memset(new_centroids, 0, (size_t)k * D * sizeof(float));
|
||||||
|
memset(counts, 0, k * sizeof(int));
|
||||||
|
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
int c = assignments[i];
|
||||||
|
counts[c]++;
|
||||||
|
for (int d = 0; d < D; d++) {
|
||||||
|
new_centroids[c * D + d] += vectors[i * D + d];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int c = 0; c < k; c++) {
|
||||||
|
if (counts[c] == 0) {
|
||||||
|
// Empty cluster: reassign to farthest point from its nearest centroid
|
||||||
|
float max_dist = -1.0f;
|
||||||
|
int farthest = 0;
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
float d = ivf_l2_dist(&vectors[i * D],
|
||||||
|
&out_centroids[assignments[i] * D], D);
|
||||||
|
if (d > max_dist) {
|
||||||
|
max_dist = d;
|
||||||
|
farthest = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
memcpy(&out_centroids[c * D], &vectors[farthest * D],
|
||||||
|
D * sizeof(float));
|
||||||
|
} else {
|
||||||
|
for (int d = 0; d < D; d++) {
|
||||||
|
out_centroids[c * D + d] = new_centroids[c * D + d] / counts[c];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3_free(assignments);
|
||||||
|
sqlite3_free(new_centroids);
|
||||||
|
sqlite3_free(counts);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif /* SQLITE_VEC_IVF_KMEANS_C */
|
||||||
1445
sqlite-vec-ivf.c
Normal file
1445
sqlite-vec-ivf.c
Normal file
File diff suppressed because it is too large
Load diff
231
sqlite-vec.c
231
sqlite-vec.c
|
|
@ -93,6 +93,10 @@ typedef size_t usize;
|
||||||
#define COMPILER_SUPPORTS_VTAB_IN 1
|
#define COMPILER_SUPPORTS_VTAB_IN 1
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifndef SQLITE_VEC_ENABLE_IVF
|
||||||
|
#define SQLITE_VEC_ENABLE_IVF 1
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifndef SQLITE_SUBTYPE
|
#ifndef SQLITE_SUBTYPE
|
||||||
#define SQLITE_SUBTYPE 0x000100000
|
#define SQLITE_SUBTYPE 0x000100000
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -2539,6 +2543,7 @@ enum Vec0IndexType {
|
||||||
#if SQLITE_VEC_ENABLE_RESCORE
|
#if SQLITE_VEC_ENABLE_RESCORE
|
||||||
VEC0_INDEX_TYPE_RESCORE = 2,
|
VEC0_INDEX_TYPE_RESCORE = 2,
|
||||||
#endif
|
#endif
|
||||||
|
VEC0_INDEX_TYPE_IVF = 3,
|
||||||
};
|
};
|
||||||
|
|
||||||
#if SQLITE_VEC_ENABLE_RESCORE
|
#if SQLITE_VEC_ENABLE_RESCORE
|
||||||
|
|
@ -2553,6 +2558,22 @@ struct Vec0RescoreConfig {
|
||||||
};
|
};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if SQLITE_VEC_ENABLE_IVF
|
||||||
|
enum Vec0IvfQuantizer {
|
||||||
|
VEC0_IVF_QUANTIZER_NONE = 0,
|
||||||
|
VEC0_IVF_QUANTIZER_INT8 = 1,
|
||||||
|
VEC0_IVF_QUANTIZER_BINARY = 2,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Vec0IvfConfig {
|
||||||
|
int nlist; // number of centroids (0 = deferred)
|
||||||
|
int nprobe; // cells to probe at query time
|
||||||
|
int quantizer; // VEC0_IVF_QUANTIZER_NONE / INT8 / BINARY
|
||||||
|
int oversample; // >= 1 (1 = no oversampling)
|
||||||
|
};
|
||||||
|
#else
|
||||||
|
struct Vec0IvfConfig { char _unused; };
|
||||||
|
#endif
|
||||||
|
|
||||||
struct VectorColumnDefinition {
|
struct VectorColumnDefinition {
|
||||||
char *name;
|
char *name;
|
||||||
|
|
@ -2564,6 +2585,7 @@ struct VectorColumnDefinition {
|
||||||
#if SQLITE_VEC_ENABLE_RESCORE
|
#if SQLITE_VEC_ENABLE_RESCORE
|
||||||
struct Vec0RescoreConfig rescore;
|
struct Vec0RescoreConfig rescore;
|
||||||
#endif
|
#endif
|
||||||
|
struct Vec0IvfConfig ivf;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Vec0PartitionColumnDefinition {
|
struct Vec0PartitionColumnDefinition {
|
||||||
|
|
@ -2715,6 +2737,12 @@ static int vec0_parse_rescore_options(struct Vec0Scanner *scanner,
|
||||||
* @return int SQLITE_OK on success, SQLITE_EMPTY is it's not a vector column
|
* @return int SQLITE_OK on success, SQLITE_EMPTY is it's not a vector column
|
||||||
* definition, SQLITE_ERROR on error.
|
* definition, SQLITE_ERROR on error.
|
||||||
*/
|
*/
|
||||||
|
#if SQLITE_VEC_ENABLE_IVF
|
||||||
|
// Forward declaration — defined in sqlite-vec-ivf.c
|
||||||
|
static int vec0_parse_ivf_options(struct Vec0Scanner *scanner,
|
||||||
|
struct Vec0IvfConfig *config);
|
||||||
|
#endif
|
||||||
|
|
||||||
int vec0_parse_vector_column(const char *source, int source_length,
|
int vec0_parse_vector_column(const char *source, int source_length,
|
||||||
struct VectorColumnDefinition *outColumn) {
|
struct VectorColumnDefinition *outColumn) {
|
||||||
// parses a vector column definition like so:
|
// parses a vector column definition like so:
|
||||||
|
|
@ -2733,6 +2761,8 @@ int vec0_parse_vector_column(const char *source, int source_length,
|
||||||
struct Vec0RescoreConfig rescoreConfig;
|
struct Vec0RescoreConfig rescoreConfig;
|
||||||
memset(&rescoreConfig, 0, sizeof(rescoreConfig));
|
memset(&rescoreConfig, 0, sizeof(rescoreConfig));
|
||||||
#endif
|
#endif
|
||||||
|
struct Vec0IvfConfig ivfConfig;
|
||||||
|
memset(&ivfConfig, 0, sizeof(ivfConfig));
|
||||||
int dimensions;
|
int dimensions;
|
||||||
|
|
||||||
vec0_scanner_init(&scanner, source, source_length);
|
vec0_scanner_init(&scanner, source, source_length);
|
||||||
|
|
@ -2891,7 +2921,18 @@ int vec0_parse_vector_column(const char *source, int source_length,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
else {
|
else if (sqlite3_strnicmp(token.start, "ivf", indexNameLen) == 0) {
|
||||||
|
#if SQLITE_VEC_ENABLE_IVF
|
||||||
|
indexType = VEC0_INDEX_TYPE_IVF;
|
||||||
|
memset(&ivfConfig, 0, sizeof(ivfConfig));
|
||||||
|
rc = vec0_parse_ivf_options(&scanner, &ivfConfig);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
return SQLITE_ERROR; // IVF not compiled in
|
||||||
|
#endif
|
||||||
|
} else {
|
||||||
// unknown index type
|
// unknown index type
|
||||||
return SQLITE_ERROR;
|
return SQLITE_ERROR;
|
||||||
}
|
}
|
||||||
|
|
@ -2914,6 +2955,7 @@ int vec0_parse_vector_column(const char *source, int source_length,
|
||||||
#if SQLITE_VEC_ENABLE_RESCORE
|
#if SQLITE_VEC_ENABLE_RESCORE
|
||||||
outColumn->rescore = rescoreConfig;
|
outColumn->rescore = rescoreConfig;
|
||||||
#endif
|
#endif
|
||||||
|
outColumn->ivf = ivfConfig;
|
||||||
return SQLITE_OK;
|
return SQLITE_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -3279,6 +3321,18 @@ struct vec0_vtab {
|
||||||
|
|
||||||
int chunk_size;
|
int chunk_size;
|
||||||
|
|
||||||
|
#if SQLITE_VEC_ENABLE_IVF
|
||||||
|
// IVF cached state per vector column
|
||||||
|
char *shadowIvfCellsNames[VEC0_MAX_VECTOR_COLUMNS]; // table name for blob_open
|
||||||
|
int ivfTrainedCache[VEC0_MAX_VECTOR_COLUMNS]; // -1=unknown, 0=no, 1=yes
|
||||||
|
sqlite3_stmt *stmtIvfCellMeta[VEC0_MAX_VECTOR_COLUMNS]; // SELECT n_vectors, length(validity)*8 FROM cells WHERE cell_id=?
|
||||||
|
sqlite3_stmt *stmtIvfCellUpdateN[VEC0_MAX_VECTOR_COLUMNS]; // UPDATE cells SET n_vectors=n_vectors+? WHERE cell_id=?
|
||||||
|
sqlite3_stmt *stmtIvfRowidMapInsert[VEC0_MAX_VECTOR_COLUMNS]; // INSERT INTO rowid_map(rowid,cell_id,slot) VALUES(?,?,?)
|
||||||
|
sqlite3_stmt *stmtIvfRowidMapLookup[VEC0_MAX_VECTOR_COLUMNS]; // SELECT cell_id,slot FROM rowid_map WHERE rowid=?
|
||||||
|
sqlite3_stmt *stmtIvfRowidMapDelete[VEC0_MAX_VECTOR_COLUMNS]; // DELETE FROM rowid_map WHERE rowid=?
|
||||||
|
sqlite3_stmt *stmtIvfCentroidsAll[VEC0_MAX_VECTOR_COLUMNS]; // SELECT centroid_id,centroid FROM centroids
|
||||||
|
#endif
|
||||||
|
|
||||||
// select latest chunk from _chunks, getting chunk_id
|
// select latest chunk from _chunks, getting chunk_id
|
||||||
sqlite3_stmt *stmtLatestChunk;
|
sqlite3_stmt *stmtLatestChunk;
|
||||||
|
|
||||||
|
|
@ -3364,6 +3418,17 @@ void vec0_free_resources(vec0_vtab *p) {
|
||||||
p->stmtRowidsUpdatePosition = NULL;
|
p->stmtRowidsUpdatePosition = NULL;
|
||||||
sqlite3_finalize(p->stmtRowidsGetChunkPosition);
|
sqlite3_finalize(p->stmtRowidsGetChunkPosition);
|
||||||
p->stmtRowidsGetChunkPosition = NULL;
|
p->stmtRowidsGetChunkPosition = NULL;
|
||||||
|
|
||||||
|
#if SQLITE_VEC_ENABLE_IVF
|
||||||
|
for (int i = 0; i < VEC0_MAX_VECTOR_COLUMNS; i++) {
|
||||||
|
sqlite3_finalize(p->stmtIvfCellMeta[i]); p->stmtIvfCellMeta[i] = NULL;
|
||||||
|
sqlite3_finalize(p->stmtIvfCellUpdateN[i]); p->stmtIvfCellUpdateN[i] = NULL;
|
||||||
|
sqlite3_finalize(p->stmtIvfRowidMapInsert[i]); p->stmtIvfRowidMapInsert[i] = NULL;
|
||||||
|
sqlite3_finalize(p->stmtIvfRowidMapLookup[i]); p->stmtIvfRowidMapLookup[i] = NULL;
|
||||||
|
sqlite3_finalize(p->stmtIvfRowidMapDelete[i]); p->stmtIvfRowidMapDelete[i] = NULL;
|
||||||
|
sqlite3_finalize(p->stmtIvfCentroidsAll[i]); p->stmtIvfCentroidsAll[i] = NULL;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -3386,6 +3451,10 @@ void vec0_free(vec0_vtab *p) {
|
||||||
for (int i = 0; i < p->numVectorColumns; i++) {
|
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||||
sqlite3_free(p->shadowVectorChunksNames[i]);
|
sqlite3_free(p->shadowVectorChunksNames[i]);
|
||||||
p->shadowVectorChunksNames[i] = NULL;
|
p->shadowVectorChunksNames[i] = NULL;
|
||||||
|
#if SQLITE_VEC_ENABLE_IVF
|
||||||
|
sqlite3_free(p->shadowIvfCellsNames[i]);
|
||||||
|
p->shadowIvfCellsNames[i] = NULL;
|
||||||
|
#endif
|
||||||
|
|
||||||
#if SQLITE_VEC_ENABLE_RESCORE
|
#if SQLITE_VEC_ENABLE_RESCORE
|
||||||
sqlite3_free(p->shadowRescoreChunksNames[i]);
|
sqlite3_free(p->shadowRescoreChunksNames[i]);
|
||||||
|
|
@ -3674,12 +3743,25 @@ int vec0_result_id(vec0_vtab *p, sqlite3_context *context, i64 rowid) {
|
||||||
* will be stored.
|
* will be stored.
|
||||||
* @return int SQLITE_OK on success.
|
* @return int SQLITE_OK on success.
|
||||||
*/
|
*/
|
||||||
|
#if SQLITE_VEC_ENABLE_IVF
|
||||||
|
// Forward declaration — defined in sqlite-vec-ivf.c (included later)
|
||||||
|
static int ivf_get_vector_data(vec0_vtab *p, i64 rowid, int col_idx,
|
||||||
|
void **outVector, int *outVectorSize);
|
||||||
|
#endif
|
||||||
|
|
||||||
int vec0_get_vector_data(vec0_vtab *pVtab, i64 rowid, int vector_column_idx,
|
int vec0_get_vector_data(vec0_vtab *pVtab, i64 rowid, int vector_column_idx,
|
||||||
void **outVector, int *outVectorSize) {
|
void **outVector, int *outVectorSize) {
|
||||||
vec0_vtab *p = pVtab;
|
vec0_vtab *p = pVtab;
|
||||||
int rc, brc;
|
int rc, brc;
|
||||||
i64 chunk_id;
|
i64 chunk_id;
|
||||||
i64 chunk_offset;
|
i64 chunk_offset;
|
||||||
|
|
||||||
|
#if SQLITE_VEC_ENABLE_IVF
|
||||||
|
// IVF-indexed columns store vectors in _ivf_cells, not _vector_chunks
|
||||||
|
if (p->vector_columns[vector_column_idx].index_type == VEC0_INDEX_TYPE_IVF) {
|
||||||
|
return ivf_get_vector_data(p, rowid, vector_column_idx, outVector, outVectorSize);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
size_t size;
|
size_t size;
|
||||||
void *buf = NULL;
|
void *buf = NULL;
|
||||||
int blobOffset;
|
int blobOffset;
|
||||||
|
|
@ -4327,8 +4409,12 @@ int vec0_new_chunk(vec0_vtab *p, sqlite3_value ** partitionKeyValues, i64 *chunk
|
||||||
int vector_column_idx = p->user_column_idxs[i];
|
int vector_column_idx = p->user_column_idxs[i];
|
||||||
|
|
||||||
#if SQLITE_VEC_ENABLE_RESCORE
|
#if SQLITE_VEC_ENABLE_RESCORE
|
||||||
// Rescore columns don't use _vector_chunks for float storage
|
// Rescore and IVF columns don't use _vector_chunks for float storage
|
||||||
if (p->vector_columns[vector_column_idx].index_type == VEC0_INDEX_TYPE_RESCORE) {
|
if (p->vector_columns[vector_column_idx].index_type == VEC0_INDEX_TYPE_RESCORE
|
||||||
|
#if SQLITE_VEC_ENABLE_IVF
|
||||||
|
|| p->vector_columns[vector_column_idx].index_type == VEC0_INDEX_TYPE_IVF
|
||||||
|
#endif
|
||||||
|
) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -4500,6 +4586,12 @@ void vec0_cursor_clear(vec0_cursor *pCur) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IVF index implementation — #include'd here after all struct/helper definitions
|
||||||
|
#if SQLITE_VEC_ENABLE_IVF
|
||||||
|
#include "sqlite-vec-ivf-kmeans.c"
|
||||||
|
#include "sqlite-vec-ivf.c"
|
||||||
|
#endif
|
||||||
|
|
||||||
#define VEC_CONSTRUCTOR_ERROR "vec0 constructor error: "
|
#define VEC_CONSTRUCTOR_ERROR "vec0 constructor error: "
|
||||||
static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
|
static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
|
||||||
sqlite3_vtab **ppVtab, char **pzErr, bool isCreate) {
|
sqlite3_vtab **ppVtab, char **pzErr, bool isCreate) {
|
||||||
|
|
@ -4761,6 +4853,34 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// IVF indexes do not support auxiliary, metadata, or partition key columns.
|
||||||
|
{
|
||||||
|
int has_ivf = 0;
|
||||||
|
for (int i = 0; i < numVectorColumns; i++) {
|
||||||
|
if (pNew->vector_columns[i].index_type == VEC0_INDEX_TYPE_IVF) {
|
||||||
|
has_ivf = 1;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (has_ivf) {
|
||||||
|
if (numPartitionColumns > 0) {
|
||||||
|
*pzErr = sqlite3_mprintf(VEC_CONSTRUCTOR_ERROR
|
||||||
|
"partition key columns are not supported with IVF indexes");
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
if (numAuxiliaryColumns > 0) {
|
||||||
|
*pzErr = sqlite3_mprintf(VEC_CONSTRUCTOR_ERROR
|
||||||
|
"auxiliary columns are not supported with IVF indexes");
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
if (numMetadataColumns > 0) {
|
||||||
|
*pzErr = sqlite3_mprintf(VEC_CONSTRUCTOR_ERROR
|
||||||
|
"metadata columns are not supported with IVF indexes");
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
sqlite3_str *createStr = sqlite3_str_new(NULL);
|
sqlite3_str *createStr = sqlite3_str_new(NULL);
|
||||||
sqlite3_str_appendall(createStr, "CREATE TABLE x(");
|
sqlite3_str_appendall(createStr, "CREATE TABLE x(");
|
||||||
if (pkColumnName) {
|
if (pkColumnName) {
|
||||||
|
|
@ -4866,6 +4986,15 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
#if SQLITE_VEC_ENABLE_IVF
|
||||||
|
for (int i = 0; i < pNew->numVectorColumns; i++) {
|
||||||
|
if (pNew->vector_columns[i].index_type != VEC0_INDEX_TYPE_IVF) continue;
|
||||||
|
pNew->shadowIvfCellsNames[i] =
|
||||||
|
sqlite3_mprintf("%s_ivf_cells%02d", tableName, i);
|
||||||
|
if (!pNew->shadowIvfCellsNames[i]) goto error;
|
||||||
|
pNew->ivfTrainedCache[i] = -1; // unknown
|
||||||
|
}
|
||||||
|
#endif
|
||||||
for (int i = 0; i < pNew->numMetadataColumns; i++) {
|
for (int i = 0; i < pNew->numMetadataColumns; i++) {
|
||||||
pNew->shadowMetadataChunksNames[i] =
|
pNew->shadowMetadataChunksNames[i] =
|
||||||
sqlite3_mprintf("%s_metadatachunks%02d", tableName, i);
|
sqlite3_mprintf("%s_metadatachunks%02d", tableName, i);
|
||||||
|
|
@ -4989,8 +5118,8 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
|
||||||
|
|
||||||
for (int i = 0; i < pNew->numVectorColumns; i++) {
|
for (int i = 0; i < pNew->numVectorColumns; i++) {
|
||||||
#if SQLITE_VEC_ENABLE_RESCORE
|
#if SQLITE_VEC_ENABLE_RESCORE
|
||||||
// Rescore columns don't use _vector_chunks
|
// Rescore and IVF columns don't use _vector_chunks
|
||||||
if (pNew->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE)
|
if (pNew->vector_columns[i].index_type != VEC0_INDEX_TYPE_FLAT)
|
||||||
continue;
|
continue;
|
||||||
#endif
|
#endif
|
||||||
char *zSql = sqlite3_mprintf(VEC0_SHADOW_VECTOR_N_CREATE,
|
char *zSql = sqlite3_mprintf(VEC0_SHADOW_VECTOR_N_CREATE,
|
||||||
|
|
@ -5018,6 +5147,18 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if SQLITE_VEC_ENABLE_IVF
|
||||||
|
// Create IVF shadow tables for IVF-indexed vector columns
|
||||||
|
for (int i = 0; i < pNew->numVectorColumns; i++) {
|
||||||
|
if (pNew->vector_columns[i].index_type != VEC0_INDEX_TYPE_IVF) continue;
|
||||||
|
rc = ivf_create_shadow_tables(pNew, i);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
*pzErr = sqlite3_mprintf("Could not create IVF shadow tables for column %d", i);
|
||||||
|
goto error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// See SHADOW_TABLE_ROWID_QUIRK in vec0_new_chunk() — same "rowid PRIMARY KEY"
|
// See SHADOW_TABLE_ROWID_QUIRK in vec0_new_chunk() — same "rowid PRIMARY KEY"
|
||||||
// without INTEGER type issue applies here.
|
// without INTEGER type issue applies here.
|
||||||
for (int i = 0; i < pNew->numMetadataColumns; i++) {
|
for (int i = 0; i < pNew->numMetadataColumns; i++) {
|
||||||
|
|
@ -5153,7 +5294,7 @@ static int vec0Destroy(sqlite3_vtab *pVtab) {
|
||||||
|
|
||||||
for (int i = 0; i < p->numVectorColumns; i++) {
|
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||||
#if SQLITE_VEC_ENABLE_RESCORE
|
#if SQLITE_VEC_ENABLE_RESCORE
|
||||||
if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE)
|
if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_FLAT)
|
||||||
continue;
|
continue;
|
||||||
#endif
|
#endif
|
||||||
zSql = sqlite3_mprintf("DROP TABLE \"%w\".\"%w\"", p->schemaName,
|
zSql = sqlite3_mprintf("DROP TABLE \"%w\".\"%w\"", p->schemaName,
|
||||||
|
|
@ -5174,6 +5315,14 @@ static int vec0Destroy(sqlite3_vtab *pVtab) {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if SQLITE_VEC_ENABLE_IVF
|
||||||
|
// Drop IVF shadow tables
|
||||||
|
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||||
|
if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_IVF) continue;
|
||||||
|
ivf_drop_shadow_tables(p, i);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
if(p->numAuxiliaryColumns > 0) {
|
if(p->numAuxiliaryColumns > 0) {
|
||||||
zSql = sqlite3_mprintf("DROP TABLE " VEC0_SHADOW_AUXILIARY_NAME, p->schemaName, p->tableName);
|
zSql = sqlite3_mprintf("DROP TABLE " VEC0_SHADOW_AUXILIARY_NAME, p->schemaName, p->tableName);
|
||||||
rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0);
|
rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0);
|
||||||
|
|
@ -7186,6 +7335,21 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if SQLITE_VEC_ENABLE_IVF
|
||||||
|
// IVF dispatch: if vector column has IVF, use IVF query instead of chunk scan
|
||||||
|
if (vector_column->index_type == VEC0_INDEX_TYPE_IVF) {
|
||||||
|
rc = ivf_query_knn(p, vectorColumnIdx, queryVector,
|
||||||
|
(int)vector_column_byte_size(*vector_column), k, knn_data);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
pCur->knn_data = knn_data;
|
||||||
|
pCur->query_plan = VEC0_QUERY_PLAN_KNN;
|
||||||
|
rc = SQLITE_OK;
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
rc = vec0_chunks_iter(p, idxStr, argc, argv, &stmtChunks);
|
rc = vec0_chunks_iter(p, idxStr, argc, argv, &stmtChunks);
|
||||||
if (rc != SQLITE_OK) {
|
if (rc != SQLITE_OK) {
|
||||||
// IMP: V06942_23781
|
// IMP: V06942_23781
|
||||||
|
|
@ -8011,8 +8175,12 @@ int vec0Update_InsertWriteFinalStep(vec0_vtab *p, i64 chunk_rowid,
|
||||||
// Go insert the vector data into the vector chunk shadow tables
|
// Go insert the vector data into the vector chunk shadow tables
|
||||||
for (int i = 0; i < p->numVectorColumns; i++) {
|
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||||
#if SQLITE_VEC_ENABLE_RESCORE
|
#if SQLITE_VEC_ENABLE_RESCORE
|
||||||
// Rescore columns store float vectors in _rescore_vectors instead
|
// Rescore and IVF columns don't use _vector_chunks
|
||||||
if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE)
|
if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE
|
||||||
|
#if SQLITE_VEC_ENABLE_IVF
|
||||||
|
|| p->vector_columns[i].index_type == VEC0_INDEX_TYPE_IVF
|
||||||
|
#endif
|
||||||
|
)
|
||||||
continue;
|
continue;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
@ -8425,6 +8593,18 @@ int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv,
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if SQLITE_VEC_ENABLE_IVF
|
||||||
|
// Step #4: IVF index insert (if any vector column uses IVF)
|
||||||
|
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||||
|
if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_IVF) continue;
|
||||||
|
int vecSize = (int)vector_column_byte_size(p->vector_columns[i]);
|
||||||
|
rc = ivf_insert(p, i, rowid, vectorDatas[i], vecSize);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
if(p->numAuxiliaryColumns > 0) {
|
if(p->numAuxiliaryColumns > 0) {
|
||||||
sqlite3_stmt *stmt;
|
sqlite3_stmt *stmt;
|
||||||
sqlite3_str * s = sqlite3_str_new(NULL);
|
sqlite3_str * s = sqlite3_str_new(NULL);
|
||||||
|
|
@ -8616,8 +8796,8 @@ int vec0Update_Delete_ClearVectors(vec0_vtab *p, i64 chunk_id,
|
||||||
int rc, brc;
|
int rc, brc;
|
||||||
for (int i = 0; i < p->numVectorColumns; i++) {
|
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||||
#if SQLITE_VEC_ENABLE_RESCORE
|
#if SQLITE_VEC_ENABLE_RESCORE
|
||||||
// Rescore columns don't use _vector_chunks
|
// Non-FLAT columns don't use _vector_chunks
|
||||||
if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE)
|
if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_FLAT)
|
||||||
continue;
|
continue;
|
||||||
#endif
|
#endif
|
||||||
sqlite3_blob *blobVectors = NULL;
|
sqlite3_blob *blobVectors = NULL;
|
||||||
|
|
@ -8732,7 +8912,7 @@ int vec0Update_Delete_DeleteChunkIfEmpty(vec0_vtab *p, i64 chunk_id,
|
||||||
// Delete from each _vector_chunksNN
|
// Delete from each _vector_chunksNN
|
||||||
for (int i = 0; i < p->numVectorColumns; i++) {
|
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||||
#if SQLITE_VEC_ENABLE_RESCORE
|
#if SQLITE_VEC_ENABLE_RESCORE
|
||||||
if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE)
|
if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_FLAT)
|
||||||
continue;
|
continue;
|
||||||
#endif
|
#endif
|
||||||
zSql = sqlite3_mprintf(
|
zSql = sqlite3_mprintf(
|
||||||
|
|
@ -9009,6 +9189,15 @@ int vec0Update_Delete(sqlite3_vtab *pVTab, sqlite3_value *idValue) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if SQLITE_VEC_ENABLE_IVF
|
||||||
|
// 7. delete from IVF index
|
||||||
|
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||||
|
if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_IVF) continue;
|
||||||
|
rc = ivf_delete(p, i, rowid);
|
||||||
|
if (rc != SQLITE_OK) return rc;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
return SQLITE_OK;
|
return SQLITE_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -9284,6 +9473,18 @@ static int vec0Update(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv,
|
||||||
}
|
}
|
||||||
// INSERT operation
|
// INSERT operation
|
||||||
else if (argc > 1 && sqlite3_value_type(argv[0]) == SQLITE_NULL) {
|
else if (argc > 1 && sqlite3_value_type(argv[0]) == SQLITE_NULL) {
|
||||||
|
#if SQLITE_VEC_ENABLE_IVF
|
||||||
|
// Check for IVF command inserts: INSERT INTO t(rowid) VALUES ('compute-centroids')
|
||||||
|
// The id column holds the command string.
|
||||||
|
sqlite3_value *idVal = argv[2 + VEC0_COLUMN_ID];
|
||||||
|
if (sqlite3_value_type(idVal) == SQLITE_TEXT) {
|
||||||
|
const char *cmd = (const char *)sqlite3_value_text(idVal);
|
||||||
|
vec0_vtab *p = (vec0_vtab *)pVTab;
|
||||||
|
int cmdRc = ivf_handle_command(p, cmd, argc, argv);
|
||||||
|
if (cmdRc != SQLITE_EMPTY) return cmdRc; // handled (or error)
|
||||||
|
// SQLITE_EMPTY means not an IVF command — fall through to normal insert
|
||||||
|
}
|
||||||
|
#endif
|
||||||
return vec0Update_Insert(pVTab, argc, argv, pRowid);
|
return vec0Update_Insert(pVTab, argc, argv, pRowid);
|
||||||
}
|
}
|
||||||
// UPDATE operation
|
// UPDATE operation
|
||||||
|
|
@ -9431,9 +9632,15 @@ static sqlite3_module vec0Module = {
|
||||||
#define SQLITE_VEC_DEBUG_BUILD_RESCORE ""
|
#define SQLITE_VEC_DEBUG_BUILD_RESCORE ""
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if SQLITE_VEC_ENABLE_IVF
|
||||||
|
#define SQLITE_VEC_DEBUG_BUILD_IVF "ivf"
|
||||||
|
#else
|
||||||
|
#define SQLITE_VEC_DEBUG_BUILD_IVF ""
|
||||||
|
#endif
|
||||||
|
|
||||||
#define SQLITE_VEC_DEBUG_BUILD \
|
#define SQLITE_VEC_DEBUG_BUILD \
|
||||||
SQLITE_VEC_DEBUG_BUILD_AVX " " SQLITE_VEC_DEBUG_BUILD_NEON " " \
|
SQLITE_VEC_DEBUG_BUILD_AVX " " SQLITE_VEC_DEBUG_BUILD_NEON " " \
|
||||||
SQLITE_VEC_DEBUG_BUILD_RESCORE
|
SQLITE_VEC_DEBUG_BUILD_RESCORE " " SQLITE_VEC_DEBUG_BUILD_IVF
|
||||||
|
|
||||||
#define SQLITE_VEC_DEBUG_STRING \
|
#define SQLITE_VEC_DEBUG_STRING \
|
||||||
"Version: " SQLITE_VEC_VERSION "\n" \
|
"Version: " SQLITE_VEC_VERSION "\n" \
|
||||||
|
|
|
||||||
|
|
@ -93,13 +93,40 @@ $(TARGET_DIR)/rescore_quantize_edge: rescore-quantize-edge.c $(FUZZ_SRCS) | $(TA
|
||||||
$(TARGET_DIR)/rescore_interleave: rescore-interleave.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
$(TARGET_DIR)/rescore_interleave: rescore-interleave.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
$(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE $(FUZZ_SRCS) $< -o $@
|
$(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/ivf_create: ivf-create.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/ivf_operations: ivf-operations.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/ivf_quantize: ivf-quantize.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/ivf_kmeans: ivf-kmeans.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/ivf_shadow_corrupt: ivf-shadow-corrupt.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/ivf_knn_deep: ivf-knn-deep.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/ivf_cell_overflow: ivf-cell-overflow.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/ivf_rescore: ivf-rescore.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
FUZZ_TARGETS = vec0_create exec json numpy \
|
FUZZ_TARGETS = vec0_create exec json numpy \
|
||||||
shadow_corrupt vec0_operations scalar_functions \
|
shadow_corrupt vec0_operations scalar_functions \
|
||||||
vec0_create_full metadata_columns vec_each vec_mismatch \
|
vec0_create_full metadata_columns vec_each vec_mismatch \
|
||||||
vec0_delete_completeness \
|
vec0_delete_completeness \
|
||||||
rescore_operations rescore_create rescore_quantize \
|
rescore_operations rescore_create rescore_quantize \
|
||||||
rescore_shadow_corrupt rescore_knn_deep \
|
rescore_shadow_corrupt rescore_knn_deep \
|
||||||
rescore_quantize_edge rescore_interleave
|
rescore_quantize_edge rescore_interleave \
|
||||||
|
ivf_create ivf_operations \
|
||||||
|
ivf_quantize ivf_kmeans ivf_shadow_corrupt \
|
||||||
|
ivf_knn_deep ivf_cell_overflow ivf_rescore
|
||||||
|
|
||||||
all: $(addprefix $(TARGET_DIR)/,$(FUZZ_TARGETS))
|
all: $(addprefix $(TARGET_DIR)/,$(FUZZ_TARGETS))
|
||||||
|
|
||||||
|
|
|
||||||
192
tests/fuzz/ivf-cell-overflow.c
Normal file
192
tests/fuzz/ivf-cell-overflow.c
Normal file
|
|
@ -0,0 +1,192 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target: IVF cell overflow and boundary conditions.
|
||||||
|
*
|
||||||
|
* Pushes cells past VEC0_IVF_CELL_MAX_VECTORS (64) to trigger cell
|
||||||
|
* splitting, then exercises blob I/O at slot boundaries.
|
||||||
|
*
|
||||||
|
* Targets:
|
||||||
|
* - Cell splitting when n_vectors reaches cap (64)
|
||||||
|
* - Blob offset arithmetic: slot * vecSize, slot / 8, slot % 8
|
||||||
|
* - Validity bitmap at byte boundaries (slot 7->8, 15->16, etc.)
|
||||||
|
* - Insert into full cell -> create new cell path
|
||||||
|
* - Delete from various slot positions (first, last, middle)
|
||||||
|
* - Multiple cells per centroid
|
||||||
|
* - assign-vectors command with multi-cell centroids
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 8) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
// Use small dimensions for speed but enough vectors to overflow cells
|
||||||
|
int dim = (data[0] % 8) + 2; // 2..9
|
||||||
|
int nlist = (data[1] % 4) + 1; // 1..4
|
||||||
|
// We need >64 vectors to overflow a cell
|
||||||
|
int num_vecs = (data[2] % 64) + 65; // 65..128
|
||||||
|
int delete_pattern = data[3]; // Controls which vectors to delete
|
||||||
|
|
||||||
|
const uint8_t *payload = data + 4;
|
||||||
|
size_t payload_size = size - 4;
|
||||||
|
|
||||||
|
char sql[256];
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[%d] indexed by ivf(nlist=%d, nprobe=%d))",
|
||||||
|
dim, nlist, nlist);
|
||||||
|
|
||||||
|
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
// Insert enough vectors to overflow at least one cell
|
||||||
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
|
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
size_t offset = 0;
|
||||||
|
for (int i = 0; i < num_vecs; i++) {
|
||||||
|
float *vec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (!vec) break;
|
||||||
|
for (int d = 0; d < dim; d++) {
|
||||||
|
if (offset < payload_size) {
|
||||||
|
vec[d] = ((float)(int8_t)payload[offset++]) / 50.0f;
|
||||||
|
} else {
|
||||||
|
// Cluster vectors near specific centroids to ensure some cells overflow
|
||||||
|
int cluster = i % nlist;
|
||||||
|
vec[d] = (float)cluster + (float)(i % 10) * 0.01f + d * 0.001f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, (int64_t)(i + 1));
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
sqlite3_free(vec);
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmtInsert);
|
||||||
|
|
||||||
|
// Train to assign vectors to centroids (triggers cell building)
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"INSERT INTO v(rowid) VALUES ('compute-centroids')",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// Delete vectors at boundary positions based on fuzz data
|
||||||
|
// This tests validity bitmap manipulation at different slot positions
|
||||||
|
for (int i = 0; i < num_vecs; i++) {
|
||||||
|
int byte_idx = i / 8;
|
||||||
|
if (byte_idx < (int)payload_size && (payload[byte_idx] & (1 << (i % 8)))) {
|
||||||
|
// Use delete_pattern to thin deletions
|
||||||
|
if ((delete_pattern + i) % 3 == 0) {
|
||||||
|
char delsql[64];
|
||||||
|
snprintf(delsql, sizeof(delsql), "DELETE FROM v WHERE rowid = %d", i + 1);
|
||||||
|
sqlite3_exec(db, delsql, NULL, NULL, NULL);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert more vectors after deletions (into cells with holes)
|
||||||
|
{
|
||||||
|
sqlite3_stmt *si = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &si, NULL);
|
||||||
|
if (si) {
|
||||||
|
for (int i = 0; i < 10; i++) {
|
||||||
|
float *vec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (!vec) break;
|
||||||
|
for (int d = 0; d < dim; d++)
|
||||||
|
vec[d] = (float)(i + 200) * 0.01f;
|
||||||
|
sqlite3_reset(si);
|
||||||
|
sqlite3_bind_int64(si, 1, (int64_t)(num_vecs + i + 1));
|
||||||
|
sqlite3_bind_blob(si, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(si);
|
||||||
|
sqlite3_free(vec);
|
||||||
|
}
|
||||||
|
sqlite3_finalize(si);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// KNN query that must scan multiple cells per centroid
|
||||||
|
{
|
||||||
|
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (qvec) {
|
||||||
|
for (int d = 0; d < dim; d++) qvec[d] = 0.0f;
|
||||||
|
sqlite3_stmt *sk = NULL;
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 20");
|
||||||
|
sqlite3_prepare_v2(db, sql, -1, &sk, NULL);
|
||||||
|
if (sk) {
|
||||||
|
sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
while (sqlite3_step(sk) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(sk);
|
||||||
|
}
|
||||||
|
sqlite3_free(qvec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test assign-vectors with multi-cell state
|
||||||
|
// First clear centroids
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"INSERT INTO v(rowid) VALUES ('clear-centroids')",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// Set centroids manually, then assign
|
||||||
|
for (int c = 0; c < nlist; c++) {
|
||||||
|
float *cvec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (!cvec) break;
|
||||||
|
for (int d = 0; d < dim; d++) cvec[d] = (float)c + d * 0.1f;
|
||||||
|
|
||||||
|
char cmd[128];
|
||||||
|
snprintf(cmd, sizeof(cmd),
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES ('set-centroid:%d', ?)", c);
|
||||||
|
sqlite3_stmt *sc = NULL;
|
||||||
|
sqlite3_prepare_v2(db, cmd, -1, &sc, NULL);
|
||||||
|
if (sc) {
|
||||||
|
sqlite3_bind_blob(sc, 1, cvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(sc);
|
||||||
|
sqlite3_finalize(sc);
|
||||||
|
}
|
||||||
|
sqlite3_free(cvec);
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"INSERT INTO v(rowid) VALUES ('assign-vectors')",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// Final query after assign-vectors
|
||||||
|
{
|
||||||
|
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (qvec) {
|
||||||
|
for (int d = 0; d < dim; d++) qvec[d] = 1.0f;
|
||||||
|
sqlite3_stmt *sk = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 5",
|
||||||
|
-1, &sk, NULL);
|
||||||
|
if (sk) {
|
||||||
|
sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
while (sqlite3_step(sk) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(sk);
|
||||||
|
}
|
||||||
|
sqlite3_free(qvec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Full scan
|
||||||
|
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||||
|
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
36
tests/fuzz/ivf-create.c
Normal file
36
tests/fuzz/ivf-create.c
Normal file
|
|
@ -0,0 +1,36 @@
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
sqlite3_stmt *stmt;
|
||||||
|
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
sqlite3_str *s = sqlite3_str_new(NULL);
|
||||||
|
assert(s);
|
||||||
|
sqlite3_str_appendall(s, "CREATE VIRTUAL TABLE v USING vec0(emb float[4] indexed by ivf(");
|
||||||
|
sqlite3_str_appendf(s, "%.*s", (int)size, data);
|
||||||
|
sqlite3_str_appendall(s, "))");
|
||||||
|
const char *zSql = sqlite3_str_finish(s);
|
||||||
|
assert(zSql);
|
||||||
|
|
||||||
|
rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, NULL);
|
||||||
|
sqlite3_free((void *)zSql);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
sqlite3_step(stmt);
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
16
tests/fuzz/ivf-create.dict
Normal file
16
tests/fuzz/ivf-create.dict
Normal file
|
|
@ -0,0 +1,16 @@
|
||||||
|
"nlist"
|
||||||
|
"nprobe"
|
||||||
|
"quantizer"
|
||||||
|
"oversample"
|
||||||
|
"binary"
|
||||||
|
"int8"
|
||||||
|
"none"
|
||||||
|
"="
|
||||||
|
","
|
||||||
|
"("
|
||||||
|
")"
|
||||||
|
"0"
|
||||||
|
"1"
|
||||||
|
"128"
|
||||||
|
"65536"
|
||||||
|
"65537"
|
||||||
180
tests/fuzz/ivf-kmeans.c
Normal file
180
tests/fuzz/ivf-kmeans.c
Normal file
|
|
@ -0,0 +1,180 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target: IVF k-means clustering.
|
||||||
|
*
|
||||||
|
* Builds a table, inserts fuzz-controlled vectors, then runs
|
||||||
|
* compute-centroids with fuzz-controlled parameters (nlist, max_iter, seed).
|
||||||
|
* Targets:
|
||||||
|
* - kmeans with N < k (clamping), N == 1, k == 1
|
||||||
|
* - kmeans with duplicate/identical vectors (all distances zero)
|
||||||
|
* - kmeans with NaN/Inf vectors
|
||||||
|
* - Empty cluster reassignment path (farthest-point heuristic)
|
||||||
|
* - Large nlist relative to N
|
||||||
|
* - The compute-centroids:{json} command parsing
|
||||||
|
* - clear-centroids followed by compute-centroids (round-trip)
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 10) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
// Parse fuzz header
|
||||||
|
// Byte 0-1: dimension (1..128)
|
||||||
|
// Byte 2: nlist for CREATE (1..64)
|
||||||
|
// Byte 3: nlist override for compute-centroids (0 = use default)
|
||||||
|
// Byte 4: max_iter (1..50)
|
||||||
|
// Byte 5-8: seed
|
||||||
|
// Byte 9: num_vectors (1..64)
|
||||||
|
// Remaining: vector float data
|
||||||
|
|
||||||
|
int dim = (data[0] | (data[1] << 8)) % 128 + 1;
|
||||||
|
int nlist_create = (data[2] % 64) + 1;
|
||||||
|
int nlist_override = data[3] % 65; // 0 means use table default
|
||||||
|
int max_iter = (data[4] % 50) + 1;
|
||||||
|
uint32_t seed = (uint32_t)data[5] | ((uint32_t)data[6] << 8) |
|
||||||
|
((uint32_t)data[7] << 16) | ((uint32_t)data[8] << 24);
|
||||||
|
int num_vecs = (data[9] % 64) + 1;
|
||||||
|
|
||||||
|
const uint8_t *payload = data + 10;
|
||||||
|
size_t payload_size = size - 10;
|
||||||
|
|
||||||
|
char sql[256];
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[%d] indexed by ivf(nlist=%d, nprobe=%d))",
|
||||||
|
dim, nlist_create, nlist_create);
|
||||||
|
|
||||||
|
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
// Insert vectors
|
||||||
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
|
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
size_t offset = 0;
|
||||||
|
for (int i = 0; i < num_vecs; i++) {
|
||||||
|
float *vec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (!vec) break;
|
||||||
|
|
||||||
|
for (int d = 0; d < dim; d++) {
|
||||||
|
if (offset + 4 <= payload_size) {
|
||||||
|
memcpy(&vec[d], payload + offset, sizeof(float));
|
||||||
|
offset += 4;
|
||||||
|
} else if (offset < payload_size) {
|
||||||
|
// Scale to interesting range including values > 1, < -1
|
||||||
|
vec[d] = ((float)(int8_t)payload[offset++]) / 5.0f;
|
||||||
|
} else {
|
||||||
|
// Reuse earlier bytes to fill remaining dimensions
|
||||||
|
vec[d] = (float)(i * dim + d) * 0.01f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, (int64_t)(i + 1));
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
sqlite3_free(vec);
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmtInsert);
|
||||||
|
|
||||||
|
// Exercise compute-centroids with JSON options
|
||||||
|
{
|
||||||
|
char cmd[256];
|
||||||
|
snprintf(cmd, sizeof(cmd),
|
||||||
|
"INSERT INTO v(rowid) VALUES "
|
||||||
|
"('compute-centroids:{\"nlist\":%d,\"max_iterations\":%d,\"seed\":%u}')",
|
||||||
|
nlist_override, max_iter, seed);
|
||||||
|
sqlite3_exec(db, cmd, NULL, NULL, NULL);
|
||||||
|
}
|
||||||
|
|
||||||
|
// KNN query after training
|
||||||
|
{
|
||||||
|
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (qvec) {
|
||||||
|
for (int d = 0; d < dim; d++) {
|
||||||
|
qvec[d] = (d < 3) ? 1.0f : 0.0f;
|
||||||
|
}
|
||||||
|
sqlite3_stmt *stmtKnn = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 5",
|
||||||
|
-1, &stmtKnn, NULL);
|
||||||
|
if (stmtKnn) {
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(stmtKnn);
|
||||||
|
}
|
||||||
|
sqlite3_free(qvec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear centroids and re-compute to test round-trip
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"INSERT INTO v(rowid) VALUES ('clear-centroids')",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// Insert a few more vectors in untrained state
|
||||||
|
{
|
||||||
|
sqlite3_stmt *si = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &si, NULL);
|
||||||
|
if (si) {
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
float *vec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (!vec) break;
|
||||||
|
for (int d = 0; d < dim; d++) vec[d] = (float)(i + 100) * 0.1f;
|
||||||
|
sqlite3_reset(si);
|
||||||
|
sqlite3_bind_int64(si, 1, (int64_t)(num_vecs + i + 1));
|
||||||
|
sqlite3_bind_blob(si, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(si);
|
||||||
|
sqlite3_free(vec);
|
||||||
|
}
|
||||||
|
sqlite3_finalize(si);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-train
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"INSERT INTO v(rowid) VALUES ('compute-centroids')",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// Delete some rows after training, then query
|
||||||
|
sqlite3_exec(db, "DELETE FROM v WHERE rowid = 1", NULL, NULL, NULL);
|
||||||
|
sqlite3_exec(db, "DELETE FROM v WHERE rowid = 2", NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// Query after deletes
|
||||||
|
{
|
||||||
|
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (qvec) {
|
||||||
|
for (int d = 0; d < dim; d++) qvec[d] = 0.5f;
|
||||||
|
sqlite3_stmt *stmtKnn = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 10",
|
||||||
|
-1, &stmtKnn, NULL);
|
||||||
|
if (stmtKnn) {
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(stmtKnn);
|
||||||
|
}
|
||||||
|
sqlite3_free(qvec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
199
tests/fuzz/ivf-knn-deep.c
Normal file
199
tests/fuzz/ivf-knn-deep.c
Normal file
|
|
@ -0,0 +1,199 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target: IVF KNN search deep paths.
|
||||||
|
*
|
||||||
|
* Exercises the full KNN pipeline with fuzz-controlled:
|
||||||
|
* - nprobe values (including > nlist, =1, =nlist)
|
||||||
|
* - Query vectors (including adversarial floats)
|
||||||
|
* - Mix of trained/untrained state
|
||||||
|
* - Oversample + rescore path (quantizer=int8 with oversample>1)
|
||||||
|
* - Multiple interleaved KNN queries
|
||||||
|
* - Candidate array realloc path (many vectors in probed cells)
|
||||||
|
*
|
||||||
|
* Targets:
|
||||||
|
* - ivf_scan_cells_from_stmt: candidate realloc, distance computation
|
||||||
|
* - ivf_query_knn: centroid sorting, nprobe selection
|
||||||
|
* - Oversample rescore: re-ranking with full-precision vectors
|
||||||
|
* - qsort with NaN distances
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
static uint16_t read_u16(const uint8_t *p) {
|
||||||
|
return (uint16_t)(p[0] | (p[1] << 8));
|
||||||
|
}
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 16) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
// Header
|
||||||
|
int dim = (data[0] % 32) + 2; // 2..33
|
||||||
|
int nlist = (data[1] % 16) + 1; // 1..16
|
||||||
|
int nprobe_initial = (data[2] % 20) + 1; // 1..20 (can be > nlist)
|
||||||
|
int quantizer_type = data[3] % 3; // 0=none, 1=int8, 2=binary
|
||||||
|
int oversample = (data[4] % 4) + 1; // 1..4
|
||||||
|
int num_vecs = (data[5] % 80) + 4; // 4..83
|
||||||
|
int num_queries = (data[6] % 8) + 1; // 1..8
|
||||||
|
int k_limit = (data[7] % 20) + 1; // 1..20
|
||||||
|
|
||||||
|
const uint8_t *payload = data + 8;
|
||||||
|
size_t payload_size = size - 8;
|
||||||
|
|
||||||
|
// For binary quantizer, dimension must be multiple of 8
|
||||||
|
if (quantizer_type == 2) {
|
||||||
|
dim = ((dim + 7) / 8) * 8;
|
||||||
|
if (dim == 0) dim = 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char *qname;
|
||||||
|
switch (quantizer_type) {
|
||||||
|
case 1: qname = "int8"; break;
|
||||||
|
case 2: qname = "binary"; break;
|
||||||
|
default: qname = "none"; break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Oversample only valid with quantization
|
||||||
|
if (quantizer_type == 0) oversample = 1;
|
||||||
|
|
||||||
|
// Cap nprobe to nlist for CREATE (parser rejects nprobe > nlist)
|
||||||
|
int nprobe_create = nprobe_initial <= nlist ? nprobe_initial : nlist;
|
||||||
|
|
||||||
|
char sql[512];
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[%d] indexed by ivf(nlist=%d, nprobe=%d, quantizer=%s%s))",
|
||||||
|
dim, nlist, nprobe_create, qname,
|
||||||
|
oversample > 1 ? ", oversample=2" : "");
|
||||||
|
|
||||||
|
// If that fails (e.g. oversample with none), try without oversample
|
||||||
|
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[%d] indexed by ivf(nlist=%d, nprobe=%d, quantizer=%s))",
|
||||||
|
dim, nlist, nprobe_create, qname);
|
||||||
|
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert vectors
|
||||||
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
|
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
size_t offset = 0;
|
||||||
|
for (int i = 0; i < num_vecs; i++) {
|
||||||
|
float *vec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (!vec) break;
|
||||||
|
for (int d = 0; d < dim; d++) {
|
||||||
|
if (offset < payload_size) {
|
||||||
|
vec[d] = ((float)(int8_t)payload[offset++]) / 20.0f;
|
||||||
|
} else {
|
||||||
|
vec[d] = (float)((i * dim + d) % 256 - 128) / 128.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, (int64_t)(i + 1));
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
sqlite3_free(vec);
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmtInsert);
|
||||||
|
|
||||||
|
// Query BEFORE training (flat scan path)
|
||||||
|
{
|
||||||
|
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (qvec) {
|
||||||
|
for (int d = 0; d < dim; d++) qvec[d] = 0.5f;
|
||||||
|
sqlite3_stmt *sk = NULL;
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit);
|
||||||
|
sqlite3_prepare_v2(db, sql, -1, &sk, NULL);
|
||||||
|
if (sk) {
|
||||||
|
sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
while (sqlite3_step(sk) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(sk);
|
||||||
|
}
|
||||||
|
sqlite3_free(qvec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Train
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"INSERT INTO v(rowid) VALUES ('compute-centroids')",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// Change nprobe at runtime (can exceed nlist -- tests clamping in query)
|
||||||
|
{
|
||||||
|
char cmd[64];
|
||||||
|
snprintf(cmd, sizeof(cmd),
|
||||||
|
"INSERT INTO v(rowid) VALUES ('nprobe=%d')", nprobe_initial);
|
||||||
|
sqlite3_exec(db, cmd, NULL, NULL, NULL);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiple KNN queries with different fuzz-derived query vectors
|
||||||
|
for (int q = 0; q < num_queries; q++) {
|
||||||
|
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (!qvec) break;
|
||||||
|
for (int d = 0; d < dim; d++) {
|
||||||
|
if (offset < payload_size) {
|
||||||
|
qvec[d] = ((float)(int8_t)payload[offset++]) / 10.0f;
|
||||||
|
} else {
|
||||||
|
qvec[d] = (q == 0) ? 1.0f : 0.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3_stmt *sk = NULL;
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit);
|
||||||
|
sqlite3_prepare_v2(db, sql, -1, &sk, NULL);
|
||||||
|
if (sk) {
|
||||||
|
sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
while (sqlite3_step(sk) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(sk);
|
||||||
|
}
|
||||||
|
sqlite3_free(qvec);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete half the vectors then query again
|
||||||
|
for (int i = 1; i <= num_vecs / 2; i++) {
|
||||||
|
char delsql[64];
|
||||||
|
snprintf(delsql, sizeof(delsql), "DELETE FROM v WHERE rowid = %d", i);
|
||||||
|
sqlite3_exec(db, delsql, NULL, NULL, NULL);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query after mass deletion
|
||||||
|
{
|
||||||
|
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (qvec) {
|
||||||
|
for (int d = 0; d < dim; d++) qvec[d] = -0.5f;
|
||||||
|
sqlite3_stmt *sk = NULL;
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit);
|
||||||
|
sqlite3_prepare_v2(db, sql, -1, &sk, NULL);
|
||||||
|
if (sk) {
|
||||||
|
sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
while (sqlite3_step(sk) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(sk);
|
||||||
|
}
|
||||||
|
sqlite3_free(qvec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
121
tests/fuzz/ivf-operations.c
Normal file
121
tests/fuzz/ivf-operations.c
Normal file
|
|
@ -0,0 +1,121 @@
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 6) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
|
sqlite3_stmt *stmtDelete = NULL;
|
||||||
|
sqlite3_stmt *stmtKnn = NULL;
|
||||||
|
sqlite3_stmt *stmtScan = NULL;
|
||||||
|
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
rc = sqlite3_exec(db,
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0(emb float[4] indexed by ivf(nlist=4, nprobe=4))",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 3",
|
||||||
|
-1, &stmtKnn, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid FROM v", -1, &stmtScan, NULL);
|
||||||
|
|
||||||
|
if (!stmtInsert || !stmtDelete || !stmtKnn || !stmtScan) goto cleanup;
|
||||||
|
|
||||||
|
size_t i = 0;
|
||||||
|
while (i + 2 <= size) {
|
||||||
|
uint8_t op = data[i++] % 7;
|
||||||
|
uint8_t rowid_byte = data[i++];
|
||||||
|
int64_t rowid = (int64_t)(rowid_byte % 32) + 1;
|
||||||
|
|
||||||
|
switch (op) {
|
||||||
|
case 0: {
|
||||||
|
// INSERT: consume 16 bytes for 4 floats, or use what's left
|
||||||
|
float vec[4] = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||||
|
for (int j = 0; j < 4 && i < size; j++, i++) {
|
||||||
|
vec[j] = (float)((int8_t)data[i]) / 10.0f;
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 1: {
|
||||||
|
// DELETE
|
||||||
|
sqlite3_reset(stmtDelete);
|
||||||
|
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||||
|
sqlite3_step(stmtDelete);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 2: {
|
||||||
|
// KNN query with a fixed query vector
|
||||||
|
float qvec[4] = {1.0f, 0.0f, 0.0f, 0.0f};
|
||||||
|
sqlite3_reset(stmtKnn);
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 3: {
|
||||||
|
// Full scan
|
||||||
|
sqlite3_reset(stmtScan);
|
||||||
|
while (sqlite3_step(stmtScan) == SQLITE_ROW) {}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 4: {
|
||||||
|
// compute-centroids command
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"INSERT INTO v(rowid) VALUES ('compute-centroids')",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 5: {
|
||||||
|
// clear-centroids command
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"INSERT INTO v(rowid) VALUES ('clear-centroids')",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 6: {
|
||||||
|
// nprobe=N command
|
||||||
|
if (i < size) {
|
||||||
|
uint8_t n = data[i++];
|
||||||
|
int nprobe = (n % 4) + 1;
|
||||||
|
char buf[64];
|
||||||
|
snprintf(buf, sizeof(buf),
|
||||||
|
"INSERT INTO v(rowid) VALUES ('nprobe=%d')", nprobe);
|
||||||
|
sqlite3_exec(db, buf, NULL, NULL, NULL);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Final operations — must not crash regardless of prior state
|
||||||
|
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||||
|
|
||||||
|
cleanup:
|
||||||
|
sqlite3_finalize(stmtInsert);
|
||||||
|
sqlite3_finalize(stmtDelete);
|
||||||
|
sqlite3_finalize(stmtKnn);
|
||||||
|
sqlite3_finalize(stmtScan);
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
129
tests/fuzz/ivf-quantize.c
Normal file
129
tests/fuzz/ivf-quantize.c
Normal file
|
|
@ -0,0 +1,129 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target: IVF quantization functions.
|
||||||
|
*
|
||||||
|
* Directly exercises ivf_quantize_int8 and ivf_quantize_binary with
|
||||||
|
* fuzz-controlled dimensions and float data. Targets:
|
||||||
|
* - ivf_quantize_int8: clamping, int8 overflow boundary
|
||||||
|
* - ivf_quantize_binary: D not divisible by 8, memset(D/8) undercount
|
||||||
|
* - Round-trip through CREATE TABLE + INSERT with quantized IVF
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 8) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
// Byte 0: quantizer type (0=int8, 1=binary)
|
||||||
|
// Byte 1: dimension (1..64, but we test edge cases)
|
||||||
|
// Byte 2: nlist (1..8)
|
||||||
|
// Byte 3: num_vectors to insert (1..32)
|
||||||
|
// Remaining: float data
|
||||||
|
int qtype = data[0] % 2;
|
||||||
|
int dim = (data[1] % 64) + 1;
|
||||||
|
int nlist = (data[2] % 8) + 1;
|
||||||
|
int num_vecs = (data[3] % 32) + 1;
|
||||||
|
const uint8_t *payload = data + 4;
|
||||||
|
size_t payload_size = size - 4;
|
||||||
|
|
||||||
|
// For binary quantizer, D must be multiple of 8 to avoid the D/8 bug
|
||||||
|
// in production. But we explicitly want to test non-multiples too to
|
||||||
|
// find the bug. Use dim as-is.
|
||||||
|
const char *quantizer = qtype ? "binary" : "int8";
|
||||||
|
|
||||||
|
// Binary quantizer needs D multiple of 8 in current code, but let's
|
||||||
|
// test both valid and invalid dimensions to see what happens.
|
||||||
|
// For binary with non-multiple-of-8, the code does memset(dst, 0, D/8)
|
||||||
|
// which underallocates when D%8 != 0.
|
||||||
|
char sql[256];
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[%d] indexed by ivf(nlist=%d, nprobe=%d, quantizer=%s))",
|
||||||
|
dim, nlist, nlist, quantizer);
|
||||||
|
|
||||||
|
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
// Insert vectors with fuzz-controlled float values
|
||||||
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
|
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
size_t offset = 0;
|
||||||
|
for (int i = 0; i < num_vecs && offset < payload_size; i++) {
|
||||||
|
// Build float vector from fuzz data
|
||||||
|
float *vec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (!vec) break;
|
||||||
|
|
||||||
|
for (int d = 0; d < dim; d++) {
|
||||||
|
if (offset + 4 <= payload_size) {
|
||||||
|
// Use raw bytes as float -- can produce NaN, Inf, denormals
|
||||||
|
memcpy(&vec[d], payload + offset, sizeof(float));
|
||||||
|
offset += 4;
|
||||||
|
} else if (offset < payload_size) {
|
||||||
|
// Partial: use byte as scaled value
|
||||||
|
vec[d] = ((float)(int8_t)payload[offset++]) / 50.0f;
|
||||||
|
} else {
|
||||||
|
vec[d] = 0.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, (int64_t)(i + 1));
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
sqlite3_free(vec);
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmtInsert);
|
||||||
|
|
||||||
|
// Trigger compute-centroids to exercise kmeans + quantization together
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"INSERT INTO v(rowid) VALUES ('compute-centroids')",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// KNN query with fuzz-derived query vector
|
||||||
|
{
|
||||||
|
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (qvec) {
|
||||||
|
for (int d = 0; d < dim; d++) {
|
||||||
|
if (offset < payload_size) {
|
||||||
|
qvec[d] = ((float)(int8_t)payload[offset++]) / 10.0f;
|
||||||
|
} else {
|
||||||
|
qvec[d] = 1.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3_stmt *stmtKnn = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 5",
|
||||||
|
-1, &stmtKnn, NULL);
|
||||||
|
if (stmtKnn) {
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(stmtKnn);
|
||||||
|
}
|
||||||
|
sqlite3_free(qvec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Full scan
|
||||||
|
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||||
|
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
182
tests/fuzz/ivf-rescore.c
Normal file
182
tests/fuzz/ivf-rescore.c
Normal file
|
|
@ -0,0 +1,182 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target: IVF oversample + rescore path.
|
||||||
|
*
|
||||||
|
* Specifically targets the code path where quantizer != none AND
|
||||||
|
* oversample > 1, which triggers:
|
||||||
|
* 1. Quantized KNN scan to collect oversample*k candidates
|
||||||
|
* 2. Full-precision vector lookup from _ivf_vectors table
|
||||||
|
* 3. Re-scoring with float32 distances
|
||||||
|
* 4. Re-sort and truncation
|
||||||
|
*
|
||||||
|
* This path has the most complex memory management in the KNN query:
|
||||||
|
* - Two separate distance computations (quantized + float)
|
||||||
|
* - Cross-table lookups (cells + vectors KV store)
|
||||||
|
* - Candidate array resizing
|
||||||
|
* - qsort over partially re-scored arrays
|
||||||
|
*
|
||||||
|
* Also tests the int8 + binary quantization round-trip fidelity
|
||||||
|
* under adversarial float inputs.
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 12) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
// Header
|
||||||
|
int quantizer_type = (data[0] % 2) + 1; // 1=int8, 2=binary (never none)
|
||||||
|
int dim = (data[1] % 32) + 8; // 8..39
|
||||||
|
int nlist = (data[2] % 8) + 1; // 1..8
|
||||||
|
int oversample = (data[3] % 4) + 2; // 2..5 (always > 1)
|
||||||
|
int num_vecs = (data[4] % 60) + 8; // 8..67
|
||||||
|
int k_limit = (data[5] % 15) + 1; // 1..15
|
||||||
|
|
||||||
|
const uint8_t *payload = data + 6;
|
||||||
|
size_t payload_size = size - 6;
|
||||||
|
|
||||||
|
// Binary quantizer needs D multiple of 8
|
||||||
|
if (quantizer_type == 2) {
|
||||||
|
dim = ((dim + 7) / 8) * 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char *qname = (quantizer_type == 1) ? "int8" : "binary";
|
||||||
|
|
||||||
|
char sql[512];
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[%d] indexed by ivf(nlist=%d, nprobe=%d, quantizer=%s, oversample=%d))",
|
||||||
|
dim, nlist, nlist, qname, oversample);
|
||||||
|
|
||||||
|
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
// Insert vectors with diverse values
|
||||||
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
|
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
size_t offset = 0;
|
||||||
|
for (int i = 0; i < num_vecs; i++) {
|
||||||
|
float *vec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (!vec) break;
|
||||||
|
for (int d = 0; d < dim; d++) {
|
||||||
|
if (offset + 4 <= payload_size) {
|
||||||
|
// Use raw bytes as float for adversarial values
|
||||||
|
memcpy(&vec[d], payload + offset, sizeof(float));
|
||||||
|
offset += 4;
|
||||||
|
// Sanitize: replace NaN/Inf with bounded values to avoid
|
||||||
|
// poisoning the entire computation. We want edge values,
|
||||||
|
// not complete nonsense.
|
||||||
|
if (isnan(vec[d]) || isinf(vec[d])) {
|
||||||
|
vec[d] = (vec[d] > 0) ? 1e6f : -1e6f;
|
||||||
|
if (isnan(vec[d])) vec[d] = 0.0f;
|
||||||
|
}
|
||||||
|
} else if (offset < payload_size) {
|
||||||
|
vec[d] = ((float)(int8_t)payload[offset++]) / 30.0f;
|
||||||
|
} else {
|
||||||
|
vec[d] = (float)(i * dim + d) * 0.001f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, (int64_t)(i + 1));
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
sqlite3_free(vec);
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmtInsert);
|
||||||
|
|
||||||
|
// Train
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"INSERT INTO v(rowid) VALUES ('compute-centroids')",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// Multiple KNN queries to exercise rescore path
|
||||||
|
for (int q = 0; q < 4; q++) {
|
||||||
|
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (!qvec) break;
|
||||||
|
for (int d = 0; d < dim; d++) {
|
||||||
|
if (offset < payload_size) {
|
||||||
|
qvec[d] = ((float)(int8_t)payload[offset++]) / 10.0f;
|
||||||
|
} else {
|
||||||
|
qvec[d] = (q == 0) ? 1.0f : (q == 1) ? -1.0f : 0.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3_stmt *sk = NULL;
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit);
|
||||||
|
sqlite3_prepare_v2(db, sql, -1, &sk, NULL);
|
||||||
|
if (sk) {
|
||||||
|
sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
while (sqlite3_step(sk) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(sk);
|
||||||
|
}
|
||||||
|
sqlite3_free(qvec);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete some vectors, then query again (rescore with missing _ivf_vectors rows)
|
||||||
|
for (int i = 1; i <= num_vecs / 3; i++) {
|
||||||
|
char delsql[64];
|
||||||
|
snprintf(delsql, sizeof(delsql), "DELETE FROM v WHERE rowid = %d", i);
|
||||||
|
sqlite3_exec(db, delsql, NULL, NULL, NULL);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (qvec) {
|
||||||
|
for (int d = 0; d < dim; d++) qvec[d] = 0.5f;
|
||||||
|
sqlite3_stmt *sk = NULL;
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit);
|
||||||
|
sqlite3_prepare_v2(db, sql, -1, &sk, NULL);
|
||||||
|
if (sk) {
|
||||||
|
sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
while (sqlite3_step(sk) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(sk);
|
||||||
|
}
|
||||||
|
sqlite3_free(qvec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrain after deletions
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"INSERT INTO v(rowid) VALUES ('compute-centroids')",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// Query after retrain
|
||||||
|
{
|
||||||
|
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (qvec) {
|
||||||
|
for (int d = 0; d < dim; d++) qvec[d] = -0.3f;
|
||||||
|
sqlite3_stmt *sk = NULL;
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit);
|
||||||
|
sqlite3_prepare_v2(db, sql, -1, &sk, NULL);
|
||||||
|
if (sk) {
|
||||||
|
sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
while (sqlite3_step(sk) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(sk);
|
||||||
|
}
|
||||||
|
sqlite3_free(qvec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
228
tests/fuzz/ivf-shadow-corrupt.c
Normal file
228
tests/fuzz/ivf-shadow-corrupt.c
Normal file
|
|
@ -0,0 +1,228 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target: IVF shadow table corruption.
|
||||||
|
*
|
||||||
|
* Creates a trained IVF table, then corrupts IVF shadow table blobs
|
||||||
|
* (centroids, cells validity/rowids/vectors, rowid_map) with fuzz data.
|
||||||
|
* Then exercises all read/write paths. Must not crash.
|
||||||
|
*
|
||||||
|
* Targets:
|
||||||
|
* - Cell validity bitmap with wrong size
|
||||||
|
* - Cell rowids blob with wrong size/alignment
|
||||||
|
* - Cell vectors blob with wrong size
|
||||||
|
* - Centroid blob with wrong size
|
||||||
|
* - n_vectors inconsistent with validity bitmap
|
||||||
|
* - Missing rowid_map entries
|
||||||
|
* - KNN scan over corrupted cells
|
||||||
|
* - Insert/delete with corrupted rowid_map
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 4) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
// Create IVF table and insert enough vectors to train
|
||||||
|
rc = sqlite3_exec(db,
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[8] indexed by ivf(nlist=2, nprobe=2))",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
// Insert 10 vectors
|
||||||
|
{
|
||||||
|
sqlite3_stmt *si = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &si, NULL);
|
||||||
|
if (!si) { sqlite3_close(db); return 0; }
|
||||||
|
for (int i = 0; i < 10; i++) {
|
||||||
|
float vec[8];
|
||||||
|
for (int d = 0; d < 8; d++) {
|
||||||
|
vec[d] = (float)(i * 8 + d) * 0.1f;
|
||||||
|
}
|
||||||
|
sqlite3_reset(si);
|
||||||
|
sqlite3_bind_int64(si, 1, i + 1);
|
||||||
|
sqlite3_bind_blob(si, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(si);
|
||||||
|
}
|
||||||
|
sqlite3_finalize(si);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Train
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"INSERT INTO v(rowid) VALUES ('compute-centroids')",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// Now corrupt shadow tables based on fuzz input
|
||||||
|
uint8_t target = data[0] % 10;
|
||||||
|
const uint8_t *payload = data + 1;
|
||||||
|
int payload_size = (int)(size - 1);
|
||||||
|
|
||||||
|
// Limit payload to avoid huge allocations
|
||||||
|
if (payload_size > 4096) payload_size = 4096;
|
||||||
|
|
||||||
|
sqlite3_stmt *stmt = NULL;
|
||||||
|
|
||||||
|
switch (target) {
|
||||||
|
case 0: {
|
||||||
|
// Corrupt cell validity blob
|
||||||
|
rc = sqlite3_prepare_v2(db,
|
||||||
|
"UPDATE v_ivf_cells00 SET validity = ? WHERE rowid = 1",
|
||||||
|
-1, &stmt, NULL);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC);
|
||||||
|
sqlite3_step(stmt); sqlite3_finalize(stmt);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 1: {
|
||||||
|
// Corrupt cell rowids blob
|
||||||
|
rc = sqlite3_prepare_v2(db,
|
||||||
|
"UPDATE v_ivf_cells00 SET rowids = ? WHERE rowid = 1",
|
||||||
|
-1, &stmt, NULL);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC);
|
||||||
|
sqlite3_step(stmt); sqlite3_finalize(stmt);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 2: {
|
||||||
|
// Corrupt cell vectors blob
|
||||||
|
rc = sqlite3_prepare_v2(db,
|
||||||
|
"UPDATE v_ivf_cells00 SET vectors = ? WHERE rowid = 1",
|
||||||
|
-1, &stmt, NULL);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC);
|
||||||
|
sqlite3_step(stmt); sqlite3_finalize(stmt);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 3: {
|
||||||
|
// Corrupt centroid blob
|
||||||
|
rc = sqlite3_prepare_v2(db,
|
||||||
|
"UPDATE v_ivf_centroids00 SET centroid = ? WHERE centroid_id = 0",
|
||||||
|
-1, &stmt, NULL);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC);
|
||||||
|
sqlite3_step(stmt); sqlite3_finalize(stmt);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 4: {
|
||||||
|
// Set n_vectors to a bogus value (larger than cell capacity)
|
||||||
|
int bogus_n = 99999;
|
||||||
|
if (payload_size >= 4) {
|
||||||
|
memcpy(&bogus_n, payload, 4);
|
||||||
|
bogus_n = abs(bogus_n) % 100000;
|
||||||
|
}
|
||||||
|
char sql[128];
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"UPDATE v_ivf_cells00 SET n_vectors = %d WHERE rowid = 1", bogus_n);
|
||||||
|
sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 5: {
|
||||||
|
// Delete rowid_map entries (orphan vectors)
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"DELETE FROM v_ivf_rowid_map00 WHERE rowid IN (1, 2, 3)",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 6: {
|
||||||
|
// Corrupt rowid_map slot values
|
||||||
|
char sql[128];
|
||||||
|
int bogus_slot = payload_size > 0 ? (int)payload[0] * 1000 : 99999;
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"UPDATE v_ivf_rowid_map00 SET slot = %d WHERE rowid = 1", bogus_slot);
|
||||||
|
sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 7: {
|
||||||
|
// Corrupt rowid_map cell_id values
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"UPDATE v_ivf_rowid_map00 SET cell_id = 99999 WHERE rowid = 1",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 8: {
|
||||||
|
// Delete all centroids (make trained but no centroids)
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"DELETE FROM v_ivf_centroids00",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 9: {
|
||||||
|
// Set validity to NULL
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"UPDATE v_ivf_cells00 SET validity = NULL WHERE rowid = 1",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exercise all read paths over corrupted state — must not crash
|
||||||
|
float qvec[8] = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
|
||||||
|
|
||||||
|
// KNN query
|
||||||
|
{
|
||||||
|
sqlite3_stmt *sk = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 5",
|
||||||
|
-1, &sk, NULL);
|
||||||
|
if (sk) {
|
||||||
|
sqlite3_bind_blob(sk, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||||
|
while (sqlite3_step(sk) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(sk);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Full scan
|
||||||
|
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// Point query
|
||||||
|
sqlite3_exec(db, "SELECT * FROM v WHERE rowid = 1", NULL, NULL, NULL);
|
||||||
|
sqlite3_exec(db, "SELECT * FROM v WHERE rowid = 5", NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// Delete
|
||||||
|
sqlite3_exec(db, "DELETE FROM v WHERE rowid = 3", NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// Insert after corruption
|
||||||
|
{
|
||||||
|
float newvec[8] = {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f};
|
||||||
|
sqlite3_stmt *si = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &si, NULL);
|
||||||
|
if (si) {
|
||||||
|
sqlite3_bind_int64(si, 1, 100);
|
||||||
|
sqlite3_bind_blob(si, 2, newvec, sizeof(newvec), SQLITE_STATIC);
|
||||||
|
sqlite3_step(si);
|
||||||
|
sqlite3_finalize(si);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// compute-centroids over corrupted state
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"INSERT INTO v(rowid) VALUES ('compute-centroids')",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// clear-centroids
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"INSERT INTO v(rowid) VALUES ('clear-centroids')",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
@ -5,6 +5,10 @@
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#ifndef SQLITE_VEC_ENABLE_IVF
|
||||||
|
#define SQLITE_VEC_ENABLE_IVF 1
|
||||||
|
#endif
|
||||||
|
|
||||||
int min_idx(
|
int min_idx(
|
||||||
const float *distances,
|
const float *distances,
|
||||||
int32_t n,
|
int32_t n,
|
||||||
|
|
@ -68,8 +72,36 @@ enum Vec0IndexType {
|
||||||
#ifdef SQLITE_VEC_ENABLE_RESCORE
|
#ifdef SQLITE_VEC_ENABLE_RESCORE
|
||||||
VEC0_INDEX_TYPE_RESCORE = 2,
|
VEC0_INDEX_TYPE_RESCORE = 2,
|
||||||
#endif
|
#endif
|
||||||
|
VEC0_INDEX_TYPE_IVF = 3,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum Vec0RescoreQuantizerType {
|
||||||
|
VEC0_RESCORE_QUANTIZER_BIT = 1,
|
||||||
|
VEC0_RESCORE_QUANTIZER_INT8 = 2,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Vec0RescoreConfig {
|
||||||
|
enum Vec0RescoreQuantizerType quantizer_type;
|
||||||
|
int oversample;
|
||||||
|
};
|
||||||
|
|
||||||
|
#if SQLITE_VEC_ENABLE_IVF
|
||||||
|
enum Vec0IvfQuantizer {
|
||||||
|
VEC0_IVF_QUANTIZER_NONE = 0,
|
||||||
|
VEC0_IVF_QUANTIZER_INT8 = 1,
|
||||||
|
VEC0_IVF_QUANTIZER_BINARY = 2,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Vec0IvfConfig {
|
||||||
|
int nlist;
|
||||||
|
int nprobe;
|
||||||
|
int quantizer;
|
||||||
|
int oversample;
|
||||||
|
};
|
||||||
|
#else
|
||||||
|
struct Vec0IvfConfig { char _unused; };
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef SQLITE_VEC_ENABLE_RESCORE
|
#ifdef SQLITE_VEC_ENABLE_RESCORE
|
||||||
enum Vec0RescoreQuantizerType {
|
enum Vec0RescoreQuantizerType {
|
||||||
VEC0_RESCORE_QUANTIZER_BIT = 1,
|
VEC0_RESCORE_QUANTIZER_BIT = 1,
|
||||||
|
|
@ -93,6 +125,7 @@ struct VectorColumnDefinition {
|
||||||
#ifdef SQLITE_VEC_ENABLE_RESCORE
|
#ifdef SQLITE_VEC_ENABLE_RESCORE
|
||||||
struct Vec0RescoreConfig rescore;
|
struct Vec0RescoreConfig rescore;
|
||||||
#endif
|
#endif
|
||||||
|
struct Vec0IvfConfig ivf;
|
||||||
};
|
};
|
||||||
|
|
||||||
int vec0_parse_vector_column(const char *source, int source_length,
|
int vec0_parse_vector_column(const char *source, int source_length,
|
||||||
|
|
@ -114,6 +147,10 @@ void _test_rescore_quantize_float_to_int8(const float *src, int8_t *dst, size_t
|
||||||
size_t _test_rescore_quantized_byte_size_bit(size_t dimensions);
|
size_t _test_rescore_quantized_byte_size_bit(size_t dimensions);
|
||||||
size_t _test_rescore_quantized_byte_size_int8(size_t dimensions);
|
size_t _test_rescore_quantized_byte_size_int8(size_t dimensions);
|
||||||
#endif
|
#endif
|
||||||
|
#if SQLITE_VEC_ENABLE_IVF
|
||||||
|
void ivf_quantize_int8(const float *src, int8_t *dst, int D);
|
||||||
|
void ivf_quantize_binary(const float *src, uint8_t *dst, int D);
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#endif /* SQLITE_VEC_INTERNAL_H */
|
#endif /* SQLITE_VEC_INTERNAL_H */
|
||||||
|
|
|
||||||
575
tests/test-ivf-mutations.py
Normal file
575
tests/test-ivf-mutations.py
Normal file
|
|
@ -0,0 +1,575 @@
|
||||||
|
"""
|
||||||
|
Thorough IVF mutation tests: insert, delete, update, KNN correctness,
|
||||||
|
error cases, edge cases, and cell overflow scenarios.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
import sqlite3
|
||||||
|
import struct
|
||||||
|
import math
|
||||||
|
from helpers import _f32, exec
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def db():
|
||||||
|
db = sqlite3.connect(":memory:")
|
||||||
|
db.row_factory = sqlite3.Row
|
||||||
|
db.enable_load_extension(True)
|
||||||
|
db.load_extension("dist/vec0")
|
||||||
|
db.enable_load_extension(False)
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
def ivf_total_vectors(db, table="t", col=0):
|
||||||
|
"""Count total vectors across all IVF cells."""
|
||||||
|
return db.execute(
|
||||||
|
f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d}"
|
||||||
|
).fetchone()[0]
|
||||||
|
|
||||||
|
|
||||||
|
def ivf_unassigned_count(db, table="t", col=0):
|
||||||
|
return db.execute(
|
||||||
|
f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d} WHERE centroid_id = -1"
|
||||||
|
).fetchone()[0]
|
||||||
|
|
||||||
|
|
||||||
|
def ivf_assigned_count(db, table="t", col=0):
|
||||||
|
return db.execute(
|
||||||
|
f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d} WHERE centroid_id >= 0"
|
||||||
|
).fetchone()[0]
|
||||||
|
|
||||||
|
|
||||||
|
def knn(db, query, k, table="t", col="v"):
|
||||||
|
"""Run a KNN query and return list of (rowid, distance) tuples."""
|
||||||
|
rows = db.execute(
|
||||||
|
f"SELECT rowid, distance FROM {table} WHERE {col} MATCH ? AND k = ?",
|
||||||
|
[_f32(query), k],
|
||||||
|
).fetchall()
|
||||||
|
return [(r[0], r[1]) for r in rows]
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Single row insert + KNN
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_insert_single_row_knn(db):
|
||||||
|
db.execute("CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf())")
|
||||||
|
db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 0, 0, 0])])
|
||||||
|
results = knn(db, [1, 0, 0, 0], 5)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0][0] == 1
|
||||||
|
assert results[0][1] < 0.001
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Batch insert + KNN recall
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_insert_knn_recall(db):
|
||||||
|
"""Insert 200 vectors, train, verify KNN recall with nprobe=nlist."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=8, nprobe=8))"
|
||||||
|
)
|
||||||
|
for i in range(200):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
|
)
|
||||||
|
assert ivf_total_vectors(db) == 200
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
||||||
|
assert ivf_assigned_count(db) == 200
|
||||||
|
|
||||||
|
# Query near 100 -- closest should be rowid 100
|
||||||
|
results = knn(db, [100.0, 0, 0, 0], 10)
|
||||||
|
assert len(results) == 10
|
||||||
|
assert results[0][0] == 100
|
||||||
|
assert results[0][1] < 0.01
|
||||||
|
|
||||||
|
# All results should be near 100
|
||||||
|
rowids = {r[0] for r in results}
|
||||||
|
assert all(95 <= r <= 105 for r in rowids)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Delete rows, verify they're gone from KNN
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_rows_gone_from_knn(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
|
||||||
|
)
|
||||||
|
for i in range(20):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
# Delete rowid 10
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = 10")
|
||||||
|
|
||||||
|
results = knn(db, [10.0, 0, 0, 0], 20)
|
||||||
|
rowids = {r[0] for r in results}
|
||||||
|
assert 10 not in rowids
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_all_rows_empty_results(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
|
||||||
|
)
|
||||||
|
for i in range(10):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = ?", [i])
|
||||||
|
|
||||||
|
assert ivf_total_vectors(db) == 0
|
||||||
|
results = knn(db, [5.0, 0, 0, 0], 10)
|
||||||
|
assert len(results) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Insert after delete (reuse rowids)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_insert_after_delete_reuse_rowid(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
|
||||||
|
)
|
||||||
|
for i in range(10):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
# Delete rowid 5
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = 5")
|
||||||
|
|
||||||
|
# Re-insert rowid 5 with a very different vector
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (5, ?)", [_f32([999.0, 0, 0, 0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
# KNN near 999 should find rowid 5
|
||||||
|
results = knn(db, [999.0, 0, 0, 0], 1)
|
||||||
|
assert len(results) >= 1
|
||||||
|
assert results[0][0] == 5
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Update vectors (INSERT OR REPLACE), verify KNN reflects new values
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_vector_via_delete_insert(db):
|
||||||
|
"""vec0 IVF update: delete then re-insert with new vector."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
|
||||||
|
)
|
||||||
|
for i in range(10):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
# "Update" rowid 3: delete and re-insert with new vector
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = 3")
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (3, ?)",
|
||||||
|
[_f32([100.0, 0, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
# KNN near 100 should find rowid 3
|
||||||
|
results = knn(db, [100.0, 0, 0, 0], 1)
|
||||||
|
assert results[0][0] == 3
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Error cases: IVF + auxiliary/metadata/partition key columns
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_error_ivf_with_auxiliary_column(db):
|
||||||
|
result = exec(
|
||||||
|
db,
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(), +extra text)",
|
||||||
|
)
|
||||||
|
assert "error" in result
|
||||||
|
assert "auxiliary" in result.get("message", "").lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_error_ivf_with_metadata_column(db):
|
||||||
|
result = exec(
|
||||||
|
db,
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(), genre text)",
|
||||||
|
)
|
||||||
|
assert "error" in result
|
||||||
|
assert "metadata" in result.get("message", "").lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_error_ivf_with_partition_key(db):
|
||||||
|
result = exec(
|
||||||
|
db,
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(), user_id integer partition key)",
|
||||||
|
)
|
||||||
|
assert "error" in result
|
||||||
|
assert "partition" in result.get("message", "").lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_flat_with_auxiliary_still_works(db):
|
||||||
|
"""Regression guard: flat-indexed tables with aux columns should still work."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4], +extra text)"
|
||||||
|
)
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v, extra) VALUES (1, ?, 'hello')",
|
||||||
|
[_f32([1, 0, 0, 0])],
|
||||||
|
)
|
||||||
|
row = db.execute("SELECT extra FROM t WHERE rowid = 1").fetchone()
|
||||||
|
assert row[0] == "hello"
|
||||||
|
|
||||||
|
|
||||||
|
def test_flat_with_metadata_still_works(db):
|
||||||
|
"""Regression guard: flat-indexed tables with metadata columns should still work."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4], genre text)"
|
||||||
|
)
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v, genre) VALUES (1, ?, 'rock')",
|
||||||
|
[_f32([1, 0, 0, 0])],
|
||||||
|
)
|
||||||
|
row = db.execute("SELECT genre FROM t WHERE rowid = 1").fetchone()
|
||||||
|
assert row[0] == "rock"
|
||||||
|
|
||||||
|
|
||||||
|
def test_flat_with_partition_key_still_works(db):
|
||||||
|
"""Regression guard: flat-indexed tables with partition key should still work."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4], user_id integer partition key)"
|
||||||
|
)
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v, user_id) VALUES (1, ?, 42)",
|
||||||
|
[_f32([1, 0, 0, 0])],
|
||||||
|
)
|
||||||
|
row = db.execute("SELECT user_id FROM t WHERE rowid = 1").fetchone()
|
||||||
|
assert row[0] == 42
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Edge cases
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_zero_vectors(db):
|
||||||
|
"""Insert zero vectors, verify KNN still works."""
|
||||||
|
db.execute("CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf())")
|
||||||
|
for i in range(5):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([0, 0, 0, 0])],
|
||||||
|
)
|
||||||
|
results = knn(db, [0, 0, 0, 0], 5)
|
||||||
|
assert len(results) == 5
|
||||||
|
# All distances should be 0
|
||||||
|
for _, dist in results:
|
||||||
|
assert dist < 0.001
|
||||||
|
|
||||||
|
|
||||||
|
def test_large_values(db):
|
||||||
|
"""Insert vectors with very large and small values."""
|
||||||
|
db.execute("CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf())")
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1e6, 1e6, 1e6, 1e6])]
|
||||||
|
)
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([1e-6, 1e-6, 1e-6, 1e-6])]
|
||||||
|
)
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (3, ?)", [_f32([-1e6, -1e6, -1e6, -1e6])]
|
||||||
|
)
|
||||||
|
|
||||||
|
results = knn(db, [1e6, 1e6, 1e6, 1e6], 3)
|
||||||
|
assert results[0][0] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_row_compute_centroids(db):
|
||||||
|
"""Single row table, compute-centroids should still work."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=1))"
|
||||||
|
)
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 2, 3, 4])]
|
||||||
|
)
|
||||||
|
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
||||||
|
assert ivf_assigned_count(db) == 1
|
||||||
|
|
||||||
|
results = knn(db, [1, 2, 3, 4], 1)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0][0] == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Cell overflow (many vectors in one cell)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_cell_overflow_many_vectors(db):
|
||||||
|
"""Insert >64 vectors that all go to same centroid. Should create multiple cells."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=0))"
|
||||||
|
)
|
||||||
|
# Insert 100 very similar vectors
|
||||||
|
for i in range(100):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([1.0 + i * 0.001, 0, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set a single centroid so all vectors go there
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES ('set-centroid:0', ?)",
|
||||||
|
[_f32([1.0, 0, 0, 0])],
|
||||||
|
)
|
||||||
|
db.execute("INSERT INTO t(rowid) VALUES ('assign-vectors')")
|
||||||
|
|
||||||
|
assert ivf_assigned_count(db) == 100
|
||||||
|
|
||||||
|
# Should have more than 1 cell (64 max per cell)
|
||||||
|
cell_count = db.execute(
|
||||||
|
"SELECT count(*) FROM t_ivf_cells00 WHERE centroid_id = 0"
|
||||||
|
).fetchone()[0]
|
||||||
|
assert cell_count >= 2 # 100 / 64 = 2 cells needed
|
||||||
|
|
||||||
|
# All vectors should be queryable
|
||||||
|
results = knn(db, [1.0, 0, 0, 0], 100)
|
||||||
|
assert len(results) == 100
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Large batch with training
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_large_batch_with_training(db):
|
||||||
|
"""Insert 500, train, insert 500 more, verify total is 1000."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=16, nprobe=16))"
|
||||||
|
)
|
||||||
|
for i in range(500):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
for i in range(500, 1000):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ivf_total_vectors(db) == 1000
|
||||||
|
|
||||||
|
# KNN should still work
|
||||||
|
results = knn(db, [750.0, 0, 0, 0], 5)
|
||||||
|
assert len(results) == 5
|
||||||
|
assert results[0][0] == 750
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# KNN after interleaved insert/delete
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_knn_after_interleaved_insert_delete(db):
|
||||||
|
"""Insert 20, train, delete 10 closest to query, verify remaining."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4, nprobe=4))"
|
||||||
|
)
|
||||||
|
for i in range(20):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
# Delete rowids 0-9 (closest to query at 5.0)
|
||||||
|
for i in range(10):
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = ?", [i])
|
||||||
|
|
||||||
|
results = knn(db, [5.0, 0, 0, 0], 10)
|
||||||
|
rowids = {r[0] for r in results}
|
||||||
|
# None of the deleted rowids should appear
|
||||||
|
assert all(r >= 10 for r in rowids)
|
||||||
|
assert len(results) == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_knn_empty_centroids_after_deletes(db):
|
||||||
|
"""Some centroids may become empty after deletes. Should not crash."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4, nprobe=2))"
|
||||||
|
)
|
||||||
|
# Insert vectors clustered near 4 points
|
||||||
|
for i in range(40):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([float(i % 10) * 10, 0, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
# Delete a bunch, potentially emptying some centroids
|
||||||
|
for i in range(30):
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = ?", [i])
|
||||||
|
|
||||||
|
# Should not crash even with empty centroids
|
||||||
|
results = knn(db, [50.0, 0, 0, 0], 20)
|
||||||
|
assert len(results) <= 10 # only 10 left
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# KNN returns correct distances
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_knn_correct_distances(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
|
||||||
|
)
|
||||||
|
db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([0, 0, 0, 0])])
|
||||||
|
db.execute("INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([3, 0, 0, 0])])
|
||||||
|
db.execute("INSERT INTO t(rowid, v) VALUES (3, ?)", [_f32([0, 4, 0, 0])])
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
results = knn(db, [0, 0, 0, 0], 3)
|
||||||
|
result_map = {r[0]: r[1] for r in results}
|
||||||
|
|
||||||
|
# L2 distances (sqrt of sum of squared differences)
|
||||||
|
assert abs(result_map[1] - 0.0) < 0.01
|
||||||
|
assert abs(result_map[2] - 3.0) < 0.01 # sqrt(3^2) = 3
|
||||||
|
assert abs(result_map[3] - 4.0) < 0.01 # sqrt(4^2) = 4
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Delete in flat mode leaves no orphan rowid_map entries
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_flat_mode_rowid_map_count(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
|
||||||
|
)
|
||||||
|
for i in range(5):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = 0")
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = 2")
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = 4")
|
||||||
|
|
||||||
|
assert db.execute("SELECT count(*) FROM t_ivf_rowid_map00").fetchone()[0] == 2
|
||||||
|
assert ivf_unassigned_count(db) == 2
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Duplicate rowid insert
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_reinsert_as_update(db):
|
||||||
|
"""Simulate update via delete + insert on same rowid."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
|
||||||
|
)
|
||||||
|
db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 0, 0, 0])])
|
||||||
|
|
||||||
|
# Delete then re-insert as "update"
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = 1")
|
||||||
|
db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([99, 0, 0, 0])])
|
||||||
|
|
||||||
|
results = knn(db, [99, 0, 0, 0], 1)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0][0] == 1
|
||||||
|
assert results[0][1] < 0.01
|
||||||
|
|
||||||
|
|
||||||
|
def test_duplicate_rowid_insert_fails(db):
|
||||||
|
"""Inserting a duplicate rowid should fail with a constraint error."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
|
||||||
|
)
|
||||||
|
db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 0, 0, 0])])
|
||||||
|
|
||||||
|
result = exec(
|
||||||
|
db,
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (1, ?)",
|
||||||
|
[_f32([99, 0, 0, 0])],
|
||||||
|
)
|
||||||
|
assert "error" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Interleaved insert/delete with KNN correctness
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_interleaved_ops_correctness(db):
|
||||||
|
"""Complex sequence of inserts and deletes, verify KNN is always correct."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4, nprobe=4))"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Phase 1: Insert 50 vectors
|
||||||
|
for i in range(50):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
# Phase 2: Delete even-numbered rowids
|
||||||
|
for i in range(0, 50, 2):
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = ?", [i])
|
||||||
|
|
||||||
|
# Phase 3: Insert new vectors with higher rowids
|
||||||
|
for i in range(50, 75):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Phase 4: Delete some of the new ones
|
||||||
|
for i in range(60, 70):
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = ?", [i])
|
||||||
|
|
||||||
|
# KNN query: should only find existing vectors
|
||||||
|
results = knn(db, [25.0, 0, 0, 0], 50)
|
||||||
|
rowids = {r[0] for r in results}
|
||||||
|
|
||||||
|
# Verify no deleted rowids appear
|
||||||
|
deleted = set(range(0, 50, 2)) | set(range(60, 70))
|
||||||
|
assert len(rowids & deleted) == 0
|
||||||
|
|
||||||
|
# Verify we get the right count (25 odd + 15 new - 10 deleted new = 30)
|
||||||
|
expected_alive = set(range(1, 50, 2)) | set(range(50, 60)) | set(range(70, 75))
|
||||||
|
assert rowids.issubset(expected_alive)
|
||||||
255
tests/test-ivf-quantization.py
Normal file
255
tests/test-ivf-quantization.py
Normal file
|
|
@ -0,0 +1,255 @@
|
||||||
|
import pytest
|
||||||
|
import sqlite3
|
||||||
|
from helpers import _f32, exec
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def db():
|
||||||
|
db = sqlite3.connect(":memory:")
|
||||||
|
db.row_factory = sqlite3.Row
|
||||||
|
db.enable_load_extension(True)
|
||||||
|
db.load_extension("dist/vec0")
|
||||||
|
db.enable_load_extension(False)
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Parser tests — quantizer and oversample options
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_quantizer_binary(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
"v float[768] indexed by ivf(nlist=64, quantizer=binary, oversample=10))"
|
||||||
|
)
|
||||||
|
tables = [
|
||||||
|
r[0]
|
||||||
|
for r in db.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
|
||||||
|
).fetchall()
|
||||||
|
]
|
||||||
|
assert "t_ivf_centroids00" in tables
|
||||||
|
assert "t_ivf_cells00" in tables
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_quantizer_int8(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
"v float[768] indexed by ivf(nlist=64, quantizer=int8))"
|
||||||
|
)
|
||||||
|
tables = [
|
||||||
|
r[0]
|
||||||
|
for r in db.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
|
||||||
|
).fetchall()
|
||||||
|
]
|
||||||
|
assert "t_ivf_centroids00" in tables
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_quantizer_none_explicit(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
"v float[768] indexed by ivf(quantizer=none))"
|
||||||
|
)
|
||||||
|
# Should work — same as no quantizer
|
||||||
|
tables = [
|
||||||
|
r[0]
|
||||||
|
for r in db.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
|
||||||
|
).fetchall()
|
||||||
|
]
|
||||||
|
assert "t_ivf_centroids00" in tables
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_quantizer_all_params(db):
|
||||||
|
"""All params together: nlist, nprobe, quantizer, oversample."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
"v float[768] distance_metric=cosine "
|
||||||
|
"indexed by ivf(nlist=128, nprobe=16, quantizer=int8, oversample=4))"
|
||||||
|
)
|
||||||
|
tables = [
|
||||||
|
r[0]
|
||||||
|
for r in db.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
|
||||||
|
).fetchall()
|
||||||
|
]
|
||||||
|
assert "t_ivf_centroids00" in tables
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_error_oversample_without_quantizer(db):
|
||||||
|
"""oversample > 1 without quantizer should error."""
|
||||||
|
result = exec(
|
||||||
|
db,
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
"v float[768] indexed by ivf(oversample=10))",
|
||||||
|
)
|
||||||
|
assert "error" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_error_unknown_quantizer(db):
|
||||||
|
result = exec(
|
||||||
|
db,
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
"v float[768] indexed by ivf(quantizer=pq))",
|
||||||
|
)
|
||||||
|
assert "error" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_oversample_1_without_quantizer_ok(db):
|
||||||
|
"""oversample=1 (default) is fine without quantizer."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
"v float[768] indexed by ivf(nlist=64))"
|
||||||
|
)
|
||||||
|
# Should succeed — oversample defaults to 1
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Functional tests — insert, train, query with quantized IVF
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_int8_insert_and_query(db):
|
||||||
|
"""int8 quantized IVF: insert, train, query."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
"v float[4] indexed by ivf(nlist=2, quantizer=int8, oversample=4))"
|
||||||
|
)
|
||||||
|
for i in range(20):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
# Should be able to query
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT rowid, distance FROM t WHERE v MATCH ? AND k = 5",
|
||||||
|
[_f32([10.0, 0, 0, 0])],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) == 5
|
||||||
|
# Top result should be close to 10
|
||||||
|
assert rows[0][0] in range(8, 13)
|
||||||
|
|
||||||
|
# Full vectors should be in _ivf_vectors table
|
||||||
|
fv_count = db.execute("SELECT count(*) FROM t_ivf_vectors00").fetchone()[0]
|
||||||
|
assert fv_count == 20
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_binary_insert_and_query(db):
|
||||||
|
"""Binary quantized IVF: insert, train, query."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
"v float[8] indexed by ivf(nlist=2, quantizer=binary, oversample=4))"
|
||||||
|
)
|
||||||
|
for i in range(20):
|
||||||
|
# Vectors with varying sign patterns
|
||||||
|
v = [(i * 0.1 - 1.0) + j * 0.3 for j in range(8)]
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32(v)]
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT rowid FROM t WHERE v MATCH ? AND k = 5",
|
||||||
|
[_f32([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5])],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) == 5
|
||||||
|
|
||||||
|
# Full vectors stored
|
||||||
|
fv_count = db.execute("SELECT count(*) FROM t_ivf_vectors00").fetchone()[0]
|
||||||
|
assert fv_count == 20
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_int8_cell_sizes_smaller(db):
|
||||||
|
"""Cell blobs should be smaller with int8 quantization."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
"v float[768] indexed by ivf(nlist=2, quantizer=int8, oversample=1))"
|
||||||
|
)
|
||||||
|
for i in range(10):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([float(x) / 768 for x in range(768)])],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cell vectors blob: 10 vectors at int8 = 10 * 768 = 7680 bytes
|
||||||
|
# vs float32 = 10 * 768 * 4 = 30720 bytes
|
||||||
|
# But cells have capacity 64, so blob = 64 * 768 = 49152 (int8) vs 64*768*4=196608 (float32)
|
||||||
|
blob_size = db.execute(
|
||||||
|
"SELECT length(vectors) FROM t_ivf_cells00 LIMIT 1"
|
||||||
|
).fetchone()[0]
|
||||||
|
# int8: 64 slots * 768 bytes = 49152
|
||||||
|
assert blob_size == 64 * 768
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_binary_cell_sizes_smaller(db):
|
||||||
|
"""Cell blobs should be much smaller with binary quantization."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
"v float[768] indexed by ivf(nlist=2, quantizer=binary, oversample=1))"
|
||||||
|
)
|
||||||
|
for i in range(10):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([float(x) / 768 for x in range(768)])],
|
||||||
|
)
|
||||||
|
|
||||||
|
blob_size = db.execute(
|
||||||
|
"SELECT length(vectors) FROM t_ivf_cells00 LIMIT 1"
|
||||||
|
).fetchone()[0]
|
||||||
|
# binary: 64 slots * 768/8 bytes = 6144
|
||||||
|
assert blob_size == 64 * (768 // 8)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_int8_oversample_improves_recall(db):
|
||||||
|
"""Oversample re-ranking should improve recall over oversample=1."""
|
||||||
|
# Create two tables: one with oversample=1, one with oversample=10
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t1 USING vec0("
|
||||||
|
"v float[4] indexed by ivf(nlist=4, quantizer=int8, oversample=1))"
|
||||||
|
)
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t2 USING vec0("
|
||||||
|
"v float[4] indexed by ivf(nlist=4, quantizer=int8, oversample=10))"
|
||||||
|
)
|
||||||
|
for i in range(100):
|
||||||
|
v = _f32([i * 0.1, (i * 0.1) % 3, (i * 0.3) % 5, i * 0.01])
|
||||||
|
db.execute("INSERT INTO t1(rowid, v) VALUES (?, ?)", [i, v])
|
||||||
|
db.execute("INSERT INTO t2(rowid, v) VALUES (?, ?)", [i, v])
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t1(rowid) VALUES ('compute-centroids')")
|
||||||
|
db.execute("INSERT INTO t2(rowid) VALUES ('compute-centroids')")
|
||||||
|
db.execute("INSERT INTO t1(rowid) VALUES ('nprobe=4')")
|
||||||
|
db.execute("INSERT INTO t2(rowid) VALUES ('nprobe=4')")
|
||||||
|
|
||||||
|
query = _f32([5.0, 1.5, 2.5, 0.5])
|
||||||
|
r1 = db.execute("SELECT rowid FROM t1 WHERE v MATCH ? AND k=10", [query]).fetchall()
|
||||||
|
r2 = db.execute("SELECT rowid FROM t2 WHERE v MATCH ? AND k=10", [query]).fetchall()
|
||||||
|
|
||||||
|
# Both should return 10 results
|
||||||
|
assert len(r1) == 10
|
||||||
|
assert len(r2) == 10
|
||||||
|
# oversample=10 should have at least as good recall (same or better ordering)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_quantized_delete(db):
|
||||||
|
"""Delete should remove from both cells and _ivf_vectors."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
"v float[4] indexed by ivf(nlist=2, quantizer=int8, oversample=1))"
|
||||||
|
)
|
||||||
|
for i in range(10):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
||||||
|
assert db.execute("SELECT count(*) FROM t_ivf_vectors00").fetchone()[0] == 10
|
||||||
|
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = 5")
|
||||||
|
# _ivf_vectors should have 9 rows
|
||||||
|
assert db.execute("SELECT count(*) FROM t_ivf_vectors00").fetchone()[0] == 9
|
||||||
426
tests/test-ivf.py
Normal file
426
tests/test-ivf.py
Normal file
|
|
@ -0,0 +1,426 @@
|
||||||
|
import pytest
|
||||||
|
import sqlite3
|
||||||
|
import struct
|
||||||
|
import math
|
||||||
|
from helpers import _f32, exec
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def db():
|
||||||
|
db = sqlite3.connect(":memory:")
|
||||||
|
db.row_factory = sqlite3.Row
|
||||||
|
db.enable_load_extension(True)
|
||||||
|
db.load_extension("dist/vec0")
|
||||||
|
db.enable_load_extension(False)
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
def ivf_total_vectors(db, table="t", col=0):
|
||||||
|
"""Count total vectors across all IVF cells."""
|
||||||
|
return db.execute(
|
||||||
|
f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d}"
|
||||||
|
).fetchone()[0]
|
||||||
|
|
||||||
|
|
||||||
|
def ivf_unassigned_count(db, table="t", col=0):
|
||||||
|
"""Count vectors in unassigned cells (centroid_id=-1)."""
|
||||||
|
return db.execute(
|
||||||
|
f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d} WHERE centroid_id = -1"
|
||||||
|
).fetchone()[0]
|
||||||
|
|
||||||
|
|
||||||
|
def ivf_assigned_count(db, table="t", col=0):
|
||||||
|
"""Count vectors in trained cells (centroid_id >= 0)."""
|
||||||
|
return db.execute(
|
||||||
|
f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d} WHERE centroid_id >= 0"
|
||||||
|
).fetchone()[0]
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Parser tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_create_defaults(db):
|
||||||
|
"""ivf() with no args uses defaults."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf())"
|
||||||
|
)
|
||||||
|
tables = [
|
||||||
|
r[0]
|
||||||
|
for r in db.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
|
||||||
|
).fetchall()
|
||||||
|
]
|
||||||
|
assert "t_ivf_centroids00" in tables
|
||||||
|
assert "t_ivf_cells00" in tables
|
||||||
|
assert "t_ivf_rowid_map00" in tables
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_create_custom_params(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=64, nprobe=8))"
|
||||||
|
)
|
||||||
|
tables = [
|
||||||
|
r[0]
|
||||||
|
for r in db.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
|
||||||
|
).fetchall()
|
||||||
|
]
|
||||||
|
assert "t_ivf_centroids00" in tables
|
||||||
|
assert "t_ivf_cells00" in tables
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_create_with_distance_metric(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] distance_metric=cosine indexed by ivf(nlist=16))"
|
||||||
|
)
|
||||||
|
tables = [
|
||||||
|
r[0]
|
||||||
|
for r in db.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
|
||||||
|
).fetchall()
|
||||||
|
]
|
||||||
|
assert "t_ivf_centroids00" in tables
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_create_error_unknown_key(db):
|
||||||
|
result = exec(
|
||||||
|
db,
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(bogus=1))",
|
||||||
|
)
|
||||||
|
assert "error" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_create_error_nprobe_gt_nlist(db):
|
||||||
|
result = exec(
|
||||||
|
db,
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4, nprobe=10))",
|
||||||
|
)
|
||||||
|
assert "error" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Shadow table tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_shadow_tables_created(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=8))"
|
||||||
|
)
|
||||||
|
trained = db.execute(
|
||||||
|
"SELECT value FROM t_info WHERE key = 'ivf_trained_0'"
|
||||||
|
).fetchone()[0]
|
||||||
|
assert str(trained) == "0"
|
||||||
|
|
||||||
|
# No cells yet (created lazily on first insert)
|
||||||
|
count = db.execute(
|
||||||
|
"SELECT count(*) FROM t_ivf_cells00"
|
||||||
|
).fetchone()[0]
|
||||||
|
assert count == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_drop_cleans_up(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
|
||||||
|
)
|
||||||
|
db.execute("DROP TABLE t")
|
||||||
|
tables = [
|
||||||
|
r[0]
|
||||||
|
for r in db.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table'"
|
||||||
|
).fetchall()
|
||||||
|
]
|
||||||
|
assert not any("ivf" in t for t in tables)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Insert tests (flat mode)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_insert_flat_mode(db):
|
||||||
|
"""Before training, vectors go to unassigned cell."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
|
||||||
|
)
|
||||||
|
db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 2, 3, 4])])
|
||||||
|
db.execute("INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([5, 6, 7, 8])])
|
||||||
|
|
||||||
|
assert ivf_unassigned_count(db) == 2
|
||||||
|
assert ivf_assigned_count(db) == 0
|
||||||
|
|
||||||
|
# rowid_map should have 2 entries
|
||||||
|
assert db.execute("SELECT count(*) FROM t_ivf_rowid_map00").fetchone()[0] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_delete_flat_mode(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
|
||||||
|
)
|
||||||
|
db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 2, 3, 4])])
|
||||||
|
db.execute("INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([5, 6, 7, 8])])
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = 1")
|
||||||
|
|
||||||
|
assert ivf_unassigned_count(db) == 1
|
||||||
|
assert db.execute("SELECT count(*) FROM t_ivf_rowid_map00").fetchone()[0] == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# KNN flat mode tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_knn_flat_mode(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
|
||||||
|
)
|
||||||
|
db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 0, 0, 0])])
|
||||||
|
db.execute("INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([2, 0, 0, 0])])
|
||||||
|
db.execute("INSERT INTO t(rowid, v) VALUES (3, ?)", [_f32([9, 0, 0, 0])])
|
||||||
|
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT rowid, distance FROM t WHERE v MATCH ? AND k = 2",
|
||||||
|
[_f32([1.5, 0, 0, 0])],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) == 2
|
||||||
|
rowids = {r[0] for r in rows}
|
||||||
|
assert rowids == {1, 2}
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_knn_flat_empty(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
|
||||||
|
)
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT rowid FROM t WHERE v MATCH ? AND k = 5",
|
||||||
|
[_f32([1, 0, 0, 0])],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# compute-centroids tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_compute_centroids(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
|
||||||
|
)
|
||||||
|
for i in range(40):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([i % 10, i // 10, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ivf_unassigned_count(db) == 40
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
# After training: unassigned cell should be gone (or empty), vectors in trained cells
|
||||||
|
assert ivf_unassigned_count(db) == 0
|
||||||
|
assert ivf_assigned_count(db) == 40
|
||||||
|
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 4
|
||||||
|
trained = db.execute(
|
||||||
|
"SELECT value FROM t_info WHERE key='ivf_trained_0'"
|
||||||
|
).fetchone()[0]
|
||||||
|
assert str(trained) == "1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_compute_centroids_recompute(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2))"
|
||||||
|
)
|
||||||
|
for i in range(20):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
||||||
|
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
||||||
|
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
|
||||||
|
assert ivf_assigned_count(db) == 20
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Insert after training (assigned mode)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_insert_after_training(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2))"
|
||||||
|
)
|
||||||
|
for i in range(20):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (100, ?)", [_f32([5, 0, 0, 0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should be in a trained cell, not unassigned
|
||||||
|
row = db.execute(
|
||||||
|
"SELECT m.cell_id, c.centroid_id FROM t_ivf_rowid_map00 m "
|
||||||
|
"JOIN t_ivf_cells00 c ON c.rowid = m.cell_id "
|
||||||
|
"WHERE m.rowid = 100"
|
||||||
|
).fetchone()
|
||||||
|
assert row is not None
|
||||||
|
assert row[1] >= 0 # centroid_id >= 0 means trained cell
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# KNN after training (IVF probe mode)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_knn_after_training(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4, nprobe=4))"
|
||||||
|
)
|
||||||
|
for i in range(100):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT rowid, distance FROM t WHERE v MATCH ? AND k = 5",
|
||||||
|
[_f32([50.0, 0, 0, 0])],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) == 5
|
||||||
|
assert rows[0][0] == 50
|
||||||
|
assert rows[0][1] < 0.01
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_knn_k_larger_than_n(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
|
||||||
|
)
|
||||||
|
for i in range(5):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT rowid FROM t WHERE v MATCH ? AND k = 100",
|
||||||
|
[_f32([0, 0, 0, 0])],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) == 5
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Manual centroid import (set-centroid, assign-vectors)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_centroid_and_assign(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=0))"
|
||||||
|
)
|
||||||
|
for i in range(20):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES ('set-centroid:0', ?)",
|
||||||
|
[_f32([5, 0, 0, 0])],
|
||||||
|
)
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES ('set-centroid:1', ?)",
|
||||||
|
[_f32([15, 0, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(rowid) VALUES ('assign-vectors')")
|
||||||
|
|
||||||
|
assert ivf_unassigned_count(db) == 0
|
||||||
|
assert ivf_assigned_count(db) == 20
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# clear-centroids
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_clear_centroids(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2))"
|
||||||
|
)
|
||||||
|
for i in range(20):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
||||||
|
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(rowid) VALUES ('clear-centroids')")
|
||||||
|
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 0
|
||||||
|
assert ivf_unassigned_count(db) == 20
|
||||||
|
trained = db.execute(
|
||||||
|
"SELECT value FROM t_info WHERE key='ivf_trained_0'"
|
||||||
|
).fetchone()[0]
|
||||||
|
assert str(trained) == "0"
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Delete after training
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_delete_after_training(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2))"
|
||||||
|
)
|
||||||
|
for i in range(10):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
||||||
|
assert ivf_assigned_count(db) == 10
|
||||||
|
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = 5")
|
||||||
|
assert ivf_assigned_count(db) == 9
|
||||||
|
assert db.execute("SELECT count(*) FROM t_ivf_rowid_map00").fetchone()[0] == 9
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Recall test
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_recall_nprobe_equals_nlist(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=8, nprobe=8))"
|
||||||
|
)
|
||||||
|
for i in range(100):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(rowid) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT rowid FROM t WHERE v MATCH ? AND k = 10",
|
||||||
|
[_f32([50.0, 0, 0, 0])],
|
||||||
|
).fetchall()
|
||||||
|
rowids = {r[0] for r in rows}
|
||||||
|
|
||||||
|
# 45 and 55 are equidistant from 50, so either may appear in top 10
|
||||||
|
assert 50 in rowids
|
||||||
|
assert len(rowids) == 10
|
||||||
|
assert all(45 <= r <= 55 for r in rowids)
|
||||||
|
|
@ -577,6 +577,182 @@ void test_vec0_parse_vector_column() {
|
||||||
assert(rc == SQLITE_ERROR);
|
assert(rc == SQLITE_ERROR);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if SQLITE_VEC_ENABLE_IVF
|
||||||
|
// IVF: indexed by ivf() — defaults
|
||||||
|
{
|
||||||
|
const char *input = "v float[4] indexed by ivf()";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.index_type == VEC0_INDEX_TYPE_IVF);
|
||||||
|
assert(col.dimensions == 4);
|
||||||
|
assert(col.index_type == VEC0_INDEX_TYPE_IVF);
|
||||||
|
assert(col.ivf.nlist == 128); // default
|
||||||
|
assert(col.ivf.nprobe == 10); // default
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// IVF: indexed by ivf(nlist=8) — nprobe auto-clamped to 8
|
||||||
|
{
|
||||||
|
const char *input = "v float[4] indexed by ivf(nlist=8)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.index_type == VEC0_INDEX_TYPE_IVF);
|
||||||
|
assert(col.index_type == VEC0_INDEX_TYPE_IVF);
|
||||||
|
assert(col.ivf.nlist == 8);
|
||||||
|
assert(col.ivf.nprobe == 8); // clamped from default 10
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// IVF: indexed by ivf(nlist=64, nprobe=8)
|
||||||
|
{
|
||||||
|
const char *input = "v float[4] indexed by ivf(nlist=64, nprobe=8)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.index_type == VEC0_INDEX_TYPE_IVF);
|
||||||
|
assert(col.ivf.nlist == 64);
|
||||||
|
assert(col.ivf.nprobe == 8);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// IVF: with distance_metric before indexed by
|
||||||
|
{
|
||||||
|
const char *input = "v float[4] distance_metric=cosine indexed by ivf(nlist=16)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.index_type == VEC0_INDEX_TYPE_IVF);
|
||||||
|
assert(col.distance_metric == VEC0_DISTANCE_METRIC_COSINE);
|
||||||
|
assert(col.index_type == VEC0_INDEX_TYPE_IVF);
|
||||||
|
assert(col.ivf.nlist == 16);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// IVF: nlist=0 (deferred)
|
||||||
|
{
|
||||||
|
const char *input = "v float[4] indexed by ivf(nlist=0)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.ivf.nlist == 0);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// IVF error: nprobe > nlist
|
||||||
|
{
|
||||||
|
const char *input = "v float[4] indexed by ivf(nlist=4, nprobe=10)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
|
// IVF error: unknown key
|
||||||
|
{
|
||||||
|
const char *input = "v float[4] indexed by ivf(bogus=1)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
|
// IVF error: unknown index type (hnsw not supported)
|
||||||
|
{
|
||||||
|
const char *input = "v float[4] indexed by hnsw()";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not IVF: no ivf config
|
||||||
|
{
|
||||||
|
const char *input = "v float[4]";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.index_type == VEC0_INDEX_TYPE_FLAT);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// IVF: quantizer=binary
|
||||||
|
{
|
||||||
|
const char *input = "v float[768] indexed by ivf(nlist=128, quantizer=binary)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.index_type == VEC0_INDEX_TYPE_IVF);
|
||||||
|
assert(col.ivf.nlist == 128);
|
||||||
|
assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_BINARY);
|
||||||
|
assert(col.ivf.oversample == 1);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// IVF: quantizer=int8
|
||||||
|
{
|
||||||
|
const char *input = "v float[768] indexed by ivf(nlist=64, quantizer=int8)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_INT8);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// IVF: quantizer=none (explicit)
|
||||||
|
{
|
||||||
|
const char *input = "v float[768] indexed by ivf(quantizer=none)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_NONE);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// IVF: oversample=10 with quantizer
|
||||||
|
{
|
||||||
|
const char *input = "v float[768] indexed by ivf(nlist=128, quantizer=binary, oversample=10)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_BINARY);
|
||||||
|
assert(col.ivf.oversample == 10);
|
||||||
|
assert(col.ivf.nlist == 128);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// IVF: all params
|
||||||
|
{
|
||||||
|
const char *input = "v float[768] distance_metric=cosine indexed by ivf(nlist=256, nprobe=16, quantizer=int8, oversample=4)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.distance_metric == VEC0_DISTANCE_METRIC_COSINE);
|
||||||
|
assert(col.ivf.nlist == 256);
|
||||||
|
assert(col.ivf.nprobe == 16);
|
||||||
|
assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_INT8);
|
||||||
|
assert(col.ivf.oversample == 4);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// IVF error: oversample > 1 without quantizer
|
||||||
|
{
|
||||||
|
const char *input = "v float[768] indexed by ivf(oversample=10)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
|
// IVF error: unknown quantizer value
|
||||||
|
{
|
||||||
|
const char *input = "v float[768] indexed by ivf(quantizer=pq)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
|
// IVF: quantizer with defaults (nlist=128 default, nprobe=10 default)
|
||||||
|
{
|
||||||
|
const char *input = "v float[768] indexed by ivf(quantizer=binary, oversample=5)";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.ivf.nlist == 128);
|
||||||
|
assert(col.ivf.nprobe == 10);
|
||||||
|
assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_BINARY);
|
||||||
|
assert(col.ivf.oversample == 5);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
// When IVF is disabled, parsing "ivf" should fail
|
||||||
|
{
|
||||||
|
const char *input = "v float[4] indexed by ivf()";
|
||||||
|
rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
|
||||||
|
assert(rc == SQLITE_ERROR);
|
||||||
|
}
|
||||||
|
#endif /* SQLITE_VEC_ENABLE_IVF */
|
||||||
|
|
||||||
printf(" All vec0_parse_vector_column tests passed.\n");
|
printf(" All vec0_parse_vector_column tests passed.\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -821,6 +997,38 @@ void test_rescore_quantize_float_to_int8() {
|
||||||
float src[8] = {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f};
|
float src[8] = {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f};
|
||||||
_test_rescore_quantize_float_to_int8(src, dst, 8);
|
_test_rescore_quantize_float_to_int8(src, dst, 8);
|
||||||
for (int i = 0; i < 8; i++) {
|
for (int i = 0; i < 8; i++) {
|
||||||
|
#if SQLITE_VEC_ENABLE_IVF
|
||||||
|
void test_ivf_quantize_int8() {
|
||||||
|
printf("Starting %s...\n", __func__);
|
||||||
|
|
||||||
|
// Basic values in [-1, 1] range
|
||||||
|
{
|
||||||
|
float src[] = {0.0f, 1.0f, -1.0f, 0.5f};
|
||||||
|
int8_t dst[4];
|
||||||
|
ivf_quantize_int8(src, dst, 4);
|
||||||
|
assert(dst[0] == 0);
|
||||||
|
assert(dst[1] == 127);
|
||||||
|
assert(dst[2] == -127);
|
||||||
|
assert(dst[3] == 63); // 0.5 * 127 = 63.5, truncated to 63
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clamping: values beyond [-1, 1]
|
||||||
|
{
|
||||||
|
float src[] = {2.0f, -3.0f, 100.0f, -0.01f};
|
||||||
|
int8_t dst[4];
|
||||||
|
ivf_quantize_int8(src, dst, 4);
|
||||||
|
assert(dst[0] == 127); // clamped to 1.0
|
||||||
|
assert(dst[1] == -127); // clamped to -1.0
|
||||||
|
assert(dst[2] == 127); // clamped to 1.0
|
||||||
|
assert(dst[3] == (int8_t)(-0.01f * 127.0f));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Zero vector
|
||||||
|
{
|
||||||
|
float src[] = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||||
|
int8_t dst[4];
|
||||||
|
ivf_quantize_int8(src, dst, 4);
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
assert(dst[i] == 0);
|
assert(dst[i] == 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -882,6 +1090,103 @@ void test_rescore_quantized_byte_size() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void test_vec0_parse_vector_column_rescore() {
|
void test_vec0_parse_vector_column_rescore() {
|
||||||
|
// Negative zero
|
||||||
|
{
|
||||||
|
float src[] = {-0.0f};
|
||||||
|
int8_t dst[1];
|
||||||
|
ivf_quantize_int8(src, dst, 1);
|
||||||
|
assert(dst[0] == 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Single element
|
||||||
|
{
|
||||||
|
float src[] = {0.75f};
|
||||||
|
int8_t dst[1];
|
||||||
|
ivf_quantize_int8(src, dst, 1);
|
||||||
|
assert(dst[0] == (int8_t)(0.75f * 127.0f));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Boundary: exactly 1.0 and -1.0
|
||||||
|
{
|
||||||
|
float src[] = {1.0f, -1.0f};
|
||||||
|
int8_t dst[2];
|
||||||
|
ivf_quantize_int8(src, dst, 2);
|
||||||
|
assert(dst[0] == 127);
|
||||||
|
assert(dst[1] == -127);
|
||||||
|
}
|
||||||
|
|
||||||
|
printf(" All ivf_quantize_int8 tests passed.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_ivf_quantize_binary() {
|
||||||
|
printf("Starting %s...\n", __func__);
|
||||||
|
|
||||||
|
// Basic sign-bit quantization: positive -> 1, negative/zero -> 0
|
||||||
|
{
|
||||||
|
float src[] = {1.0f, -1.0f, 0.5f, -0.5f, 0.0f, 0.1f, -0.1f, 2.0f};
|
||||||
|
uint8_t dst[1];
|
||||||
|
ivf_quantize_binary(src, dst, 8);
|
||||||
|
// bit 0: 1.0 > 0 -> 1 (LSB)
|
||||||
|
// bit 1: -1.0 -> 0
|
||||||
|
// bit 2: 0.5 > 0 -> 1
|
||||||
|
// bit 3: -0.5 -> 0
|
||||||
|
// bit 4: 0.0 -> 0 (not > 0)
|
||||||
|
// bit 5: 0.1 > 0 -> 1
|
||||||
|
// bit 6: -0.1 -> 0
|
||||||
|
// bit 7: 2.0 > 0 -> 1
|
||||||
|
// Expected: bits 0,2,5,7 = 0b10100101 = 0xA5
|
||||||
|
assert(dst[0] == 0xA5);
|
||||||
|
}
|
||||||
|
|
||||||
|
// All positive
|
||||||
|
{
|
||||||
|
float src[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f};
|
||||||
|
uint8_t dst[1];
|
||||||
|
ivf_quantize_binary(src, dst, 8);
|
||||||
|
assert(dst[0] == 0xFF);
|
||||||
|
}
|
||||||
|
|
||||||
|
// All negative
|
||||||
|
{
|
||||||
|
float src[] = {-1.0f, -2.0f, -3.0f, -4.0f, -5.0f, -6.0f, -7.0f, -8.0f};
|
||||||
|
uint8_t dst[1];
|
||||||
|
ivf_quantize_binary(src, dst, 8);
|
||||||
|
assert(dst[0] == 0x00);
|
||||||
|
}
|
||||||
|
|
||||||
|
// All zero (zero is NOT > 0, so all bits should be 0)
|
||||||
|
{
|
||||||
|
float src[] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
|
||||||
|
uint8_t dst[1];
|
||||||
|
ivf_quantize_binary(src, dst, 8);
|
||||||
|
assert(dst[0] == 0x00);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multi-byte: 16 dimensions -> 2 bytes
|
||||||
|
{
|
||||||
|
float src[16];
|
||||||
|
for (int i = 0; i < 16; i++) src[i] = (i % 2 == 0) ? 1.0f : -1.0f;
|
||||||
|
uint8_t dst[2];
|
||||||
|
ivf_quantize_binary(src, dst, 16);
|
||||||
|
// Even indices are positive: bits 0,2,4,6 in each byte
|
||||||
|
// byte 0: bits 0,2,4,6 = 0b01010101 = 0x55
|
||||||
|
// byte 1: same pattern = 0x55
|
||||||
|
assert(dst[0] == 0x55);
|
||||||
|
assert(dst[1] == 0x55);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Single byte, only first bit set
|
||||||
|
{
|
||||||
|
float src[] = {0.1f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f};
|
||||||
|
uint8_t dst[1];
|
||||||
|
ivf_quantize_binary(src, dst, 8);
|
||||||
|
assert(dst[0] == 0x01);
|
||||||
|
}
|
||||||
|
|
||||||
|
printf(" All ivf_quantize_binary tests passed.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_ivf_config_parsing() {
|
||||||
printf("Starting %s...\n", __func__);
|
printf("Starting %s...\n", __func__);
|
||||||
struct VectorColumnDefinition col;
|
struct VectorColumnDefinition col;
|
||||||
int rc;
|
int rc;
|
||||||
|
|
@ -955,6 +1260,116 @@ void test_vec0_parse_vector_column_rescore() {
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif /* SQLITE_VEC_ENABLE_RESCORE */
|
#endif /* SQLITE_VEC_ENABLE_RESCORE */
|
||||||
|
// Default IVF config
|
||||||
|
{
|
||||||
|
const char *s = "v float[4] indexed by ivf()";
|
||||||
|
rc = vec0_parse_vector_column(s, (int)strlen(s), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.index_type == VEC0_INDEX_TYPE_IVF);
|
||||||
|
assert(col.ivf.nlist == 128); // default
|
||||||
|
assert(col.ivf.nprobe == 10); // default
|
||||||
|
assert(col.ivf.quantizer == 0); // VEC0_IVF_QUANTIZER_NONE
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Custom nlist and nprobe
|
||||||
|
{
|
||||||
|
const char *s = "v float[4] indexed by ivf(nlist=64, nprobe=8)";
|
||||||
|
rc = vec0_parse_vector_column(s, (int)strlen(s), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.ivf.nlist == 64);
|
||||||
|
assert(col.ivf.nprobe == 8);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// nlist=0 (deferred)
|
||||||
|
{
|
||||||
|
const char *s = "v float[4] indexed by ivf(nlist=0)";
|
||||||
|
rc = vec0_parse_vector_column(s, (int)strlen(s), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.ivf.nlist == 0);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Quantizer options
|
||||||
|
{
|
||||||
|
const char *s = "v float[8] indexed by ivf(quantizer=int8)";
|
||||||
|
rc = vec0_parse_vector_column(s, (int)strlen(s), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_INT8);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const char *s = "v float[8] indexed by ivf(quantizer=binary)";
|
||||||
|
rc = vec0_parse_vector_column(s, (int)strlen(s), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_BINARY);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// nprobe > nlist (explicit) should fail
|
||||||
|
{
|
||||||
|
const char *s = "v float[4] indexed by ivf(nlist=4, nprobe=10)";
|
||||||
|
rc = vec0_parse_vector_column(s, (int)strlen(s), &col);
|
||||||
|
assert(rc == SQLITE_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unknown key
|
||||||
|
{
|
||||||
|
const char *s = "v float[4] indexed by ivf(bogus=1)";
|
||||||
|
rc = vec0_parse_vector_column(s, (int)strlen(s), &col);
|
||||||
|
assert(rc == SQLITE_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
|
// nlist > max (65536) should fail
|
||||||
|
{
|
||||||
|
const char *s = "v float[4] indexed by ivf(nlist=65537)";
|
||||||
|
rc = vec0_parse_vector_column(s, (int)strlen(s), &col);
|
||||||
|
assert(rc == SQLITE_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
|
// nlist at max boundary (65536) should succeed
|
||||||
|
{
|
||||||
|
const char *s = "v float[4] indexed by ivf(nlist=65536)";
|
||||||
|
rc = vec0_parse_vector_column(s, (int)strlen(s), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.ivf.nlist == 65536);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// oversample > 1 without quantization should fail
|
||||||
|
{
|
||||||
|
const char *s = "v float[4] indexed by ivf(oversample=4)";
|
||||||
|
rc = vec0_parse_vector_column(s, (int)strlen(s), &col);
|
||||||
|
assert(rc == SQLITE_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
|
// oversample with quantizer should succeed
|
||||||
|
{
|
||||||
|
const char *s = "v float[8] indexed by ivf(quantizer=int8, oversample=4)";
|
||||||
|
rc = vec0_parse_vector_column(s, (int)strlen(s), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.ivf.oversample == 4);
|
||||||
|
assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_INT8);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// All options combined
|
||||||
|
{
|
||||||
|
const char *s = "v float[8] indexed by ivf(nlist=32, nprobe=4, quantizer=int8, oversample=2)";
|
||||||
|
rc = vec0_parse_vector_column(s, (int)strlen(s), &col);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
assert(col.ivf.nlist == 32);
|
||||||
|
assert(col.ivf.nprobe == 4);
|
||||||
|
assert(col.ivf.quantizer == VEC0_IVF_QUANTIZER_INT8);
|
||||||
|
assert(col.ivf.oversample == 2);
|
||||||
|
sqlite3_free(col.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
printf(" All ivf_config_parsing tests passed.\n");
|
||||||
|
}
|
||||||
|
#endif /* SQLITE_VEC_ENABLE_IVF */
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
printf("Starting unit tests...\n");
|
printf("Starting unit tests...\n");
|
||||||
|
|
@ -982,6 +1397,10 @@ int main() {
|
||||||
test_rescore_quantize_float_to_int8();
|
test_rescore_quantize_float_to_int8();
|
||||||
test_rescore_quantized_byte_size();
|
test_rescore_quantized_byte_size();
|
||||||
test_vec0_parse_vector_column_rescore();
|
test_vec0_parse_vector_column_rescore();
|
||||||
|
#if SQLITE_VEC_ENABLE_IVF
|
||||||
|
test_ivf_quantize_int8();
|
||||||
|
test_ivf_quantize_binary();
|
||||||
|
test_ivf_config_parsing();
|
||||||
#endif
|
#endif
|
||||||
printf("All unit tests passed.\n");
|
printf("All unit tests passed.\n");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue