mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-26 01:06:27 +02:00
Compare commits
49 commits
v0.1.7-alp
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5778fecfeb | ||
|
|
4e2dfcb79d | ||
|
|
b95c05b3aa | ||
|
|
89f6203536 | ||
|
|
c5607100ab | ||
|
|
6e2c4c6bab | ||
|
|
b7fc459be4 | ||
|
|
01b4b2a965 | ||
|
|
c36a995f1e | ||
|
|
c4c23bd8ba | ||
|
|
5522e86cd2 | ||
|
|
f2c9fb8f08 | ||
|
|
d684178a12 | ||
|
|
d033bf5728 | ||
|
|
b00865429b | ||
|
|
2f4c2e4bdb | ||
|
|
7de925be70 | ||
|
|
4bee88384b | ||
|
|
5e4c557f93 | ||
|
|
82f4eb08bf | ||
|
|
9df59b4c03 | ||
|
|
07f56e3cbe | ||
|
|
3cfc2e0c1f | ||
|
|
85cf415397 | ||
|
|
85973b3814 | ||
|
|
8544081a67 | ||
|
|
a248ecd061 | ||
|
|
1e3bb3e5e3 | ||
|
|
fb81c011ff | ||
|
|
575371d751 | ||
|
|
e2c38f387c | ||
|
|
bb3ef78f75 | ||
|
|
3e26925ce0 | ||
|
|
3358e127f6 | ||
|
|
43982c144b | ||
|
|
45d1375602 | ||
|
|
69ccb2405a | ||
|
|
0de765f457 | ||
|
|
e9f598abfa | ||
|
|
6c3bf3669f | ||
|
|
69f7b658e9 | ||
|
|
ee9bd2ba4d | ||
|
|
ba0db0b6d6 | ||
|
|
bf2455f2ba | ||
|
|
dfd8dc5290 | ||
|
|
e7ae41b761 | ||
|
|
a8d81cb235 | ||
|
|
633eecf506 | ||
|
|
4138619e3f |
105 changed files with 21725 additions and 2248 deletions
2
.github/workflows/release.yaml
vendored
2
.github/workflows/release.yaml
vendored
|
|
@ -252,7 +252,7 @@ jobs:
|
|||
name: sqlite-vec-iossimulator-x86_64-extension
|
||||
path: dist/iossimulator-x86_64
|
||||
- run: make sqlite-vec.h
|
||||
- uses: asg017/setup-sqlite-dist@73e37b2ffb0b51e64a64eb035da38c958b9ff6c6
|
||||
- uses: asg017/setup-sqlite-dist@fadb0183a6ec70c3f1942de7d232b087ff2bacd1
|
||||
- run: sqlite-dist build --set-version $(cat VERSION)
|
||||
- run: |
|
||||
gh release upload ${{ github.ref_name }} \
|
||||
|
|
|
|||
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -31,3 +31,6 @@ poetry.lock
|
|||
|
||||
memstat.c
|
||||
memstat.*
|
||||
|
||||
|
||||
.DS_Store
|
||||
35
Makefile
35
Makefile
|
|
@ -37,11 +37,18 @@ endif
|
|||
|
||||
ifndef OMIT_SIMD
|
||||
ifeq ($(shell uname -sm),Darwin x86_64)
|
||||
CFLAGS += -mavx -DSQLITE_VEC_ENABLE_AVX
|
||||
CFLAGS += -mavx -mavx2 -DSQLITE_VEC_ENABLE_AVX
|
||||
endif
|
||||
ifeq ($(shell uname -sm),Darwin arm64)
|
||||
CFLAGS += -mcpu=apple-m1 -DSQLITE_VEC_ENABLE_NEON
|
||||
endif
|
||||
ifeq ($(shell uname -s),Linux)
|
||||
ifeq ($(findstring android,$(CC)),)
|
||||
ifneq ($(filter avx,$(shell grep -o 'avx[^ ]*' /proc/cpuinfo 2>/dev/null | head -1)),)
|
||||
CFLAGS += -mavx -mavx2 -DSQLITE_VEC_ENABLE_AVX
|
||||
endif
|
||||
endif
|
||||
endif
|
||||
endif
|
||||
|
||||
ifdef USE_BREW_SQLITE
|
||||
|
|
@ -155,6 +162,13 @@ clean:
|
|||
rm -rf dist
|
||||
|
||||
|
||||
TARGET_AMALGAMATION=$(prefix)/sqlite-vec.c
|
||||
|
||||
amalgamation: $(TARGET_AMALGAMATION)
|
||||
|
||||
$(TARGET_AMALGAMATION): sqlite-vec.c $(wildcard sqlite-vec-*.c) scripts/amalgamate.py $(prefix)
|
||||
python3 scripts/amalgamate.py sqlite-vec.c > $@
|
||||
|
||||
FORMAT_FILES=sqlite-vec.h sqlite-vec.c
|
||||
format: $(FORMAT_FILES)
|
||||
clang-format -i $(FORMAT_FILES)
|
||||
|
|
@ -174,7 +188,7 @@ evidence-of:
|
|||
test:
|
||||
sqlite3 :memory: '.read test.sql'
|
||||
|
||||
.PHONY: version loadable static test clean gh-release evidence-of install uninstall
|
||||
.PHONY: version loadable static test clean gh-release evidence-of install uninstall amalgamation
|
||||
|
||||
publish-release:
|
||||
./scripts/publish-release.sh
|
||||
|
|
@ -190,7 +204,22 @@ test-loadable-watch:
|
|||
watchexec --exts c,py,Makefile --clear -- make test-loadable
|
||||
|
||||
test-unit:
|
||||
$(CC) -DSQLITE_CORE -DSQLITE_VEC_TEST tests/test-unit.c sqlite-vec.c vendor/sqlite3.c -I./ -Ivendor -o $(prefix)/test-unit && $(prefix)/test-unit
|
||||
$(CC) -DSQLITE_CORE -DSQLITE_VEC_TEST -DSQLITE_VEC_ENABLE_RESCORE -DSQLITE_VEC_ENABLE_DISKANN=1 tests/test-unit.c sqlite-vec.c vendor/sqlite3.c -I./ -Ivendor $(CFLAGS) -o $(prefix)/test-unit && $(prefix)/test-unit
|
||||
|
||||
# Standalone sqlite3 CLI with vec0 compiled in. Useful for benchmarking,
|
||||
# profiling (has debug symbols), and scripting without .load_extension.
|
||||
# make cli
|
||||
# dist/sqlite3 :memory: "SELECT vec_version()"
|
||||
# dist/sqlite3 < script.sql
|
||||
cli: sqlite-vec.h $(prefix)
|
||||
$(CC) -O2 -g \
|
||||
-DSQLITE_CORE \
|
||||
-DSQLITE_EXTRA_INIT=core_init \
|
||||
-DSQLITE_THREADSAFE=0 \
|
||||
-Ivendor/ -I./ \
|
||||
$(CFLAGS) \
|
||||
vendor/sqlite3.c vendor/shell.c sqlite-vec.c examples/sqlite3-cli/core_init.c \
|
||||
-ldl -lm -o $(prefix)/sqlite3
|
||||
|
||||
fuzz-build:
|
||||
$(MAKE) -C tests/fuzz all
|
||||
|
|
|
|||
73
TODO.md
Normal file
73
TODO.md
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
# TODO: `ann` base branch + consolidated benchmarks
|
||||
|
||||
## 1. Create `ann` branch with shared code
|
||||
|
||||
### 1.1 Branch setup
|
||||
- [x] `git checkout -B ann origin/main`
|
||||
- [x] Cherry-pick `624f998` (vec0_distance_full shared distance dispatch)
|
||||
- [x] Cherry-pick stdint.h fix for test header
|
||||
- [ ] Pull NEON cosine optimization from ivf-yolo3 into shared code
|
||||
- Currently only in ivf branch but is general-purpose (benefits all distance calcs)
|
||||
- Lives in `distance_cosine_float()` — ~57 lines of ARM NEON vectorized cosine
|
||||
|
||||
### 1.2 Benchmark infrastructure (`benchmarks-ann/`)
|
||||
- [x] Seed data pipeline (`seed/Makefile`, `seed/build_base_db.py`)
|
||||
- [x] Ground truth generator (`ground_truth.py`)
|
||||
- [x] Results schema (`schema.sql`)
|
||||
- [x] Benchmark runner with `INDEX_REGISTRY` extension point (`bench.py`)
|
||||
- Baseline configs (float, int8-rescore, bit-rescore) implemented
|
||||
- Index branches register their types via `INDEX_REGISTRY` dict
|
||||
- [x] Makefile with baseline targets
|
||||
- [x] README
|
||||
|
||||
### 1.3 Rebase feature branches onto `ann`
|
||||
- [x] Rebase `diskann-yolo2` onto `ann` (1 commit: DiskANN implementation)
|
||||
- [x] Rebase `ivf-yolo3` onto `ann` (1 commit: IVF implementation)
|
||||
- [x] Rebase `annoy-yolo2` onto `ann` (2 commits: Annoy implementation + schema fix)
|
||||
- [x] Verify each branch has only its index-specific commits remaining
|
||||
- [ ] Force-push all 4 branches to origin
|
||||
|
||||
---
|
||||
|
||||
## 2. Per-branch: register index type in benchmarks
|
||||
|
||||
Each index branch should add to `benchmarks-ann/` when rebased onto `ann`:
|
||||
|
||||
### 2.1 Register in `bench.py`
|
||||
|
||||
Add an `INDEX_REGISTRY` entry. Each entry provides:
|
||||
- `defaults` — default param values
|
||||
- `create_table_sql(params)` — CREATE VIRTUAL TABLE with INDEXED BY clause
|
||||
- `insert_sql(params)` — custom insert SQL, or None for default
|
||||
- `post_insert_hook(conn, params)` — training/building step, returns time
|
||||
- `run_query(conn, params, query, k)` — custom query, or None for default MATCH
|
||||
- `describe(params)` — one-line description for report output
|
||||
|
||||
### 2.2 Add configs to `Makefile`
|
||||
|
||||
Append index-specific config variables and targets. Example pattern:
|
||||
|
||||
```makefile
|
||||
DISKANN_CONFIGS = \
|
||||
"diskann-R48-binary:type=diskann,R=48,L=128,quantizer=binary" \
|
||||
...
|
||||
|
||||
ALL_CONFIGS += $(DISKANN_CONFIGS)
|
||||
|
||||
bench-diskann: seed
|
||||
$(BENCH) --subset-size 10000 -k 10 -o runs/diskann $(BASELINES) $(DISKANN_CONFIGS)
|
||||
...
|
||||
```
|
||||
|
||||
### 2.3 Migrate existing benchmark results/docs
|
||||
|
||||
- Move useful results docs (RESULTS.md, etc.) into `benchmarks-ann/results/`
|
||||
- Delete redundant per-branch benchmark directories once consolidated infra is proven
|
||||
|
||||
---
|
||||
|
||||
## 3. Future improvements
|
||||
|
||||
- [ ] Reporting script (`report.py`) — query results.db, produce markdown comparison tables
|
||||
- [ ] Profiling targets in Makefile (lift from ivf-yolo3's Instruments/perf wrappers)
|
||||
- [ ] Pre-computed ground truth integration (use GT DB files instead of on-the-fly brute-force)
|
||||
2
VERSION
2
VERSION
|
|
@ -1 +1 @@
|
|||
0.1.7-alpha.12
|
||||
0.1.10-alpha.3
|
||||
8
benchmarks-ann/.gitignore
vendored
Normal file
8
benchmarks-ann/.gitignore
vendored
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
*.db
|
||||
*.db-shm
|
||||
*.db-wal
|
||||
*.parquet
|
||||
runs/
|
||||
|
||||
viewer/
|
||||
searcher/
|
||||
85
benchmarks-ann/Makefile
Normal file
85
benchmarks-ann/Makefile
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
BENCH = python bench.py
|
||||
BASE_DB = cohere1m/base.db
|
||||
EXT = ../dist/vec0
|
||||
|
||||
# --- Baseline (brute-force) configs ---
|
||||
BASELINES = \
|
||||
"brute-float:type=vec0-flat,variant=float" \
|
||||
"brute-int8:type=vec0-flat,variant=int8" \
|
||||
"brute-bit:type=vec0-flat,variant=bit"
|
||||
|
||||
# --- IVF configs ---
|
||||
IVF_CONFIGS = \
|
||||
"ivf-n32-p8:type=ivf,nlist=32,nprobe=8" \
|
||||
"ivf-n128-p16:type=ivf,nlist=128,nprobe=16" \
|
||||
"ivf-n512-p32:type=ivf,nlist=512,nprobe=32"
|
||||
|
||||
RESCORE_CONFIGS = \
|
||||
"rescore-bit-os8:type=rescore,quantizer=bit,oversample=8" \
|
||||
"rescore-bit-os16:type=rescore,quantizer=bit,oversample=16" \
|
||||
"rescore-int8-os8:type=rescore,quantizer=int8,oversample=8"
|
||||
|
||||
# --- DiskANN configs ---
|
||||
DISKANN_CONFIGS = \
|
||||
"diskann-R48-binary:type=diskann,R=48,L=128,quantizer=binary" \
|
||||
"diskann-R72-binary:type=diskann,R=72,L=128,quantizer=binary" \
|
||||
"diskann-R72-int8:type=diskann,R=72,L=128,quantizer=int8" \
|
||||
"diskann-R72-L256:type=diskann,R=72,L=256,quantizer=binary"
|
||||
|
||||
ALL_CONFIGS = $(BASELINES) $(RESCORE_CONFIGS) $(IVF_CONFIGS) $(DISKANN_CONFIGS)
|
||||
|
||||
.PHONY: seed ground-truth bench-smoke bench-rescore bench-ivf bench-diskann bench-10k bench-50k bench-100k bench-all \
|
||||
report clean
|
||||
|
||||
# --- Data preparation ---
|
||||
seed:
|
||||
$(MAKE) -C cohere1m
|
||||
|
||||
ground-truth: seed
|
||||
python ground_truth.py --subset-size 10000
|
||||
python ground_truth.py --subset-size 50000
|
||||
python ground_truth.py --subset-size 100000
|
||||
|
||||
# --- Quick smoke test ---
|
||||
bench-smoke: seed
|
||||
$(BENCH) --subset-size 5000 -k 10 -n 20 --dataset cohere1m -o runs \
|
||||
"brute-float:type=vec0-flat,variant=float" \
|
||||
"ivf-quick:type=ivf,nlist=16,nprobe=4" \
|
||||
"diskann-quick:type=diskann,R=48,L=64,quantizer=binary"
|
||||
|
||||
bench-rescore: seed
|
||||
$(BENCH) --subset-size 10000 -k 10 --dataset cohere1m -o runs \
|
||||
$(RESCORE_CONFIGS)
|
||||
|
||||
|
||||
# --- Standard sizes ---
|
||||
bench-10k: seed
|
||||
$(BENCH) --subset-size 10000 -k 10 --dataset cohere1m -o runs $(ALL_CONFIGS)
|
||||
|
||||
bench-50k: seed
|
||||
$(BENCH) --subset-size 50000 -k 10 --dataset cohere1m -o runs $(ALL_CONFIGS)
|
||||
|
||||
bench-100k: seed
|
||||
$(BENCH) --subset-size 100000 -k 10 --dataset cohere1m -o runs $(ALL_CONFIGS)
|
||||
|
||||
bench-all: bench-10k bench-50k bench-100k
|
||||
|
||||
# --- IVF across sizes ---
|
||||
bench-ivf: seed
|
||||
$(BENCH) --subset-size 10000 -k 10 --dataset cohere1m -o runs $(BASELINES) $(IVF_CONFIGS)
|
||||
$(BENCH) --subset-size 50000 -k 10 --dataset cohere1m -o runs $(BASELINES) $(IVF_CONFIGS)
|
||||
$(BENCH) --subset-size 100000 -k 10 --dataset cohere1m -o runs $(BASELINES) $(IVF_CONFIGS)
|
||||
|
||||
# --- DiskANN across sizes ---
|
||||
bench-diskann: seed
|
||||
$(BENCH) --subset-size 10000 -k 10 --dataset cohere1m -o runs $(BASELINES) $(DISKANN_CONFIGS)
|
||||
$(BENCH) --subset-size 50000 -k 10 --dataset cohere1m -o runs $(BASELINES) $(DISKANN_CONFIGS)
|
||||
$(BENCH) --subset-size 100000 -k 10 --dataset cohere1m -o runs $(BASELINES) $(DISKANN_CONFIGS)
|
||||
|
||||
# --- Report ---
|
||||
report:
|
||||
@echo "Use: sqlite3 runs/cohere1m/<size>/results.db 'SELECT run_id, config_name, status, recall FROM runs JOIN run_results USING(run_id)'"
|
||||
|
||||
# --- Cleanup ---
|
||||
clean:
|
||||
rm -rf runs/
|
||||
111
benchmarks-ann/README.md
Normal file
111
benchmarks-ann/README.md
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
# KNN Benchmarks for sqlite-vec
|
||||
|
||||
Benchmarking infrastructure for vec0 KNN configurations. Includes brute-force
|
||||
baselines (float, int8, bit), rescore, IVF, and DiskANN index types.
|
||||
|
||||
## Datasets
|
||||
|
||||
Each dataset is a subdirectory containing a `Makefile` and `build_base_db.py`
|
||||
that produce a `base.db`. The benchmark runner auto-discovers any subdirectory
|
||||
with a `base.db` file.
|
||||
|
||||
```
|
||||
cohere1m/ # Cohere 768d cosine, 1M vectors
|
||||
Makefile # downloads parquets from Zilliz, builds base.db
|
||||
build_base_db.py
|
||||
base.db # (generated)
|
||||
|
||||
cohere10m/ # Cohere 768d cosine, 10M vectors (10 train shards)
|
||||
Makefile # make -j12 download to fetch all shards in parallel
|
||||
build_base_db.py
|
||||
base.db # (generated)
|
||||
```
|
||||
|
||||
Every `base.db` has the same schema:
|
||||
|
||||
| Table | Columns | Description |
|
||||
|-------|---------|-------------|
|
||||
| `train` | `id INTEGER PRIMARY KEY, vector BLOB` | Indexed vectors (f32 blobs) |
|
||||
| `query_vectors` | `id INTEGER PRIMARY KEY, vector BLOB` | Query vectors for KNN evaluation |
|
||||
| `neighbors` | `query_vector_id INTEGER, rank INTEGER, neighbors_id TEXT` | Ground-truth nearest neighbors |
|
||||
|
||||
To add a new dataset, create a directory with a `Makefile` that builds `base.db`
|
||||
with the tables above. It will be available via `--dataset <dirname>` automatically.
|
||||
|
||||
### Building datasets
|
||||
|
||||
```bash
|
||||
# Cohere 1M
|
||||
cd cohere1m && make download && make && cd ..
|
||||
|
||||
# Cohere 10M (parallel download recommended — 10 train shards + test + neighbors)
|
||||
cd cohere10m && make -j12 download && make && cd ..
|
||||
```
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Built `dist/vec0` extension (run `make loadable` from repo root)
|
||||
- Python 3.10+
|
||||
- `uv`
|
||||
|
||||
## Quick start
|
||||
|
||||
```bash
|
||||
# 1. Build a dataset
|
||||
cd cohere1m && make && cd ..
|
||||
|
||||
# 2. Quick smoke test (5k vectors)
|
||||
make bench-smoke
|
||||
|
||||
# 3. Full benchmark at 10k
|
||||
make bench-10k
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
uv run python bench.py --subset-size 10000 -k 10 -n 50 --dataset cohere1m \
|
||||
"brute-float:type=baseline,variant=float" \
|
||||
"rescore-bit-os8:type=rescore,quantizer=bit,oversample=8"
|
||||
```
|
||||
|
||||
### Config format
|
||||
|
||||
`name:type=<index_type>,key=val,key=val`
|
||||
|
||||
| Index type | Keys |
|
||||
|-----------|------|
|
||||
| `baseline` | `variant` (float/int8/bit), `oversample` |
|
||||
| `rescore` | `quantizer` (bit/int8), `oversample` |
|
||||
| `ivf` | `nlist`, `nprobe` |
|
||||
| `diskann` | `R`, `L`, `quantizer` (binary/int8), `buffer_threshold` |
|
||||
|
||||
### Make targets
|
||||
|
||||
| Target | Description |
|
||||
|--------|-------------|
|
||||
| `make seed` | Download and build default dataset |
|
||||
| `make bench-smoke` | Quick 5k test (3 configs) |
|
||||
| `make bench-10k` | All configs at 10k vectors |
|
||||
| `make bench-50k` | All configs at 50k vectors |
|
||||
| `make bench-100k` | All configs at 100k vectors |
|
||||
| `make bench-all` | 10k + 50k + 100k |
|
||||
| `make bench-ivf` | Baselines + IVF across 10k/50k/100k |
|
||||
| `make bench-diskann` | Baselines + DiskANN across 10k/50k/100k |
|
||||
|
||||
## Results DB
|
||||
|
||||
Each run writes to `runs/<dataset>/<subset_size>/results.db` (SQLite, WAL mode).
|
||||
Progress is written continuously — query from another terminal to monitor:
|
||||
|
||||
```bash
|
||||
sqlite3 runs/cohere1m/10000/results.db "SELECT run_id, config_name, status FROM runs"
|
||||
```
|
||||
|
||||
See `results_schema.sql` for the full schema (tables: `runs`, `run_results`,
|
||||
`insert_batches`, `queries`).
|
||||
|
||||
## Adding an index type
|
||||
|
||||
Add an entry to `INDEX_REGISTRY` in `bench.py` and append configs to
|
||||
`ALL_CONFIGS` in the `Makefile`. See existing entries for the pattern.
|
||||
3
benchmarks-ann/bench-delete/.gitignore
vendored
Normal file
3
benchmarks-ann/bench-delete/.gitignore
vendored
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
runs/
|
||||
*.db
|
||||
__pycache__/
|
||||
41
benchmarks-ann/bench-delete/Makefile
Normal file
41
benchmarks-ann/bench-delete/Makefile
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
BENCH = python bench_delete.py
|
||||
EXT = ../../dist/vec0
|
||||
|
||||
# --- Configs to test ---
|
||||
FLAT = "flat:type=vec0-flat,variant=float"
|
||||
RESCORE_BIT = "rescore-bit:type=rescore,quantizer=bit,oversample=8"
|
||||
RESCORE_INT8 = "rescore-int8:type=rescore,quantizer=int8,oversample=8"
|
||||
DISKANN_R48 = "diskann-R48:type=diskann,R=48,L=128,quantizer=binary"
|
||||
DISKANN_R72 = "diskann-R72:type=diskann,R=72,L=128,quantizer=binary"
|
||||
|
||||
ALL_CONFIGS = $(FLAT) $(RESCORE_BIT) $(RESCORE_INT8) $(DISKANN_R48) $(DISKANN_R72)
|
||||
|
||||
DELETE_PCTS = 5,10,25,50,75,90
|
||||
|
||||
.PHONY: smoke bench-10k bench-50k bench-all report clean
|
||||
|
||||
# Quick smoke test (small dataset, few queries)
|
||||
smoke:
|
||||
$(BENCH) --subset-size 5000 --delete-pct 10,50 -k 10 -n 20 \
|
||||
--dataset cohere1m --ext $(EXT) \
|
||||
$(FLAT) $(DISKANN_R48)
|
||||
|
||||
# Standard benchmarks
|
||||
bench-10k:
|
||||
$(BENCH) --subset-size 10000 --delete-pct $(DELETE_PCTS) -k 10 -n 50 \
|
||||
--dataset cohere1m --ext $(EXT) $(ALL_CONFIGS)
|
||||
|
||||
bench-50k:
|
||||
$(BENCH) --subset-size 50000 --delete-pct $(DELETE_PCTS) -k 10 -n 50 \
|
||||
--dataset cohere1m --ext $(EXT) $(ALL_CONFIGS)
|
||||
|
||||
bench-all: bench-10k bench-50k
|
||||
|
||||
# Query saved results
|
||||
report:
|
||||
@echo "Query results:"
|
||||
@echo " sqlite3 runs/cohere1m/10000/delete_results.db \\"
|
||||
@echo " \"SELECT config_name, delete_pct, recall, query_mean_ms, vacuum_size_mb FROM delete_runs ORDER BY config_name, delete_pct\""
|
||||
|
||||
clean:
|
||||
rm -rf runs/
|
||||
69
benchmarks-ann/bench-delete/README.md
Normal file
69
benchmarks-ann/bench-delete/README.md
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
# bench-delete: Recall degradation after random deletion
|
||||
|
||||
Measures how KNN recall changes after deleting a random percentage of rows
|
||||
from different index types (flat, rescore, DiskANN).
|
||||
|
||||
## Quick start
|
||||
|
||||
```bash
|
||||
# Ensure dataset exists
|
||||
make -C ../datasets/cohere1m
|
||||
|
||||
# Ensure extension is built
|
||||
make -C ../.. loadable
|
||||
|
||||
# Quick smoke test
|
||||
make smoke
|
||||
|
||||
# Full benchmark at 10k vectors
|
||||
make bench-10k
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
python bench_delete.py --subset-size 10000 --delete-pct 10,25,50,75 \
|
||||
"flat:type=vec0-flat,variant=float" \
|
||||
"diskann-R72:type=diskann,R=72,L=128,quantizer=binary" \
|
||||
"rescore-bit:type=rescore,quantizer=bit,oversample=8"
|
||||
```
|
||||
|
||||
## What it measures
|
||||
|
||||
For each config and delete percentage:
|
||||
|
||||
| Metric | Description |
|
||||
|--------|-------------|
|
||||
| **recall** | KNN recall@k after deletion (ground truth recomputed over surviving rows) |
|
||||
| **delta** | Recall change vs 0% baseline |
|
||||
| **query latency** | Mean/median query time after deletion |
|
||||
| **db_size_mb** | DB file size before VACUUM |
|
||||
| **vacuum_size_mb** | DB file size after VACUUM (space reclaimed) |
|
||||
| **delete_time_s** | Wall time for the DELETE operations |
|
||||
|
||||
## How it works
|
||||
|
||||
1. Build index with N vectors (one copy per config)
|
||||
2. Measure recall at k=10 (pre-delete baseline)
|
||||
3. For each delete %:
|
||||
- Copy the master DB
|
||||
- Delete a random selection of rows (deterministic seed)
|
||||
- Measure recall (ground truth recomputed over surviving rows only)
|
||||
- VACUUM and measure size savings
|
||||
4. Print comparison table
|
||||
|
||||
## Expected behavior
|
||||
|
||||
- **Flat index**: Recall should be 1.0 at all delete percentages (brute-force is always exact)
|
||||
- **Rescore**: Recall should stay close to baseline (quantized scan + rescore is robust)
|
||||
- **DiskANN**: Recall may degrade at high delete % due to graph fragmentation (dangling edges, broken connectivity)
|
||||
|
||||
## Results DB
|
||||
|
||||
Results are stored in `runs/<dataset>/<subset_size>/delete_results.db`:
|
||||
|
||||
```sql
|
||||
SELECT config_name, delete_pct, recall, vacuum_size_mb
|
||||
FROM delete_runs
|
||||
ORDER BY config_name, delete_pct;
|
||||
```
|
||||
593
benchmarks-ann/bench-delete/bench_delete.py
Normal file
593
benchmarks-ann/bench-delete/bench_delete.py
Normal file
|
|
@ -0,0 +1,593 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Benchmark: measure recall degradation after random row deletion.
|
||||
|
||||
Given a dataset and index config, this script:
|
||||
1. Builds the index (flat + ANN)
|
||||
2. Measures recall at k=10 (pre-delete baseline)
|
||||
3. Deletes a random % of rows
|
||||
4. Measures recall again (post-delete)
|
||||
5. Records DB size before/after deletion, recall delta, timings
|
||||
|
||||
Usage:
|
||||
python bench_delete.py --subset-size 10000 --delete-pct 25 \
|
||||
"diskann-R48:type=diskann,R=48,L=128,quantizer=binary"
|
||||
|
||||
# Multiple delete percentages in one run:
|
||||
python bench_delete.py --subset-size 10000 --delete-pct 10,25,50,75 \
|
||||
"diskann-R48:type=diskann,R=48,L=128,quantizer=binary"
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import sqlite3
|
||||
import statistics
|
||||
import struct
|
||||
import time
|
||||
|
||||
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
_BENCH_DIR = os.path.join(_SCRIPT_DIR, "..")
|
||||
_ROOT_DIR = os.path.join(_BENCH_DIR, "..")
|
||||
|
||||
EXT_PATH = os.path.join(_ROOT_DIR, "dist", "vec0")
|
||||
DATASETS_DIR = os.path.join(_BENCH_DIR, "datasets")
|
||||
|
||||
DATASETS = {
|
||||
"cohere1m": {"base_db": os.path.join(DATASETS_DIR, "cohere1m", "base.db"), "dimensions": 768},
|
||||
"cohere10m": {"base_db": os.path.join(DATASETS_DIR, "cohere10m", "base.db"), "dimensions": 768},
|
||||
"nyt": {"base_db": os.path.join(DATASETS_DIR, "nyt", "base.db"), "dimensions": 256},
|
||||
"nyt-768": {"base_db": os.path.join(DATASETS_DIR, "nyt-768", "base.db"), "dimensions": 768},
|
||||
"nyt-1024": {"base_db": os.path.join(DATASETS_DIR, "nyt-1024", "base.db"), "dimensions": 1024},
|
||||
"nyt-384": {"base_db": os.path.join(DATASETS_DIR, "nyt-384", "base.db"), "dimensions": 384},
|
||||
}
|
||||
|
||||
INSERT_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Timing helpers
|
||||
# ============================================================================
|
||||
|
||||
def now_ns():
|
||||
return time.time_ns()
|
||||
|
||||
def ns_to_s(ns):
|
||||
return ns / 1_000_000_000
|
||||
|
||||
def ns_to_ms(ns):
|
||||
return ns / 1_000_000
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Index registry (subset of bench.py — only types relevant to deletion)
|
||||
# ============================================================================
|
||||
|
||||
def _vec0_flat_create(p):
|
||||
dims = p["dimensions"]
|
||||
variant = p.get("variant", "float")
|
||||
col = f"embedding float[{dims}]"
|
||||
if variant == "int8":
|
||||
col = f"embedding int8[{dims}]"
|
||||
elif variant == "bit":
|
||||
col = f"embedding bit[{dims}]"
|
||||
return f"CREATE VIRTUAL TABLE vec_items USING vec0(id INTEGER PRIMARY KEY, {col})"
|
||||
|
||||
def _rescore_create(p):
|
||||
dims = p["dimensions"]
|
||||
q = p.get("quantizer", "bit")
|
||||
os_val = p.get("oversample", 8)
|
||||
return (
|
||||
f"CREATE VIRTUAL TABLE vec_items USING vec0("
|
||||
f"id INTEGER PRIMARY KEY, "
|
||||
f"embedding float[{dims}] indexed by rescore(quantizer={q}, oversample={os_val}))"
|
||||
)
|
||||
|
||||
def _diskann_create(p):
|
||||
dims = p["dimensions"]
|
||||
R = p.get("R", 72)
|
||||
L = p.get("L", 128)
|
||||
q = p.get("quantizer", "binary")
|
||||
bt = p.get("buffer_threshold", 0)
|
||||
sl_insert = p.get("search_list_size_insert", 0)
|
||||
sl_search = p.get("search_list_size_search", 0)
|
||||
parts = [
|
||||
f"neighbor_quantizer={q}",
|
||||
f"n_neighbors={R}",
|
||||
f"buffer_threshold={bt}",
|
||||
]
|
||||
if sl_insert or sl_search:
|
||||
# Per-path overrides — don't also set search_list_size
|
||||
if sl_insert:
|
||||
parts.append(f"search_list_size_insert={sl_insert}")
|
||||
if sl_search:
|
||||
parts.append(f"search_list_size_search={sl_search}")
|
||||
else:
|
||||
parts.append(f"search_list_size={L}")
|
||||
opts = ", ".join(parts)
|
||||
return (
|
||||
f"CREATE VIRTUAL TABLE vec_items USING vec0("
|
||||
f"id INTEGER PRIMARY KEY, "
|
||||
f"embedding float[{dims}] indexed by diskann({opts}))"
|
||||
)
|
||||
|
||||
def _ivf_create(p):
|
||||
dims = p["dimensions"]
|
||||
nlist = p.get("nlist", 128)
|
||||
nprobe = p.get("nprobe", 16)
|
||||
q = p.get("quantizer", "none")
|
||||
os_val = p.get("oversample", 1)
|
||||
parts = [f"nlist={nlist}", f"nprobe={nprobe}"]
|
||||
if q != "none":
|
||||
parts.append(f"quantizer={q}")
|
||||
if os_val > 1:
|
||||
parts.append(f"oversample={os_val}")
|
||||
opts = ", ".join(parts)
|
||||
return (
|
||||
f"CREATE VIRTUAL TABLE vec_items USING vec0("
|
||||
f"id INTEGER PRIMARY KEY, "
|
||||
f"embedding float[{dims}] indexed by ivf({opts}))"
|
||||
)
|
||||
|
||||
|
||||
INDEX_REGISTRY = {
|
||||
"vec0-flat": {
|
||||
"defaults": {"variant": "float"},
|
||||
"create_table_sql": _vec0_flat_create,
|
||||
"post_insert_hook": None,
|
||||
},
|
||||
"rescore": {
|
||||
"defaults": {"quantizer": "bit", "oversample": 8},
|
||||
"create_table_sql": _rescore_create,
|
||||
"post_insert_hook": None,
|
||||
},
|
||||
"ivf": {
|
||||
"defaults": {"nlist": 128, "nprobe": 16, "quantizer": "none",
|
||||
"oversample": 1},
|
||||
"create_table_sql": _ivf_create,
|
||||
"post_insert_hook": lambda conn, params: _ivf_train(conn),
|
||||
},
|
||||
"diskann": {
|
||||
"defaults": {"R": 72, "L": 128, "quantizer": "binary",
|
||||
"buffer_threshold": 0},
|
||||
"create_table_sql": _diskann_create,
|
||||
"post_insert_hook": None,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _ivf_train(conn):
|
||||
"""Trigger built-in k-means training for IVF."""
|
||||
t0 = now_ns()
|
||||
conn.execute("INSERT INTO vec_items(vec_items) VALUES ('compute-centroids')")
|
||||
conn.commit()
|
||||
return ns_to_s(now_ns() - t0)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Config parsing (same format as bench.py)
|
||||
# ============================================================================
|
||||
|
||||
INT_KEYS = {"R", "L", "oversample", "nlist", "nprobe", "buffer_threshold",
|
||||
"search_list_size_insert", "search_list_size_search"}
|
||||
|
||||
def parse_config(spec):
|
||||
if ":" not in spec:
|
||||
raise ValueError(f"Config must be 'name:key=val,...': {spec}")
|
||||
name, rest = spec.split(":", 1)
|
||||
params = {}
|
||||
for kv in rest.split(","):
|
||||
k, v = kv.split("=", 1)
|
||||
k = k.strip()
|
||||
v = v.strip()
|
||||
if k in INT_KEYS:
|
||||
v = int(v)
|
||||
params[k] = v
|
||||
index_type = params.pop("type", None)
|
||||
if not index_type or index_type not in INDEX_REGISTRY:
|
||||
raise ValueError(f"Unknown index type: {index_type}")
|
||||
params["index_type"] = index_type
|
||||
merged = dict(INDEX_REGISTRY[index_type]["defaults"])
|
||||
merged.update(params)
|
||||
return name, merged
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DB helpers
|
||||
# ============================================================================
|
||||
|
||||
def create_bench_db(db_path, ext_path, base_db, page_size=4096):
|
||||
if os.path.exists(db_path):
|
||||
os.remove(db_path)
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.execute(f"PRAGMA page_size={page_size}")
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.enable_load_extension(True)
|
||||
conn.load_extension(ext_path)
|
||||
conn.execute(f"ATTACH DATABASE '{base_db}' AS base")
|
||||
return conn
|
||||
|
||||
|
||||
def load_query_vectors(base_db, n):
|
||||
conn = sqlite3.connect(base_db)
|
||||
rows = conn.execute(
|
||||
"SELECT id, vector FROM query_vectors LIMIT ?", (n,)
|
||||
).fetchall()
|
||||
conn.close()
|
||||
return rows
|
||||
|
||||
|
||||
def insert_loop(conn, subset_size, label, start_from=0):
|
||||
insert_sql = (
|
||||
"INSERT INTO vec_items(id, embedding) "
|
||||
"SELECT id, vector FROM base.train "
|
||||
"WHERE id >= :lo AND id < :hi"
|
||||
)
|
||||
total = 0
|
||||
for lo in range(start_from, subset_size, INSERT_BATCH_SIZE):
|
||||
hi = min(lo + INSERT_BATCH_SIZE, subset_size)
|
||||
conn.execute(insert_sql, {"lo": lo, "hi": hi})
|
||||
conn.commit()
|
||||
total += hi - lo
|
||||
if total % 5000 == 0 or total == subset_size - start_from:
|
||||
print(f" [{label}] inserted {total + start_from}/{subset_size}", flush=True)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Recall measurement
|
||||
# ============================================================================
|
||||
|
||||
def measure_recall(conn, base_db, query_vectors, subset_size, k, alive_ids=None):
|
||||
"""Measure KNN recall. If alive_ids is provided, ground truth is computed
|
||||
only over those IDs (to match post-delete state)."""
|
||||
recalls = []
|
||||
times_ms = []
|
||||
|
||||
for qid, query in query_vectors:
|
||||
t0 = now_ns()
|
||||
results = conn.execute(
|
||||
"SELECT id, distance FROM vec_items "
|
||||
"WHERE embedding MATCH :query AND k = :k",
|
||||
{"query": query, "k": k},
|
||||
).fetchall()
|
||||
t1 = now_ns()
|
||||
times_ms.append(ns_to_ms(t1 - t0))
|
||||
|
||||
result_ids = set(r[0] for r in results)
|
||||
|
||||
# Ground truth: brute-force cosine over surviving rows
|
||||
if alive_ids is not None:
|
||||
# After deletion — compute GT only over alive IDs
|
||||
# Use a temp table for the alive set for efficiency
|
||||
gt_rows = conn.execute(
|
||||
"SELECT id FROM ("
|
||||
" SELECT id, vec_distance_l2(vector, :query) as dist "
|
||||
" FROM base.train WHERE id < :n ORDER BY dist LIMIT :k2"
|
||||
")",
|
||||
{"query": query, "k2": k * 5, "n": subset_size},
|
||||
).fetchall()
|
||||
# Filter to only alive IDs, take top k
|
||||
gt_alive = [r[0] for r in gt_rows if r[0] in alive_ids][:k]
|
||||
gt_ids = set(gt_alive)
|
||||
else:
|
||||
gt_rows = conn.execute(
|
||||
"SELECT id FROM ("
|
||||
" SELECT id, vec_distance_l2(vector, :query) as dist "
|
||||
" FROM base.train WHERE id < :n ORDER BY dist LIMIT :k"
|
||||
")",
|
||||
{"query": query, "k": k, "n": subset_size},
|
||||
).fetchall()
|
||||
gt_ids = set(r[0] for r in gt_rows)
|
||||
|
||||
if gt_ids:
|
||||
recalls.append(len(result_ids & gt_ids) / len(gt_ids))
|
||||
else:
|
||||
recalls.append(0.0)
|
||||
|
||||
return {
|
||||
"recall": round(statistics.mean(recalls), 4) if recalls else 0.0,
|
||||
"mean_ms": round(statistics.mean(times_ms), 2) if times_ms else 0.0,
|
||||
"median_ms": round(statistics.median(times_ms), 2) if times_ms else 0.0,
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Delete benchmark core
|
||||
# ============================================================================
|
||||
|
||||
def run_delete_benchmark(name, params, base_db, ext_path, subset_size, dims,
|
||||
delete_pcts, k, n_queries, out_dir, seed_val):
|
||||
params["dimensions"] = dims
|
||||
reg = INDEX_REGISTRY[params["index_type"]]
|
||||
create_sql = reg["create_table_sql"](params)
|
||||
|
||||
results = []
|
||||
|
||||
# Build once, copy for each delete %
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Config: {name} (type={params['index_type']})")
|
||||
print(f"{'='*60}")
|
||||
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
master_db_path = os.path.join(out_dir, f"{name}.{subset_size}.db")
|
||||
print(f" Building index ({subset_size} vectors)...")
|
||||
build_t0 = now_ns()
|
||||
conn = create_bench_db(master_db_path, ext_path, base_db)
|
||||
conn.execute(create_sql)
|
||||
insert_loop(conn, subset_size, name)
|
||||
hook = reg.get("post_insert_hook")
|
||||
if hook:
|
||||
print(f" Training...")
|
||||
hook(conn, params)
|
||||
conn.close()
|
||||
build_time_s = ns_to_s(now_ns() - build_t0)
|
||||
master_size = os.path.getsize(master_db_path)
|
||||
print(f" Built in {build_time_s:.1f}s ({master_size / (1024*1024):.1f} MB)")
|
||||
|
||||
# Load query vectors once
|
||||
query_vectors = load_query_vectors(base_db, n_queries)
|
||||
|
||||
# Measure pre-delete baseline on the master copy
|
||||
print(f"\n --- 0% deleted (baseline) ---")
|
||||
conn = sqlite3.connect(master_db_path)
|
||||
conn.enable_load_extension(True)
|
||||
conn.load_extension(ext_path)
|
||||
conn.execute(f"ATTACH DATABASE '{base_db}' AS base")
|
||||
baseline = measure_recall(conn, base_db, query_vectors, subset_size, k)
|
||||
conn.close()
|
||||
print(f" recall={baseline['recall']:.4f} "
|
||||
f"query={baseline['mean_ms']:.2f}ms")
|
||||
|
||||
results.append({
|
||||
"name": name,
|
||||
"index_type": params["index_type"],
|
||||
"subset_size": subset_size,
|
||||
"delete_pct": 0,
|
||||
"n_deleted": 0,
|
||||
"n_remaining": subset_size,
|
||||
"recall": baseline["recall"],
|
||||
"query_mean_ms": baseline["mean_ms"],
|
||||
"query_median_ms": baseline["median_ms"],
|
||||
"db_size_mb": round(master_size / (1024 * 1024), 2),
|
||||
"build_time_s": round(build_time_s, 1),
|
||||
"delete_time_s": 0.0,
|
||||
"vacuum_size_mb": round(master_size / (1024 * 1024), 2),
|
||||
})
|
||||
|
||||
# All IDs in the dataset
|
||||
all_ids = list(range(subset_size))
|
||||
|
||||
for pct in sorted(delete_pcts):
|
||||
n_delete = int(subset_size * pct / 100)
|
||||
print(f"\n --- {pct}% deleted ({n_delete} rows) ---")
|
||||
|
||||
# Copy master DB and work on the copy
|
||||
copy_path = os.path.join(out_dir, f"{name}.{subset_size}.del{pct}.db")
|
||||
shutil.copy2(master_db_path, copy_path)
|
||||
# Also copy WAL/SHM if they exist
|
||||
for suffix in ["-wal", "-shm"]:
|
||||
src = master_db_path + suffix
|
||||
if os.path.exists(src):
|
||||
shutil.copy2(src, copy_path + suffix)
|
||||
|
||||
conn = sqlite3.connect(copy_path)
|
||||
conn.enable_load_extension(True)
|
||||
conn.load_extension(ext_path)
|
||||
conn.execute(f"ATTACH DATABASE '{base_db}' AS base")
|
||||
|
||||
# Pick random IDs to delete (deterministic per pct)
|
||||
rng = random.Random(seed_val + pct)
|
||||
to_delete = set(rng.sample(all_ids, n_delete))
|
||||
alive_ids = set(all_ids) - to_delete
|
||||
|
||||
# Delete
|
||||
delete_t0 = now_ns()
|
||||
batch = []
|
||||
for i, rid in enumerate(to_delete):
|
||||
batch.append(rid)
|
||||
if len(batch) >= 500 or i == len(to_delete) - 1:
|
||||
placeholders = ",".join("?" for _ in batch)
|
||||
conn.execute(
|
||||
f"DELETE FROM vec_items WHERE id IN ({placeholders})",
|
||||
batch,
|
||||
)
|
||||
conn.commit()
|
||||
batch = []
|
||||
delete_time_s = ns_to_s(now_ns() - delete_t0)
|
||||
|
||||
remaining = conn.execute("SELECT count(*) FROM vec_items").fetchone()[0]
|
||||
pre_vacuum_size = os.path.getsize(copy_path)
|
||||
print(f" deleted {n_delete} rows in {delete_time_s:.2f}s "
|
||||
f"({remaining} remaining)")
|
||||
|
||||
# Measure post-delete recall
|
||||
post = measure_recall(conn, base_db, query_vectors, subset_size, k,
|
||||
alive_ids=alive_ids)
|
||||
print(f" recall={post['recall']:.4f} "
|
||||
f"(delta={post['recall'] - baseline['recall']:+.4f}) "
|
||||
f"query={post['mean_ms']:.2f}ms")
|
||||
|
||||
# VACUUM and measure size savings — close fully, reopen without base
|
||||
conn.close()
|
||||
vconn = sqlite3.connect(copy_path)
|
||||
vconn.execute("VACUUM")
|
||||
vconn.close()
|
||||
post_vacuum_size = os.path.getsize(copy_path)
|
||||
saved_mb = (pre_vacuum_size - post_vacuum_size) / (1024 * 1024)
|
||||
print(f" size: {pre_vacuum_size/(1024*1024):.1f} MB -> "
|
||||
f"{post_vacuum_size/(1024*1024):.1f} MB after VACUUM "
|
||||
f"(saved {saved_mb:.1f} MB)")
|
||||
|
||||
results.append({
|
||||
"name": name,
|
||||
"index_type": params["index_type"],
|
||||
"subset_size": subset_size,
|
||||
"delete_pct": pct,
|
||||
"n_deleted": n_delete,
|
||||
"n_remaining": remaining,
|
||||
"recall": post["recall"],
|
||||
"query_mean_ms": post["mean_ms"],
|
||||
"query_median_ms": post["median_ms"],
|
||||
"db_size_mb": round(pre_vacuum_size / (1024 * 1024), 2),
|
||||
"build_time_s": round(build_time_s, 1),
|
||||
"delete_time_s": round(delete_time_s, 2),
|
||||
"vacuum_size_mb": round(post_vacuum_size / (1024 * 1024), 2),
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Results DB
|
||||
# ============================================================================
|
||||
|
||||
RESULTS_SCHEMA = """\
|
||||
CREATE TABLE IF NOT EXISTS delete_runs (
|
||||
run_id INTEGER PRIMARY KEY,
|
||||
config_name TEXT NOT NULL,
|
||||
index_type TEXT NOT NULL,
|
||||
params TEXT,
|
||||
dataset TEXT NOT NULL,
|
||||
subset_size INTEGER NOT NULL,
|
||||
delete_pct INTEGER NOT NULL,
|
||||
n_deleted INTEGER NOT NULL,
|
||||
n_remaining INTEGER NOT NULL,
|
||||
k INTEGER NOT NULL,
|
||||
n_queries INTEGER NOT NULL,
|
||||
seed INTEGER NOT NULL,
|
||||
recall REAL,
|
||||
query_mean_ms REAL,
|
||||
query_median_ms REAL,
|
||||
db_size_mb REAL,
|
||||
vacuum_size_mb REAL,
|
||||
build_time_s REAL,
|
||||
delete_time_s REAL,
|
||||
created_at TEXT DEFAULT (datetime('now'))
|
||||
);
|
||||
"""
|
||||
|
||||
def save_results(results, out_dir, dataset, subset_size, params_json, k, n_queries, seed_val):
|
||||
db_path = os.path.join(out_dir, "delete_results.db")
|
||||
db = sqlite3.connect(db_path)
|
||||
db.execute("PRAGMA journal_mode=WAL")
|
||||
db.executescript(RESULTS_SCHEMA)
|
||||
for r in results:
|
||||
db.execute(
|
||||
"INSERT INTO delete_runs "
|
||||
"(config_name, index_type, params, dataset, subset_size, "
|
||||
" delete_pct, n_deleted, n_remaining, k, n_queries, seed, "
|
||||
" recall, query_mean_ms, query_median_ms, "
|
||||
" db_size_mb, vacuum_size_mb, build_time_s, delete_time_s) "
|
||||
"VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)",
|
||||
(
|
||||
r["name"], r["index_type"], params_json, dataset, r["subset_size"],
|
||||
r["delete_pct"], r["n_deleted"], r["n_remaining"], k, n_queries, seed_val,
|
||||
r["recall"], r["query_mean_ms"], r["query_median_ms"],
|
||||
r["db_size_mb"], r["vacuum_size_mb"], r["build_time_s"], r["delete_time_s"],
|
||||
),
|
||||
)
|
||||
db.commit()
|
||||
db.close()
|
||||
return db_path
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Reporting
|
||||
# ============================================================================
|
||||
|
||||
def print_report(all_results):
|
||||
print(f"\n{'name':>22} {'del%':>5} {'deleted':>8} {'remain':>8} "
|
||||
f"{'recall':>7} {'delta':>7} {'qry(ms)':>8} "
|
||||
f"{'size(MB)':>9} {'vacuumed':>9} {'del(s)':>7}")
|
||||
print("-" * 110)
|
||||
|
||||
# Group by config name
|
||||
configs = {}
|
||||
for r in all_results:
|
||||
configs.setdefault(r["name"], []).append(r)
|
||||
|
||||
for name, rows in configs.items():
|
||||
baseline_recall = rows[0]["recall"] # 0% delete is always first
|
||||
for r in rows:
|
||||
delta = r["recall"] - baseline_recall
|
||||
delta_str = f"{delta:+.4f}" if r["delete_pct"] > 0 else "-"
|
||||
print(
|
||||
f"{r['name']:>22} {r['delete_pct']:>4}% "
|
||||
f"{r['n_deleted']:>8} {r['n_remaining']:>8} "
|
||||
f"{r['recall']:>7.4f} {delta_str:>7} {r['query_mean_ms']:>8.2f} "
|
||||
f"{r['db_size_mb']:>9.1f} {r['vacuum_size_mb']:>9.1f} "
|
||||
f"{r['delete_time_s']:>7.2f}"
|
||||
)
|
||||
print()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Main
|
||||
# ============================================================================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark recall degradation after random row deletion",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__,
|
||||
)
|
||||
parser.add_argument("configs", nargs="+",
|
||||
help="config specs (name:type=X,key=val,...)")
|
||||
parser.add_argument("--subset-size", type=int, default=10000,
|
||||
help="number of vectors to build (default: 10000)")
|
||||
parser.add_argument("--delete-pct", type=str, default="10,25,50",
|
||||
help="comma-separated delete percentages (default: 10,25,50)")
|
||||
parser.add_argument("-k", type=int, default=10, help="KNN k (default 10)")
|
||||
parser.add_argument("-n", type=int, default=50,
|
||||
help="number of queries (default 50)")
|
||||
parser.add_argument("--dataset", default="cohere1m",
|
||||
choices=list(DATASETS.keys()))
|
||||
parser.add_argument("--ext", default=EXT_PATH)
|
||||
parser.add_argument("-o", "--out-dir",
|
||||
default=os.path.join(_SCRIPT_DIR, "runs"))
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="random seed for delete selection (default: 42)")
|
||||
args = parser.parse_args()
|
||||
|
||||
ds = DATASETS[args.dataset]
|
||||
base_db = ds["base_db"]
|
||||
dims = ds["dimensions"]
|
||||
if not os.path.exists(base_db):
|
||||
print(f"Error: dataset not found at {base_db}")
|
||||
print(f"Run: make -C {os.path.dirname(base_db)}")
|
||||
return 1
|
||||
|
||||
delete_pcts = [int(x.strip()) for x in args.delete_pct.split(",")]
|
||||
for p in delete_pcts:
|
||||
if not 0 < p < 100:
|
||||
print(f"Error: delete percentage must be 1-99, got {p}")
|
||||
return 1
|
||||
|
||||
out_dir = os.path.join(args.out_dir, args.dataset, str(args.subset_size))
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
all_results = []
|
||||
for spec in args.configs:
|
||||
name, params = parse_config(spec)
|
||||
params_json = json.dumps(params)
|
||||
results = run_delete_benchmark(
|
||||
name, params, base_db, args.ext, args.subset_size, dims,
|
||||
delete_pcts, args.k, args.n, out_dir, args.seed,
|
||||
)
|
||||
all_results.extend(results)
|
||||
|
||||
save_results(results, out_dir, args.dataset, args.subset_size,
|
||||
params_json, args.k, args.n, args.seed)
|
||||
|
||||
print_report(all_results)
|
||||
|
||||
results_path = os.path.join(out_dir, "delete_results.db")
|
||||
print(f"\nResults saved to: {results_path}")
|
||||
print(f"Query: sqlite3 {results_path} "
|
||||
f"\"SELECT config_name, delete_pct, recall, vacuum_size_mb "
|
||||
f"FROM delete_runs ORDER BY config_name, delete_pct\"")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
124
benchmarks-ann/bench-delete/test_smoke.py
Normal file
124
benchmarks-ann/bench-delete/test_smoke.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Quick self-contained smoke test using a synthetic dataset.
|
||||
Creates a tiny base.db in a temp dir, runs the delete benchmark, verifies output.
|
||||
"""
|
||||
import os
|
||||
import random
|
||||
import sqlite3
|
||||
import struct
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
_ROOT_DIR = os.path.join(_SCRIPT_DIR, "..", "..")
|
||||
EXT_PATH = os.path.join(_ROOT_DIR, "dist", "vec0")
|
||||
|
||||
DIMS = 8
|
||||
N_TRAIN = 200
|
||||
N_QUERIES = 10
|
||||
K_NEIGHBORS = 5
|
||||
|
||||
|
||||
def _f32(vals):
|
||||
return struct.pack(f"{len(vals)}f", *vals)
|
||||
|
||||
|
||||
def make_synthetic_base_db(path):
|
||||
"""Create a minimal base.db with train vectors and query vectors."""
|
||||
rng = random.Random(123)
|
||||
db = sqlite3.connect(path)
|
||||
db.execute("CREATE TABLE train(id INTEGER PRIMARY KEY, vector BLOB)")
|
||||
db.execute("CREATE TABLE query_vectors(id INTEGER PRIMARY KEY, vector BLOB)")
|
||||
|
||||
for i in range(N_TRAIN):
|
||||
vec = [rng.gauss(0, 1) for _ in range(DIMS)]
|
||||
db.execute("INSERT INTO train VALUES (?, ?)", (i, _f32(vec)))
|
||||
|
||||
for i in range(N_QUERIES):
|
||||
vec = [rng.gauss(0, 1) for _ in range(DIMS)]
|
||||
db.execute("INSERT INTO query_vectors VALUES (?, ?)", (i, _f32(vec)))
|
||||
|
||||
db.commit()
|
||||
db.close()
|
||||
|
||||
|
||||
def main():
|
||||
if not os.path.exists(EXT_PATH + ".dylib") and not os.path.exists(EXT_PATH + ".so"):
|
||||
# Try bare path (sqlite handles extension)
|
||||
pass
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
base_db = os.path.join(tmpdir, "base.db")
|
||||
make_synthetic_base_db(base_db)
|
||||
|
||||
# Patch DATASETS to use our synthetic DB
|
||||
import bench_delete
|
||||
bench_delete.DATASETS["synthetic"] = {
|
||||
"base_db": base_db,
|
||||
"dimensions": DIMS,
|
||||
}
|
||||
|
||||
out_dir = os.path.join(tmpdir, "runs")
|
||||
|
||||
# Test flat index
|
||||
print("=== Testing flat index ===")
|
||||
name, params = bench_delete.parse_config("flat:type=vec0-flat,variant=float")
|
||||
params["dimensions"] = DIMS
|
||||
results = bench_delete.run_delete_benchmark(
|
||||
name, params, base_db, EXT_PATH,
|
||||
subset_size=N_TRAIN, dims=DIMS,
|
||||
delete_pcts=[25, 50], k=K_NEIGHBORS, n_queries=N_QUERIES,
|
||||
out_dir=out_dir, seed_val=42,
|
||||
)
|
||||
|
||||
bench_delete.print_report(results)
|
||||
|
||||
# Flat recall should be 1.0 at all delete %
|
||||
for r in results:
|
||||
assert r["recall"] == 1.0, \
|
||||
f"Flat recall should be 1.0, got {r['recall']} at {r['delete_pct']}%"
|
||||
print("\n PASS: flat recall is 1.0 at all delete percentages\n")
|
||||
|
||||
# Test DiskANN
|
||||
print("=== Testing DiskANN ===")
|
||||
name2, params2 = bench_delete.parse_config(
|
||||
"diskann:type=diskann,R=8,L=32,quantizer=binary"
|
||||
)
|
||||
params2["dimensions"] = DIMS
|
||||
results2 = bench_delete.run_delete_benchmark(
|
||||
name2, params2, base_db, EXT_PATH,
|
||||
subset_size=N_TRAIN, dims=DIMS,
|
||||
delete_pcts=[25, 50], k=K_NEIGHBORS, n_queries=N_QUERIES,
|
||||
out_dir=out_dir, seed_val=42,
|
||||
)
|
||||
|
||||
bench_delete.print_report(results2)
|
||||
|
||||
# DiskANN baseline (0%) should have decent recall
|
||||
baseline = results2[0]
|
||||
assert baseline["recall"] > 0.0, \
|
||||
f"DiskANN baseline recall is zero"
|
||||
print(f" PASS: DiskANN baseline recall={baseline['recall']}")
|
||||
|
||||
# Test rescore
|
||||
print("\n=== Testing rescore ===")
|
||||
name3, params3 = bench_delete.parse_config(
|
||||
"rescore:type=rescore,quantizer=bit,oversample=4"
|
||||
)
|
||||
params3["dimensions"] = DIMS
|
||||
results3 = bench_delete.run_delete_benchmark(
|
||||
name3, params3, base_db, EXT_PATH,
|
||||
subset_size=N_TRAIN, dims=DIMS,
|
||||
delete_pcts=[25, 50], k=K_NEIGHBORS, n_queries=N_QUERIES,
|
||||
out_dir=out_dir, seed_val=42,
|
||||
)
|
||||
|
||||
bench_delete.print_report(results3)
|
||||
print(f" PASS: rescore baseline recall={results3[0]['recall']}")
|
||||
|
||||
print("\n ALL SMOKE TESTS PASSED")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
1350
benchmarks-ann/bench.py
Normal file
1350
benchmarks-ann/bench.py
Normal file
File diff suppressed because it is too large
Load diff
27
benchmarks-ann/datasets/cohere10m/Makefile
Normal file
27
benchmarks-ann/datasets/cohere10m/Makefile
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
BASE_URL = https://assets.zilliz.com/benchmark/cohere_large_10m
|
||||
|
||||
TRAIN_PARQUETS = $(shell printf 'train-%02d-of-10.parquet ' 0 1 2 3 4 5 6 7 8 9)
|
||||
OTHER_PARQUETS = test.parquet neighbors.parquet
|
||||
PARQUETS = $(TRAIN_PARQUETS) $(OTHER_PARQUETS)
|
||||
|
||||
.PHONY: all download clean
|
||||
|
||||
all: base.db
|
||||
|
||||
# Use: make -j12 download
|
||||
download: $(PARQUETS)
|
||||
|
||||
train-%-of-10.parquet:
|
||||
curl -L -o $@ $(BASE_URL)/$@
|
||||
|
||||
test.parquet:
|
||||
curl -L -o $@ $(BASE_URL)/test.parquet
|
||||
|
||||
neighbors.parquet:
|
||||
curl -L -o $@ $(BASE_URL)/neighbors.parquet
|
||||
|
||||
base.db: $(PARQUETS) build_base_db.py
|
||||
uv run --with pandas --with pyarrow python build_base_db.py
|
||||
|
||||
clean:
|
||||
rm -f base.db
|
||||
134
benchmarks-ann/datasets/cohere10m/build_base_db.py
Normal file
134
benchmarks-ann/datasets/cohere10m/build_base_db.py
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Build base.db from downloaded parquet files (10M dataset, 10 train shards).
|
||||
|
||||
Reads train-00-of-10.parquet .. train-09-of-10.parquet, test.parquet,
|
||||
neighbors.parquet and creates a SQLite database with tables:
|
||||
train, query_vectors, neighbors.
|
||||
|
||||
Usage:
|
||||
uv run --with pandas --with pyarrow python build_base_db.py
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import struct
|
||||
import sys
|
||||
import time
|
||||
|
||||
import pandas as pd
|
||||
|
||||
TRAIN_SHARDS = 10
|
||||
|
||||
|
||||
def float_list_to_blob(floats):
|
||||
"""Pack a list of floats into a little-endian f32 blob."""
|
||||
return struct.pack(f"<{len(floats)}f", *floats)
|
||||
|
||||
|
||||
def main():
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
db_path = os.path.join(script_dir, "base.db")
|
||||
|
||||
train_paths = [
|
||||
os.path.join(script_dir, f"train-{i:02d}-of-{TRAIN_SHARDS}.parquet")
|
||||
for i in range(TRAIN_SHARDS)
|
||||
]
|
||||
test_path = os.path.join(script_dir, "test.parquet")
|
||||
neighbors_path = os.path.join(script_dir, "neighbors.parquet")
|
||||
|
||||
for p in train_paths + [test_path, neighbors_path]:
|
||||
if not os.path.exists(p):
|
||||
print(f"ERROR: {p} not found. Run 'make download' first.")
|
||||
sys.exit(1)
|
||||
|
||||
if os.path.exists(db_path):
|
||||
os.remove(db_path)
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA page_size=4096")
|
||||
|
||||
# --- query_vectors (from test.parquet) ---
|
||||
print("Loading test.parquet (query vectors)...")
|
||||
t0 = time.perf_counter()
|
||||
df_test = pd.read_parquet(test_path)
|
||||
conn.execute(
|
||||
"CREATE TABLE query_vectors (id INTEGER PRIMARY KEY, vector BLOB)"
|
||||
)
|
||||
rows = []
|
||||
for _, row in df_test.iterrows():
|
||||
rows.append((int(row["id"]), float_list_to_blob(row["emb"])))
|
||||
conn.executemany("INSERT INTO query_vectors (id, vector) VALUES (?, ?)", rows)
|
||||
conn.commit()
|
||||
print(f" {len(rows)} query vectors in {time.perf_counter() - t0:.1f}s")
|
||||
|
||||
# --- neighbors (from neighbors.parquet) ---
|
||||
print("Loading neighbors.parquet...")
|
||||
t0 = time.perf_counter()
|
||||
df_neighbors = pd.read_parquet(neighbors_path)
|
||||
conn.execute(
|
||||
"CREATE TABLE neighbors ("
|
||||
" query_vector_id INTEGER, rank INTEGER, neighbors_id TEXT,"
|
||||
" UNIQUE(query_vector_id, rank))"
|
||||
)
|
||||
rows = []
|
||||
for _, row in df_neighbors.iterrows():
|
||||
qid = int(row["id"])
|
||||
nids = row["neighbors_id"]
|
||||
if isinstance(nids, str):
|
||||
nids = json.loads(nids)
|
||||
for rank, nid in enumerate(nids):
|
||||
rows.append((qid, rank, str(int(nid))))
|
||||
conn.executemany(
|
||||
"INSERT INTO neighbors (query_vector_id, rank, neighbors_id) VALUES (?, ?, ?)",
|
||||
rows,
|
||||
)
|
||||
conn.commit()
|
||||
print(f" {len(rows)} neighbor rows in {time.perf_counter() - t0:.1f}s")
|
||||
|
||||
# --- train (from 10 shard parquets) ---
|
||||
print(f"Loading {TRAIN_SHARDS} train shards (10M vectors, this will take a while)...")
|
||||
conn.execute(
|
||||
"CREATE TABLE train (id INTEGER PRIMARY KEY, vector BLOB)"
|
||||
)
|
||||
|
||||
global_t0 = time.perf_counter()
|
||||
total_inserted = 0
|
||||
batch_size = 10000
|
||||
|
||||
for shard_idx, train_path in enumerate(train_paths):
|
||||
print(f" Shard {shard_idx + 1}/{TRAIN_SHARDS}: {os.path.basename(train_path)}")
|
||||
t0 = time.perf_counter()
|
||||
df = pd.read_parquet(train_path)
|
||||
shard_len = len(df)
|
||||
|
||||
for start in range(0, shard_len, batch_size):
|
||||
chunk = df.iloc[start : start + batch_size]
|
||||
rows = []
|
||||
for _, row in chunk.iterrows():
|
||||
rows.append((int(row["id"]), float_list_to_blob(row["emb"])))
|
||||
conn.executemany("INSERT INTO train (id, vector) VALUES (?, ?)", rows)
|
||||
conn.commit()
|
||||
|
||||
total_inserted += len(rows)
|
||||
if total_inserted % 100000 < batch_size:
|
||||
elapsed = time.perf_counter() - global_t0
|
||||
rate = total_inserted / elapsed if elapsed > 0 else 0
|
||||
print(
|
||||
f" {total_inserted:>10} {elapsed:.0f}s {rate:.0f} rows/s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
shard_elapsed = time.perf_counter() - t0
|
||||
print(f" shard done: {shard_len} rows in {shard_elapsed:.1f}s")
|
||||
|
||||
elapsed = time.perf_counter() - global_t0
|
||||
print(f" {total_inserted} train vectors in {elapsed:.1f}s")
|
||||
|
||||
conn.close()
|
||||
size_mb = os.path.getsize(db_path) / (1024 * 1024)
|
||||
print(f"\nDone: {db_path} ({size_mb:.0f} MB)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
2
benchmarks-ann/datasets/cohere1m/.gitignore
vendored
Normal file
2
benchmarks-ann/datasets/cohere1m/.gitignore
vendored
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
*.parquet
|
||||
base.db
|
||||
24
benchmarks-ann/datasets/cohere1m/Makefile
Normal file
24
benchmarks-ann/datasets/cohere1m/Makefile
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
BASE_URL = https://assets.zilliz.com/benchmark/cohere_medium_1m
|
||||
|
||||
PARQUETS = train.parquet test.parquet neighbors.parquet
|
||||
|
||||
.PHONY: all download base.db clean
|
||||
|
||||
all: base.db
|
||||
|
||||
download: $(PARQUETS)
|
||||
|
||||
train.parquet:
|
||||
curl -L -o $@ $(BASE_URL)/train.parquet
|
||||
|
||||
test.parquet:
|
||||
curl -L -o $@ $(BASE_URL)/test.parquet
|
||||
|
||||
neighbors.parquet:
|
||||
curl -L -o $@ $(BASE_URL)/neighbors.parquet
|
||||
|
||||
base.db: $(PARQUETS) build_base_db.py
|
||||
uv run --with pandas --with pyarrow python build_base_db.py
|
||||
|
||||
clean:
|
||||
rm -f base.db
|
||||
121
benchmarks-ann/datasets/cohere1m/build_base_db.py
Normal file
121
benchmarks-ann/datasets/cohere1m/build_base_db.py
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Build base.db from downloaded parquet files.
|
||||
|
||||
Reads train.parquet, test.parquet, neighbors.parquet and creates a SQLite
|
||||
database with tables: train, query_vectors, neighbors.
|
||||
|
||||
Usage:
|
||||
uv run --with pandas --with pyarrow python build_base_db.py
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import struct
|
||||
import sys
|
||||
import time
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def float_list_to_blob(floats):
|
||||
"""Pack a list of floats into a little-endian f32 blob."""
|
||||
return struct.pack(f"<{len(floats)}f", *floats)
|
||||
|
||||
|
||||
def main():
|
||||
seed_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
db_path = os.path.join(seed_dir, "base.db")
|
||||
|
||||
train_path = os.path.join(seed_dir, "train.parquet")
|
||||
test_path = os.path.join(seed_dir, "test.parquet")
|
||||
neighbors_path = os.path.join(seed_dir, "neighbors.parquet")
|
||||
|
||||
for p in (train_path, test_path, neighbors_path):
|
||||
if not os.path.exists(p):
|
||||
print(f"ERROR: {p} not found. Run 'make download' first.")
|
||||
sys.exit(1)
|
||||
|
||||
if os.path.exists(db_path):
|
||||
os.remove(db_path)
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA page_size=4096")
|
||||
|
||||
# --- query_vectors (from test.parquet) ---
|
||||
print("Loading test.parquet (query vectors)...")
|
||||
t0 = time.perf_counter()
|
||||
df_test = pd.read_parquet(test_path)
|
||||
conn.execute(
|
||||
"CREATE TABLE query_vectors (id INTEGER PRIMARY KEY, vector BLOB)"
|
||||
)
|
||||
rows = []
|
||||
for _, row in df_test.iterrows():
|
||||
rows.append((int(row["id"]), float_list_to_blob(row["emb"])))
|
||||
conn.executemany("INSERT INTO query_vectors (id, vector) VALUES (?, ?)", rows)
|
||||
conn.commit()
|
||||
print(f" {len(rows)} query vectors in {time.perf_counter() - t0:.1f}s")
|
||||
|
||||
# --- neighbors (from neighbors.parquet) ---
|
||||
print("Loading neighbors.parquet...")
|
||||
t0 = time.perf_counter()
|
||||
df_neighbors = pd.read_parquet(neighbors_path)
|
||||
conn.execute(
|
||||
"CREATE TABLE neighbors ("
|
||||
" query_vector_id INTEGER, rank INTEGER, neighbors_id TEXT,"
|
||||
" UNIQUE(query_vector_id, rank))"
|
||||
)
|
||||
rows = []
|
||||
for _, row in df_neighbors.iterrows():
|
||||
qid = int(row["id"])
|
||||
# neighbors_id may be a numpy array or JSON string
|
||||
nids = row["neighbors_id"]
|
||||
if isinstance(nids, str):
|
||||
nids = json.loads(nids)
|
||||
for rank, nid in enumerate(nids):
|
||||
rows.append((qid, rank, str(int(nid))))
|
||||
conn.executemany(
|
||||
"INSERT INTO neighbors (query_vector_id, rank, neighbors_id) VALUES (?, ?, ?)",
|
||||
rows,
|
||||
)
|
||||
conn.commit()
|
||||
print(f" {len(rows)} neighbor rows in {time.perf_counter() - t0:.1f}s")
|
||||
|
||||
# --- train (from train.parquet) ---
|
||||
print("Loading train.parquet (1M vectors, this takes a few minutes)...")
|
||||
t0 = time.perf_counter()
|
||||
conn.execute(
|
||||
"CREATE TABLE train (id INTEGER PRIMARY KEY, vector BLOB)"
|
||||
)
|
||||
|
||||
batch_size = 10000
|
||||
df_iter = pd.read_parquet(train_path)
|
||||
total = len(df_iter)
|
||||
|
||||
for start in range(0, total, batch_size):
|
||||
chunk = df_iter.iloc[start : start + batch_size]
|
||||
rows = []
|
||||
for _, row in chunk.iterrows():
|
||||
rows.append((int(row["id"]), float_list_to_blob(row["emb"])))
|
||||
conn.executemany("INSERT INTO train (id, vector) VALUES (?, ?)", rows)
|
||||
conn.commit()
|
||||
|
||||
done = min(start + batch_size, total)
|
||||
elapsed = time.perf_counter() - t0
|
||||
rate = done / elapsed if elapsed > 0 else 0
|
||||
eta = (total - done) / rate if rate > 0 else 0
|
||||
print(
|
||||
f" {done:>8}/{total} {elapsed:.0f}s {rate:.0f} rows/s eta {eta:.0f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
elapsed = time.perf_counter() - t0
|
||||
print(f" {total} train vectors in {elapsed:.1f}s")
|
||||
|
||||
conn.close()
|
||||
size_mb = os.path.getsize(db_path) / (1024 * 1024)
|
||||
print(f"\nDone: {db_path} ({size_mb:.0f} MB)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
30
benchmarks-ann/datasets/nyt-1024/Makefile
Normal file
30
benchmarks-ann/datasets/nyt-1024/Makefile
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
MODEL ?= mixedbread-ai/mxbai-embed-large-v1
|
||||
K ?= 100
|
||||
BATCH_SIZE ?= 256
|
||||
DATA_DIR ?= ../nyt/data
|
||||
|
||||
all: base.db
|
||||
|
||||
# Reuse data from ../nyt
|
||||
$(DATA_DIR):
|
||||
$(MAKE) -C ../nyt data
|
||||
|
||||
contents.db: $(DATA_DIR)
|
||||
uv run ../nyt-768/build-contents.py --data-dir $(DATA_DIR) -o $@
|
||||
|
||||
base.db: contents.db queries.txt
|
||||
uv run build-base.py \
|
||||
--contents-db contents.db \
|
||||
--model $(MODEL) \
|
||||
--queries-file queries.txt \
|
||||
--batch-size $(BATCH_SIZE) \
|
||||
--k $(K) \
|
||||
-o $@
|
||||
|
||||
queries.txt:
|
||||
cp ../nyt/queries.txt $@
|
||||
|
||||
clean:
|
||||
rm -f base.db contents.db
|
||||
|
||||
.PHONY: all clean
|
||||
163
benchmarks-ann/datasets/nyt-1024/build-base.py
Normal file
163
benchmarks-ann/datasets/nyt-1024/build-base.py
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
# /// script
|
||||
# requires-python = ">=3.12"
|
||||
# dependencies = [
|
||||
# "sentence-transformers",
|
||||
# "torch<=2.7",
|
||||
# "tqdm",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
import argparse
|
||||
import sqlite3
|
||||
from array import array
|
||||
from itertools import batched
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Build base.db with train vectors, query vectors, and brute-force KNN neighbors",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--contents-db", "-c", default=None,
|
||||
help="Path to contents.db (source of headlines and IDs)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model", "-m", default="mixedbread-ai/mxbai-embed-large-v1",
|
||||
help="HuggingFace model ID (default: mixedbread-ai/mxbai-embed-large-v1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--queries-file", "-q", default="queries.txt",
|
||||
help="Path to the queries file (default: queries.txt)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", "-o", required=True,
|
||||
help="Path to the output base.db",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size", "-b", type=int, default=256,
|
||||
help="Batch size for embedding (default: 256)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--k", "-k", type=int, default=100,
|
||||
help="Number of nearest neighbors (default: 100)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--limit", "-l", type=int, default=0,
|
||||
help="Limit number of headlines to embed (0 = all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vec-path", "-v", default="~/projects/sqlite-vec/dist/vec0",
|
||||
help="Path to sqlite-vec extension (default: ~/projects/sqlite-vec/dist/vec0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-neighbors", action="store_true",
|
||||
help="Skip the brute-force KNN neighbor computation",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
import os
|
||||
vec_path = os.path.expanduser(args.vec_path)
|
||||
|
||||
print(f"Loading model {args.model}...")
|
||||
model = SentenceTransformer(args.model)
|
||||
|
||||
# Read headlines from contents.db
|
||||
src = sqlite3.connect(args.contents_db)
|
||||
limit_clause = f" LIMIT {args.limit}" if args.limit > 0 else ""
|
||||
headlines = src.execute(
|
||||
f"SELECT id, headline FROM contents ORDER BY id{limit_clause}"
|
||||
).fetchall()
|
||||
src.close()
|
||||
print(f"Loaded {len(headlines)} headlines from {args.contents_db}")
|
||||
|
||||
# Read queries
|
||||
with open(args.queries_file) as f:
|
||||
queries = [line.strip() for line in f if line.strip()]
|
||||
print(f"Loaded {len(queries)} queries from {args.queries_file}")
|
||||
|
||||
# Create output database
|
||||
db = sqlite3.connect(args.output)
|
||||
db.enable_load_extension(True)
|
||||
db.load_extension(vec_path)
|
||||
db.enable_load_extension(False)
|
||||
|
||||
db.execute("CREATE TABLE IF NOT EXISTS train(id INTEGER PRIMARY KEY, vector BLOB)")
|
||||
db.execute("CREATE TABLE IF NOT EXISTS query_vectors(id INTEGER PRIMARY KEY, vector BLOB)")
|
||||
db.execute(
|
||||
"CREATE TABLE IF NOT EXISTS neighbors("
|
||||
" query_vector_id INTEGER, rank INTEGER, neighbors_id TEXT,"
|
||||
" UNIQUE(query_vector_id, rank))"
|
||||
)
|
||||
|
||||
# Step 1: Embed headlines -> train table
|
||||
print("Embedding headlines...")
|
||||
for batch in tqdm(
|
||||
batched(headlines, args.batch_size),
|
||||
total=(len(headlines) + args.batch_size - 1) // args.batch_size,
|
||||
):
|
||||
ids = [r[0] for r in batch]
|
||||
texts = [r[1] for r in batch]
|
||||
embeddings = model.encode(texts, normalize_embeddings=True)
|
||||
|
||||
params = [
|
||||
(int(rid), array("f", emb.tolist()).tobytes())
|
||||
for rid, emb in zip(ids, embeddings)
|
||||
]
|
||||
db.executemany("INSERT INTO train VALUES (?, ?)", params)
|
||||
db.commit()
|
||||
|
||||
del headlines
|
||||
n = db.execute("SELECT count(*) FROM train").fetchone()[0]
|
||||
print(f"Embedded {n} headlines")
|
||||
|
||||
# Step 2: Embed queries -> query_vectors table
|
||||
print("Embedding queries...")
|
||||
query_embeddings = model.encode(queries, normalize_embeddings=True)
|
||||
query_params = []
|
||||
for i, emb in enumerate(query_embeddings, 1):
|
||||
blob = array("f", emb.tolist()).tobytes()
|
||||
query_params.append((i, blob))
|
||||
db.executemany("INSERT INTO query_vectors VALUES (?, ?)", query_params)
|
||||
db.commit()
|
||||
print(f"Embedded {len(queries)} queries")
|
||||
|
||||
if args.skip_neighbors:
|
||||
db.close()
|
||||
print(f"Done (skipped neighbors). Wrote {args.output}")
|
||||
return
|
||||
|
||||
# Step 3: Brute-force KNN via sqlite-vec -> neighbors table
|
||||
n_queries = db.execute("SELECT count(*) FROM query_vectors").fetchone()[0]
|
||||
print(f"Computing {args.k}-NN for {n_queries} queries via sqlite-vec...")
|
||||
for query_id, query_blob in tqdm(
|
||||
db.execute("SELECT id, vector FROM query_vectors").fetchall()
|
||||
):
|
||||
results = db.execute(
|
||||
"""
|
||||
SELECT
|
||||
train.id,
|
||||
vec_distance_cosine(train.vector, ?) AS distance
|
||||
FROM train
|
||||
WHERE distance IS NOT NULL
|
||||
ORDER BY distance ASC
|
||||
LIMIT ?
|
||||
""",
|
||||
(query_blob, args.k),
|
||||
).fetchall()
|
||||
|
||||
params = [
|
||||
(query_id, rank, str(rid))
|
||||
for rank, (rid, _dist) in enumerate(results)
|
||||
]
|
||||
db.executemany("INSERT INTO neighbors VALUES (?, ?, ?)", params)
|
||||
|
||||
db.commit()
|
||||
db.close()
|
||||
print(f"Done. Wrote {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
100
benchmarks-ann/datasets/nyt-1024/queries.txt
Normal file
100
benchmarks-ann/datasets/nyt-1024/queries.txt
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
latest news on climate change policy
|
||||
presidential election results and analysis
|
||||
stock market crash causes
|
||||
coronavirus vaccine development updates
|
||||
artificial intelligence breakthrough in healthcare
|
||||
supreme court ruling on abortion rights
|
||||
tech companies layoff announcements
|
||||
earthquake damages in California
|
||||
cybersecurity breach at major corporation
|
||||
space exploration mission to Mars
|
||||
immigration reform legislation debate
|
||||
renewable energy investment trends
|
||||
healthcare costs rising across America
|
||||
protests against police brutality
|
||||
wildfires destroy homes in the West
|
||||
Olympic games highlights and records
|
||||
celebrity scandal rocks Hollywood
|
||||
breakthrough cancer treatment discovered
|
||||
housing market bubble concerns
|
||||
federal reserve interest rate decision
|
||||
school shooting tragedy response
|
||||
diplomatic tensions between superpowers
|
||||
drone strike kills terrorist leader
|
||||
social media platform faces regulation
|
||||
archaeological discovery reveals ancient civilization
|
||||
unemployment rate hits record low
|
||||
autonomous vehicles testing expansion
|
||||
streaming service launches original content
|
||||
opioid crisis intervention programs
|
||||
trade war tariffs impact economy
|
||||
infrastructure bill passes Congress
|
||||
data privacy concerns grow
|
||||
minimum wage increase proposal
|
||||
college admissions scandal exposed
|
||||
NFL player protest during anthem
|
||||
cryptocurrency regulation debate
|
||||
pandemic lockdown restrictions eased
|
||||
mass shooting gun control debate
|
||||
tax reform legislation impact
|
||||
ransomware attack cripples pipeline
|
||||
climate activists stage demonstration
|
||||
sports team wins championship
|
||||
banking system collapse fears
|
||||
pharmaceutical company fraud charges
|
||||
genetic engineering ethical concerns
|
||||
border wall funding controversy
|
||||
impeachment proceedings begin
|
||||
nuclear weapons treaty violation
|
||||
artificial meat alternative launch
|
||||
student loan debt forgiveness
|
||||
venture capital funding decline
|
||||
facial recognition ban proposed
|
||||
election interference investigation
|
||||
pandemic preparedness failures
|
||||
police reform measures announced
|
||||
wildfire prevention strategies
|
||||
ocean pollution crisis worsens
|
||||
manufacturing jobs returning
|
||||
pension fund shortfall concerns
|
||||
antitrust investigation launched
|
||||
voting rights protection act
|
||||
mental health awareness campaign
|
||||
homeless population increasing
|
||||
space debris collision risk
|
||||
drug cartel violence escalates
|
||||
renewable energy jobs growth
|
||||
infrastructure deterioration report
|
||||
vaccine mandate legal challenge
|
||||
cryptocurrency market volatility
|
||||
autonomous drone delivery service
|
||||
deep fake technology dangers
|
||||
Arctic ice melting accelerates
|
||||
income inequality gap widens
|
||||
election fraud claims disputed
|
||||
corporate merger blocked
|
||||
medical breakthrough extends life
|
||||
transportation strike disrupts city
|
||||
racial justice protests spread
|
||||
carbon emissions reduction goals
|
||||
financial crisis warning signs
|
||||
cyberbullying prevention efforts
|
||||
asteroid near miss with Earth
|
||||
gene therapy approval granted
|
||||
labor union organizing drive
|
||||
surveillance technology expansion
|
||||
education funding cuts proposed
|
||||
disaster relief efforts underway
|
||||
housing affordability crisis
|
||||
clean water access shortage
|
||||
artificial intelligence job displacement
|
||||
trade agreement negotiations
|
||||
prison reform initiative launched
|
||||
species extinction accelerates
|
||||
political corruption scandal
|
||||
terrorism threat level raised
|
||||
food safety contamination outbreak
|
||||
ai model release
|
||||
affordability interest rates
|
||||
peanut allergies in newbons
|
||||
breaking bad walter white
|
||||
29
benchmarks-ann/datasets/nyt-384/Makefile
Normal file
29
benchmarks-ann/datasets/nyt-384/Makefile
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
MODEL ?= mixedbread-ai/mxbai-embed-xsmall-v1
|
||||
K ?= 100
|
||||
BATCH_SIZE ?= 512
|
||||
DATA_DIR ?= ../nyt/data
|
||||
|
||||
all: base.db
|
||||
|
||||
$(DATA_DIR):
|
||||
$(MAKE) -C ../nyt data
|
||||
|
||||
contents.db: $(DATA_DIR)
|
||||
uv run ../nyt-768/build-contents.py --data-dir $(DATA_DIR) -o $@
|
||||
|
||||
base.db: contents.db queries.txt
|
||||
uv run ../nyt-1024/build-base.py \
|
||||
--contents-db contents.db \
|
||||
--model $(MODEL) \
|
||||
--queries-file queries.txt \
|
||||
--batch-size $(BATCH_SIZE) \
|
||||
--k $(K) \
|
||||
-o $@
|
||||
|
||||
queries.txt:
|
||||
cp ../nyt/queries.txt $@
|
||||
|
||||
clean:
|
||||
rm -f base.db contents.db
|
||||
|
||||
.PHONY: all clean
|
||||
100
benchmarks-ann/datasets/nyt-384/queries.txt
Normal file
100
benchmarks-ann/datasets/nyt-384/queries.txt
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
latest news on climate change policy
|
||||
presidential election results and analysis
|
||||
stock market crash causes
|
||||
coronavirus vaccine development updates
|
||||
artificial intelligence breakthrough in healthcare
|
||||
supreme court ruling on abortion rights
|
||||
tech companies layoff announcements
|
||||
earthquake damages in California
|
||||
cybersecurity breach at major corporation
|
||||
space exploration mission to Mars
|
||||
immigration reform legislation debate
|
||||
renewable energy investment trends
|
||||
healthcare costs rising across America
|
||||
protests against police brutality
|
||||
wildfires destroy homes in the West
|
||||
Olympic games highlights and records
|
||||
celebrity scandal rocks Hollywood
|
||||
breakthrough cancer treatment discovered
|
||||
housing market bubble concerns
|
||||
federal reserve interest rate decision
|
||||
school shooting tragedy response
|
||||
diplomatic tensions between superpowers
|
||||
drone strike kills terrorist leader
|
||||
social media platform faces regulation
|
||||
archaeological discovery reveals ancient civilization
|
||||
unemployment rate hits record low
|
||||
autonomous vehicles testing expansion
|
||||
streaming service launches original content
|
||||
opioid crisis intervention programs
|
||||
trade war tariffs impact economy
|
||||
infrastructure bill passes Congress
|
||||
data privacy concerns grow
|
||||
minimum wage increase proposal
|
||||
college admissions scandal exposed
|
||||
NFL player protest during anthem
|
||||
cryptocurrency regulation debate
|
||||
pandemic lockdown restrictions eased
|
||||
mass shooting gun control debate
|
||||
tax reform legislation impact
|
||||
ransomware attack cripples pipeline
|
||||
climate activists stage demonstration
|
||||
sports team wins championship
|
||||
banking system collapse fears
|
||||
pharmaceutical company fraud charges
|
||||
genetic engineering ethical concerns
|
||||
border wall funding controversy
|
||||
impeachment proceedings begin
|
||||
nuclear weapons treaty violation
|
||||
artificial meat alternative launch
|
||||
student loan debt forgiveness
|
||||
venture capital funding decline
|
||||
facial recognition ban proposed
|
||||
election interference investigation
|
||||
pandemic preparedness failures
|
||||
police reform measures announced
|
||||
wildfire prevention strategies
|
||||
ocean pollution crisis worsens
|
||||
manufacturing jobs returning
|
||||
pension fund shortfall concerns
|
||||
antitrust investigation launched
|
||||
voting rights protection act
|
||||
mental health awareness campaign
|
||||
homeless population increasing
|
||||
space debris collision risk
|
||||
drug cartel violence escalates
|
||||
renewable energy jobs growth
|
||||
infrastructure deterioration report
|
||||
vaccine mandate legal challenge
|
||||
cryptocurrency market volatility
|
||||
autonomous drone delivery service
|
||||
deep fake technology dangers
|
||||
Arctic ice melting accelerates
|
||||
income inequality gap widens
|
||||
election fraud claims disputed
|
||||
corporate merger blocked
|
||||
medical breakthrough extends life
|
||||
transportation strike disrupts city
|
||||
racial justice protests spread
|
||||
carbon emissions reduction goals
|
||||
financial crisis warning signs
|
||||
cyberbullying prevention efforts
|
||||
asteroid near miss with Earth
|
||||
gene therapy approval granted
|
||||
labor union organizing drive
|
||||
surveillance technology expansion
|
||||
education funding cuts proposed
|
||||
disaster relief efforts underway
|
||||
housing affordability crisis
|
||||
clean water access shortage
|
||||
artificial intelligence job displacement
|
||||
trade agreement negotiations
|
||||
prison reform initiative launched
|
||||
species extinction accelerates
|
||||
political corruption scandal
|
||||
terrorism threat level raised
|
||||
food safety contamination outbreak
|
||||
ai model release
|
||||
affordability interest rates
|
||||
peanut allergies in newbons
|
||||
breaking bad walter white
|
||||
37
benchmarks-ann/datasets/nyt-768/Makefile
Normal file
37
benchmarks-ann/datasets/nyt-768/Makefile
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
MODEL ?= bge-base-en-v1.5-768
|
||||
K ?= 100
|
||||
BATCH_SIZE ?= 512
|
||||
DATA_DIR ?= ../nyt/data
|
||||
|
||||
all: base.db
|
||||
|
||||
# Reuse data from ../nyt
|
||||
$(DATA_DIR):
|
||||
$(MAKE) -C ../nyt data
|
||||
|
||||
# Distill model (separate step, may take a while)
|
||||
$(MODEL):
|
||||
uv run distill-model.py
|
||||
|
||||
contents.db: $(DATA_DIR)
|
||||
uv run build-contents.py --data-dir $(DATA_DIR) -o $@
|
||||
|
||||
base.db: contents.db queries.txt $(MODEL)
|
||||
uv run ../nyt/build-base.py \
|
||||
--contents-db contents.db \
|
||||
--model $(MODEL) \
|
||||
--queries-file queries.txt \
|
||||
--batch-size $(BATCH_SIZE) \
|
||||
--k $(K) \
|
||||
-o $@
|
||||
|
||||
queries.txt:
|
||||
cp ../nyt/queries.txt $@
|
||||
|
||||
clean:
|
||||
rm -f base.db contents.db
|
||||
|
||||
clean-all: clean
|
||||
rm -rf $(MODEL)
|
||||
|
||||
.PHONY: all clean clean-all
|
||||
64
benchmarks-ann/datasets/nyt-768/build-contents.py
Normal file
64
benchmarks-ann/datasets/nyt-768/build-contents.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
# /// script
|
||||
# requires-python = ">=3.12"
|
||||
# dependencies = [
|
||||
# "duckdb",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
import argparse
|
||||
import sqlite3
|
||||
import duckdb
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Load NYT headline CSVs into a SQLite contents database (most recent 1M, deduplicated)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-dir", "-d", default="../nyt/data",
|
||||
help="Directory containing NYT CSV files (default: ../nyt/data)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--limit", "-l", type=int, default=1_000_000,
|
||||
help="Maximum number of headlines to keep (default: 1000000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", "-o", required=True,
|
||||
help="Path to the output SQLite database",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
glob_pattern = f"{args.data_dir}/new_york_times_stories_*.csv"
|
||||
|
||||
con = duckdb.connect()
|
||||
rows = con.execute(
|
||||
f"""
|
||||
WITH deduped AS (
|
||||
SELECT
|
||||
headline,
|
||||
max(pub_date) AS pub_date
|
||||
FROM read_csv('{glob_pattern}', auto_detect=true, union_by_name=true)
|
||||
WHERE headline IS NOT NULL AND trim(headline) != ''
|
||||
GROUP BY headline
|
||||
)
|
||||
SELECT
|
||||
row_number() OVER (ORDER BY pub_date DESC) AS id,
|
||||
headline
|
||||
FROM deduped
|
||||
ORDER BY pub_date DESC
|
||||
LIMIT {args.limit}
|
||||
"""
|
||||
).fetchall()
|
||||
con.close()
|
||||
|
||||
db = sqlite3.connect(args.output)
|
||||
db.execute("CREATE TABLE contents(id INTEGER PRIMARY KEY, headline TEXT)")
|
||||
db.executemany("INSERT INTO contents VALUES (?, ?)", rows)
|
||||
db.commit()
|
||||
db.close()
|
||||
|
||||
print(f"Wrote {len(rows)} headlines to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
13
benchmarks-ann/datasets/nyt-768/distill-model.py
Normal file
13
benchmarks-ann/datasets/nyt-768/distill-model.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
# /// script
|
||||
# requires-python = ">=3.12"
|
||||
# dependencies = [
|
||||
# "model2vec[distill]",
|
||||
# "torch<=2.7",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
from model2vec.distill import distill
|
||||
|
||||
model = distill(model_name="BAAI/bge-base-en-v1.5", pca_dims=768)
|
||||
model.save_pretrained("bge-base-en-v1.5-768")
|
||||
print("Saved distilled model to bge-base-en-v1.5-768/")
|
||||
100
benchmarks-ann/datasets/nyt-768/queries.txt
Normal file
100
benchmarks-ann/datasets/nyt-768/queries.txt
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
latest news on climate change policy
|
||||
presidential election results and analysis
|
||||
stock market crash causes
|
||||
coronavirus vaccine development updates
|
||||
artificial intelligence breakthrough in healthcare
|
||||
supreme court ruling on abortion rights
|
||||
tech companies layoff announcements
|
||||
earthquake damages in California
|
||||
cybersecurity breach at major corporation
|
||||
space exploration mission to Mars
|
||||
immigration reform legislation debate
|
||||
renewable energy investment trends
|
||||
healthcare costs rising across America
|
||||
protests against police brutality
|
||||
wildfires destroy homes in the West
|
||||
Olympic games highlights and records
|
||||
celebrity scandal rocks Hollywood
|
||||
breakthrough cancer treatment discovered
|
||||
housing market bubble concerns
|
||||
federal reserve interest rate decision
|
||||
school shooting tragedy response
|
||||
diplomatic tensions between superpowers
|
||||
drone strike kills terrorist leader
|
||||
social media platform faces regulation
|
||||
archaeological discovery reveals ancient civilization
|
||||
unemployment rate hits record low
|
||||
autonomous vehicles testing expansion
|
||||
streaming service launches original content
|
||||
opioid crisis intervention programs
|
||||
trade war tariffs impact economy
|
||||
infrastructure bill passes Congress
|
||||
data privacy concerns grow
|
||||
minimum wage increase proposal
|
||||
college admissions scandal exposed
|
||||
NFL player protest during anthem
|
||||
cryptocurrency regulation debate
|
||||
pandemic lockdown restrictions eased
|
||||
mass shooting gun control debate
|
||||
tax reform legislation impact
|
||||
ransomware attack cripples pipeline
|
||||
climate activists stage demonstration
|
||||
sports team wins championship
|
||||
banking system collapse fears
|
||||
pharmaceutical company fraud charges
|
||||
genetic engineering ethical concerns
|
||||
border wall funding controversy
|
||||
impeachment proceedings begin
|
||||
nuclear weapons treaty violation
|
||||
artificial meat alternative launch
|
||||
student loan debt forgiveness
|
||||
venture capital funding decline
|
||||
facial recognition ban proposed
|
||||
election interference investigation
|
||||
pandemic preparedness failures
|
||||
police reform measures announced
|
||||
wildfire prevention strategies
|
||||
ocean pollution crisis worsens
|
||||
manufacturing jobs returning
|
||||
pension fund shortfall concerns
|
||||
antitrust investigation launched
|
||||
voting rights protection act
|
||||
mental health awareness campaign
|
||||
homeless population increasing
|
||||
space debris collision risk
|
||||
drug cartel violence escalates
|
||||
renewable energy jobs growth
|
||||
infrastructure deterioration report
|
||||
vaccine mandate legal challenge
|
||||
cryptocurrency market volatility
|
||||
autonomous drone delivery service
|
||||
deep fake technology dangers
|
||||
Arctic ice melting accelerates
|
||||
income inequality gap widens
|
||||
election fraud claims disputed
|
||||
corporate merger blocked
|
||||
medical breakthrough extends life
|
||||
transportation strike disrupts city
|
||||
racial justice protests spread
|
||||
carbon emissions reduction goals
|
||||
financial crisis warning signs
|
||||
cyberbullying prevention efforts
|
||||
asteroid near miss with Earth
|
||||
gene therapy approval granted
|
||||
labor union organizing drive
|
||||
surveillance technology expansion
|
||||
education funding cuts proposed
|
||||
disaster relief efforts underway
|
||||
housing affordability crisis
|
||||
clean water access shortage
|
||||
artificial intelligence job displacement
|
||||
trade agreement negotiations
|
||||
prison reform initiative launched
|
||||
species extinction accelerates
|
||||
political corruption scandal
|
||||
terrorism threat level raised
|
||||
food safety contamination outbreak
|
||||
ai model release
|
||||
affordability interest rates
|
||||
peanut allergies in newbons
|
||||
breaking bad walter white
|
||||
1
benchmarks-ann/datasets/nyt/.gitignore
vendored
Normal file
1
benchmarks-ann/datasets/nyt/.gitignore
vendored
Normal file
|
|
@ -0,0 +1 @@
|
|||
data/
|
||||
30
benchmarks-ann/datasets/nyt/Makefile
Normal file
30
benchmarks-ann/datasets/nyt/Makefile
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
MODEL ?= minishlab/potion-base-8M
|
||||
K ?= 100
|
||||
BATCH_SIZE ?= 512
|
||||
DATA_DIR ?= data
|
||||
|
||||
all: base.db contents.db
|
||||
|
||||
# Download NYT headlines CSVs from Kaggle (requires `kaggle` CLI + API token)
|
||||
$(DATA_DIR):
|
||||
kaggle datasets download -d johnbandy/new-york-times-headlines -p $(DATA_DIR) --unzip
|
||||
|
||||
contents.db: $(DATA_DIR)
|
||||
uv run build-contents.py --data-dir $(DATA_DIR) -o $@
|
||||
|
||||
base.db: contents.db queries.txt
|
||||
uv run build-base.py \
|
||||
--contents-db contents.db \
|
||||
--model $(MODEL) \
|
||||
--queries-file queries.txt \
|
||||
--batch-size $(BATCH_SIZE) \
|
||||
--k $(K) \
|
||||
-o $@
|
||||
|
||||
clean:
|
||||
rm -f base.db contents.db
|
||||
|
||||
clean-all: clean
|
||||
rm -rf $(DATA_DIR)
|
||||
|
||||
.PHONY: all clean clean-all
|
||||
165
benchmarks-ann/datasets/nyt/build-base.py
Normal file
165
benchmarks-ann/datasets/nyt/build-base.py
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
# /// script
|
||||
# requires-python = ">=3.12"
|
||||
# dependencies = [
|
||||
# "model2vec",
|
||||
# "torch<=2.7",
|
||||
# "tqdm",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
import argparse
|
||||
import sqlite3
|
||||
from array import array
|
||||
from itertools import batched
|
||||
|
||||
from model2vec import StaticModel
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Build base.db with train vectors, query vectors, and brute-force KNN neighbors",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--contents-db", "-c", default=None,
|
||||
help="Path to contents.db (source of headlines and IDs)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model", "-m", default="minishlab/potion-base-8M",
|
||||
help="HuggingFace model ID or local path (default: minishlab/potion-base-8M)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--queries-file", "-q", default="queries.txt",
|
||||
help="Path to the queries file (default: queries.txt)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", "-o", required=True,
|
||||
help="Path to the output base.db",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size", "-b", type=int, default=512,
|
||||
help="Batch size for embedding (default: 512)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--k", "-k", type=int, default=100,
|
||||
help="Number of nearest neighbors (default: 100)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vec-path", "-v", default="~/projects/sqlite-vec/dist/vec0",
|
||||
help="Path to sqlite-vec extension (default: ~/projects/sqlite-vec/dist/vec0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rebuild-neighbors", action="store_true",
|
||||
help="Only rebuild the neighbors table (skip embedding steps)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
import os
|
||||
vec_path = os.path.expanduser(args.vec_path)
|
||||
|
||||
if args.rebuild_neighbors:
|
||||
# Skip embedding, just open existing DB and rebuild neighbors
|
||||
db = sqlite3.connect(args.output)
|
||||
db.enable_load_extension(True)
|
||||
db.load_extension(vec_path)
|
||||
db.enable_load_extension(False)
|
||||
db.execute("DROP TABLE IF EXISTS neighbors")
|
||||
db.execute(
|
||||
"CREATE TABLE neighbors("
|
||||
" query_vector_id INTEGER, rank INTEGER, neighbors_id TEXT,"
|
||||
" UNIQUE(query_vector_id, rank))"
|
||||
)
|
||||
print(f"Rebuilding neighbors in {args.output}...")
|
||||
else:
|
||||
print(f"Loading model {args.model}...")
|
||||
model = StaticModel.from_pretrained(args.model)
|
||||
|
||||
# Read headlines from contents.db
|
||||
src = sqlite3.connect(args.contents_db)
|
||||
headlines = src.execute("SELECT id, headline FROM contents ORDER BY id").fetchall()
|
||||
src.close()
|
||||
print(f"Loaded {len(headlines)} headlines from {args.contents_db}")
|
||||
|
||||
# Read queries
|
||||
with open(args.queries_file) as f:
|
||||
queries = [line.strip() for line in f if line.strip()]
|
||||
print(f"Loaded {len(queries)} queries from {args.queries_file}")
|
||||
|
||||
# Create output database
|
||||
db = sqlite3.connect(args.output)
|
||||
db.enable_load_extension(True)
|
||||
db.load_extension(vec_path)
|
||||
db.enable_load_extension(False)
|
||||
|
||||
db.execute("CREATE TABLE train(id INTEGER PRIMARY KEY, vector BLOB)")
|
||||
db.execute("CREATE TABLE query_vectors(id INTEGER PRIMARY KEY, vector BLOB)")
|
||||
db.execute(
|
||||
"CREATE TABLE neighbors("
|
||||
" query_vector_id INTEGER, rank INTEGER, neighbors_id TEXT,"
|
||||
" UNIQUE(query_vector_id, rank))"
|
||||
)
|
||||
|
||||
# Step 1: Embed headlines -> train table
|
||||
print("Embedding headlines...")
|
||||
for batch in tqdm(
|
||||
batched(headlines, args.batch_size),
|
||||
total=(len(headlines) + args.batch_size - 1) // args.batch_size,
|
||||
):
|
||||
ids = [r[0] for r in batch]
|
||||
texts = [r[1] for r in batch]
|
||||
embeddings = model.encode(texts)
|
||||
|
||||
params = [
|
||||
(int(rid), array("f", emb.tolist()).tobytes())
|
||||
for rid, emb in zip(ids, embeddings)
|
||||
]
|
||||
db.executemany("INSERT INTO train VALUES (?, ?)", params)
|
||||
db.commit()
|
||||
|
||||
del headlines
|
||||
n = db.execute("SELECT count(*) FROM train").fetchone()[0]
|
||||
print(f"Embedded {n} headlines")
|
||||
|
||||
# Step 2: Embed queries -> query_vectors table
|
||||
print("Embedding queries...")
|
||||
query_embeddings = model.encode(queries)
|
||||
query_params = []
|
||||
for i, emb in enumerate(query_embeddings, 1):
|
||||
blob = array("f", emb.tolist()).tobytes()
|
||||
query_params.append((i, blob))
|
||||
db.executemany("INSERT INTO query_vectors VALUES (?, ?)", query_params)
|
||||
db.commit()
|
||||
print(f"Embedded {len(queries)} queries")
|
||||
|
||||
# Step 3: Brute-force KNN via sqlite-vec -> neighbors table
|
||||
n_queries = db.execute("SELECT count(*) FROM query_vectors").fetchone()[0]
|
||||
print(f"Computing {args.k}-NN for {n_queries} queries via sqlite-vec...")
|
||||
for query_id, query_blob in tqdm(
|
||||
db.execute("SELECT id, vector FROM query_vectors").fetchall()
|
||||
):
|
||||
results = db.execute(
|
||||
"""
|
||||
SELECT
|
||||
train.id,
|
||||
vec_distance_cosine(train.vector, ?) AS distance
|
||||
FROM train
|
||||
WHERE distance IS NOT NULL
|
||||
ORDER BY distance ASC
|
||||
LIMIT ?
|
||||
""",
|
||||
(query_blob, args.k),
|
||||
).fetchall()
|
||||
|
||||
params = [
|
||||
(query_id, rank, str(rid))
|
||||
for rank, (rid, _dist) in enumerate(results)
|
||||
]
|
||||
db.executemany("INSERT INTO neighbors VALUES (?, ?, ?)", params)
|
||||
|
||||
db.commit()
|
||||
db.close()
|
||||
print(f"Done. Wrote {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
52
benchmarks-ann/datasets/nyt/build-contents.py
Normal file
52
benchmarks-ann/datasets/nyt/build-contents.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
# /// script
|
||||
# requires-python = ">=3.12"
|
||||
# dependencies = [
|
||||
# "duckdb",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sqlite3
|
||||
import duckdb
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Load NYT headline CSVs into a SQLite contents database via DuckDB",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-dir", "-d", default="data",
|
||||
help="Directory containing NYT CSV files (default: data)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", "-o", required=True,
|
||||
help="Path to the output SQLite database",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
glob_pattern = os.path.join(args.data_dir, "new_york_times_stories_*.csv")
|
||||
|
||||
con = duckdb.connect()
|
||||
rows = con.execute(
|
||||
f"""
|
||||
SELECT
|
||||
row_number() OVER () AS id,
|
||||
headline
|
||||
FROM read_csv('{glob_pattern}', auto_detect=true, union_by_name=true)
|
||||
WHERE headline IS NOT NULL AND headline != ''
|
||||
"""
|
||||
).fetchall()
|
||||
con.close()
|
||||
|
||||
db = sqlite3.connect(args.output)
|
||||
db.execute("CREATE TABLE contents(id INTEGER PRIMARY KEY, headline TEXT)")
|
||||
db.executemany("INSERT INTO contents VALUES (?, ?)", rows)
|
||||
db.commit()
|
||||
db.close()
|
||||
|
||||
print(f"Wrote {len(rows)} headlines to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
100
benchmarks-ann/datasets/nyt/queries.txt
Normal file
100
benchmarks-ann/datasets/nyt/queries.txt
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
latest news on climate change policy
|
||||
presidential election results and analysis
|
||||
stock market crash causes
|
||||
coronavirus vaccine development updates
|
||||
artificial intelligence breakthrough in healthcare
|
||||
supreme court ruling on abortion rights
|
||||
tech companies layoff announcements
|
||||
earthquake damages in California
|
||||
cybersecurity breach at major corporation
|
||||
space exploration mission to Mars
|
||||
immigration reform legislation debate
|
||||
renewable energy investment trends
|
||||
healthcare costs rising across America
|
||||
protests against police brutality
|
||||
wildfires destroy homes in the West
|
||||
Olympic games highlights and records
|
||||
celebrity scandal rocks Hollywood
|
||||
breakthrough cancer treatment discovered
|
||||
housing market bubble concerns
|
||||
federal reserve interest rate decision
|
||||
school shooting tragedy response
|
||||
diplomatic tensions between superpowers
|
||||
drone strike kills terrorist leader
|
||||
social media platform faces regulation
|
||||
archaeological discovery reveals ancient civilization
|
||||
unemployment rate hits record low
|
||||
autonomous vehicles testing expansion
|
||||
streaming service launches original content
|
||||
opioid crisis intervention programs
|
||||
trade war tariffs impact economy
|
||||
infrastructure bill passes Congress
|
||||
data privacy concerns grow
|
||||
minimum wage increase proposal
|
||||
college admissions scandal exposed
|
||||
NFL player protest during anthem
|
||||
cryptocurrency regulation debate
|
||||
pandemic lockdown restrictions eased
|
||||
mass shooting gun control debate
|
||||
tax reform legislation impact
|
||||
ransomware attack cripples pipeline
|
||||
climate activists stage demonstration
|
||||
sports team wins championship
|
||||
banking system collapse fears
|
||||
pharmaceutical company fraud charges
|
||||
genetic engineering ethical concerns
|
||||
border wall funding controversy
|
||||
impeachment proceedings begin
|
||||
nuclear weapons treaty violation
|
||||
artificial meat alternative launch
|
||||
student loan debt forgiveness
|
||||
venture capital funding decline
|
||||
facial recognition ban proposed
|
||||
election interference investigation
|
||||
pandemic preparedness failures
|
||||
police reform measures announced
|
||||
wildfire prevention strategies
|
||||
ocean pollution crisis worsens
|
||||
manufacturing jobs returning
|
||||
pension fund shortfall concerns
|
||||
antitrust investigation launched
|
||||
voting rights protection act
|
||||
mental health awareness campaign
|
||||
homeless population increasing
|
||||
space debris collision risk
|
||||
drug cartel violence escalates
|
||||
renewable energy jobs growth
|
||||
infrastructure deterioration report
|
||||
vaccine mandate legal challenge
|
||||
cryptocurrency market volatility
|
||||
autonomous drone delivery service
|
||||
deep fake technology dangers
|
||||
Arctic ice melting accelerates
|
||||
income inequality gap widens
|
||||
election fraud claims disputed
|
||||
corporate merger blocked
|
||||
medical breakthrough extends life
|
||||
transportation strike disrupts city
|
||||
racial justice protests spread
|
||||
carbon emissions reduction goals
|
||||
financial crisis warning signs
|
||||
cyberbullying prevention efforts
|
||||
asteroid near miss with Earth
|
||||
gene therapy approval granted
|
||||
labor union organizing drive
|
||||
surveillance technology expansion
|
||||
education funding cuts proposed
|
||||
disaster relief efforts underway
|
||||
housing affordability crisis
|
||||
clean water access shortage
|
||||
artificial intelligence job displacement
|
||||
trade agreement negotiations
|
||||
prison reform initiative launched
|
||||
species extinction accelerates
|
||||
political corruption scandal
|
||||
terrorism threat level raised
|
||||
food safety contamination outbreak
|
||||
ai model release
|
||||
affordability interest rates
|
||||
peanut allergies in newbons
|
||||
breaking bad walter white
|
||||
101
benchmarks-ann/faiss_kmeans.py
Normal file
101
benchmarks-ann/faiss_kmeans.py
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Compute k-means centroids using FAISS and save to a centroids DB.
|
||||
|
||||
Reads the first N vectors from a base.db, runs FAISS k-means, and writes
|
||||
the centroids to an output SQLite DB as float32 blobs.
|
||||
|
||||
Usage:
|
||||
python faiss_kmeans.py --base-db datasets/cohere10m/base.db --ntrain 100000 \
|
||||
--nclusters 8192 -o centroids.db
|
||||
|
||||
Output schema:
|
||||
CREATE TABLE centroids (
|
||||
centroid_id INTEGER PRIMARY KEY,
|
||||
centroid BLOB NOT NULL -- float32[D]
|
||||
);
|
||||
CREATE TABLE meta (key TEXT PRIMARY KEY, value TEXT);
|
||||
-- ntrain, nclusters, dimensions, elapsed_s
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import sqlite3
|
||||
import struct
|
||||
import time
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="FAISS k-means centroid computation")
|
||||
parser.add_argument("--base-db", required=True, help="path to base.db with train table")
|
||||
parser.add_argument("--ntrain", type=int, required=True, help="number of vectors to train on")
|
||||
parser.add_argument("--nclusters", type=int, required=True, help="number of clusters (nlist)")
|
||||
parser.add_argument("--niter", type=int, default=20, help="k-means iterations (default 20)")
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed")
|
||||
parser.add_argument("-o", "--output", required=True, help="output centroids DB path")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load vectors
|
||||
print(f"Loading {args.ntrain} vectors from {args.base_db}...")
|
||||
conn = sqlite3.connect(args.base_db)
|
||||
rows = conn.execute(
|
||||
"SELECT vector FROM train ORDER BY id LIMIT ?", (args.ntrain,)
|
||||
).fetchall()
|
||||
conn.close()
|
||||
|
||||
# Parse float32 blobs to numpy
|
||||
first_blob = rows[0][0]
|
||||
D = len(first_blob) // 4 # float32
|
||||
print(f" Dimensions: {D}, loaded {len(rows)} vectors")
|
||||
|
||||
vectors = np.zeros((len(rows), D), dtype=np.float32)
|
||||
for i, (blob,) in enumerate(rows):
|
||||
vectors[i] = np.frombuffer(blob, dtype=np.float32)
|
||||
|
||||
# Normalize for cosine distance (FAISS k-means on L2 of unit vectors ≈ cosine)
|
||||
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1
|
||||
vectors /= norms
|
||||
|
||||
# Run FAISS k-means
|
||||
print(f"Running k-means: {args.nclusters} clusters, {args.niter} iterations...")
|
||||
t0 = time.perf_counter()
|
||||
kmeans = faiss.Kmeans(
|
||||
D, args.nclusters,
|
||||
niter=args.niter,
|
||||
seed=args.seed,
|
||||
verbose=True,
|
||||
gpu=False,
|
||||
)
|
||||
kmeans.train(vectors)
|
||||
elapsed = time.perf_counter() - t0
|
||||
print(f" Done in {elapsed:.1f}s")
|
||||
|
||||
centroids = kmeans.centroids # (nclusters, D) float32
|
||||
|
||||
# Write output DB
|
||||
if os.path.exists(args.output):
|
||||
os.remove(args.output)
|
||||
out = sqlite3.connect(args.output)
|
||||
out.execute("CREATE TABLE centroids (centroid_id INTEGER PRIMARY KEY, centroid BLOB NOT NULL)")
|
||||
out.execute("CREATE TABLE meta (key TEXT PRIMARY KEY, value TEXT)")
|
||||
|
||||
for i in range(args.nclusters):
|
||||
blob = centroids[i].tobytes()
|
||||
out.execute("INSERT INTO centroids (centroid_id, centroid) VALUES (?, ?)", (i, blob))
|
||||
|
||||
out.execute("INSERT INTO meta VALUES ('ntrain', ?)", (str(args.ntrain),))
|
||||
out.execute("INSERT INTO meta VALUES ('nclusters', ?)", (str(args.nclusters),))
|
||||
out.execute("INSERT INTO meta VALUES ('dimensions', ?)", (str(D),))
|
||||
out.execute("INSERT INTO meta VALUES ('niter', ?)", (str(args.niter),))
|
||||
out.execute("INSERT INTO meta VALUES ('elapsed_s', ?)", (str(round(elapsed, 3)),))
|
||||
out.execute("INSERT INTO meta VALUES ('seed', ?)", (str(args.seed),))
|
||||
out.commit()
|
||||
out.close()
|
||||
|
||||
print(f"Wrote {args.nclusters} centroids to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
168
benchmarks-ann/ground_truth.py
Normal file
168
benchmarks-ann/ground_truth.py
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Compute per-subset ground truth for ANN benchmarks.
|
||||
|
||||
For subset sizes < 1M, builds a temporary vec0 float table with the first N
|
||||
vectors and runs brute-force KNN to get correct ground truth per subset.
|
||||
|
||||
For 1M (the full dataset), converts the existing `neighbors` table.
|
||||
|
||||
Output: ground_truth.{subset_size}.db with table:
|
||||
ground_truth(query_vector_id, rank, neighbor_id, distance)
|
||||
|
||||
Usage:
|
||||
python ground_truth.py --subset-size 50000
|
||||
python ground_truth.py --subset-size 1000000
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
|
||||
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
EXT_PATH = os.path.join(_SCRIPT_DIR, "..", "dist", "vec0")
|
||||
BASE_DB = os.path.join(_SCRIPT_DIR, "seed", "base.db")
|
||||
FULL_DATASET_SIZE = 1_000_000
|
||||
|
||||
|
||||
def gen_ground_truth_subset(base_db, ext_path, subset_size, n_queries, k, out_path):
|
||||
"""Build ground truth by brute-force KNN over the first `subset_size` vectors."""
|
||||
if os.path.exists(out_path):
|
||||
os.remove(out_path)
|
||||
|
||||
conn = sqlite3.connect(out_path)
|
||||
conn.enable_load_extension(True)
|
||||
conn.load_extension(ext_path)
|
||||
|
||||
conn.execute(
|
||||
"CREATE TABLE ground_truth ("
|
||||
" query_vector_id INTEGER NOT NULL,"
|
||||
" rank INTEGER NOT NULL,"
|
||||
" neighbor_id INTEGER NOT NULL,"
|
||||
" distance REAL NOT NULL,"
|
||||
" PRIMARY KEY (query_vector_id, rank)"
|
||||
")"
|
||||
)
|
||||
|
||||
conn.execute(f"ATTACH DATABASE '{base_db}' AS base")
|
||||
|
||||
print(f" Building temp vec0 table with {subset_size} vectors...")
|
||||
conn.execute(
|
||||
"CREATE VIRTUAL TABLE tmp_vec USING vec0("
|
||||
" id integer primary key,"
|
||||
" embedding float[768] distance_metric=cosine"
|
||||
")"
|
||||
)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
conn.execute(
|
||||
"INSERT INTO tmp_vec(id, embedding) "
|
||||
"SELECT id, vector FROM base.train WHERE id < :n",
|
||||
{"n": subset_size},
|
||||
)
|
||||
conn.commit()
|
||||
build_time = time.perf_counter() - t0
|
||||
print(f" Temp table built in {build_time:.1f}s")
|
||||
|
||||
query_vectors = conn.execute(
|
||||
"SELECT id, vector FROM base.query_vectors ORDER BY id LIMIT :n",
|
||||
{"n": n_queries},
|
||||
).fetchall()
|
||||
|
||||
print(f" Running brute-force KNN for {len(query_vectors)} queries, k={k}...")
|
||||
t0 = time.perf_counter()
|
||||
|
||||
for i, (qid, qvec) in enumerate(query_vectors):
|
||||
results = conn.execute(
|
||||
"SELECT id, distance FROM tmp_vec "
|
||||
"WHERE embedding MATCH :query AND k = :k",
|
||||
{"query": qvec, "k": k},
|
||||
).fetchall()
|
||||
|
||||
for rank, (nid, dist) in enumerate(results):
|
||||
conn.execute(
|
||||
"INSERT INTO ground_truth(query_vector_id, rank, neighbor_id, distance) "
|
||||
"VALUES (?, ?, ?, ?)",
|
||||
(qid, rank, nid, dist),
|
||||
)
|
||||
|
||||
if (i + 1) % 10 == 0 or i == 0:
|
||||
elapsed = time.perf_counter() - t0
|
||||
eta = (elapsed / (i + 1)) * (len(query_vectors) - i - 1)
|
||||
print(
|
||||
f" {i+1}/{len(query_vectors)} queries "
|
||||
f"elapsed={elapsed:.1f}s eta={eta:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.execute("DROP TABLE tmp_vec")
|
||||
conn.execute("DETACH DATABASE base")
|
||||
conn.commit()
|
||||
|
||||
elapsed = time.perf_counter() - t0
|
||||
total_rows = conn.execute("SELECT count(*) FROM ground_truth").fetchone()[0]
|
||||
conn.close()
|
||||
print(f" Ground truth: {total_rows} rows in {elapsed:.1f}s -> {out_path}")
|
||||
|
||||
|
||||
def gen_ground_truth_full(base_db, n_queries, k, out_path):
|
||||
"""Convert the existing neighbors table for the full 1M dataset."""
|
||||
if os.path.exists(out_path):
|
||||
os.remove(out_path)
|
||||
|
||||
conn = sqlite3.connect(out_path)
|
||||
conn.execute(f"ATTACH DATABASE '{base_db}' AS base")
|
||||
|
||||
conn.execute(
|
||||
"CREATE TABLE ground_truth ("
|
||||
" query_vector_id INTEGER NOT NULL,"
|
||||
" rank INTEGER NOT NULL,"
|
||||
" neighbor_id INTEGER NOT NULL,"
|
||||
" distance REAL,"
|
||||
" PRIMARY KEY (query_vector_id, rank)"
|
||||
")"
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO ground_truth(query_vector_id, rank, neighbor_id) "
|
||||
"SELECT query_vector_id, rank, CAST(neighbors_id AS INTEGER) "
|
||||
"FROM base.neighbors "
|
||||
"WHERE query_vector_id < :n AND rank < :k",
|
||||
{"n": n_queries, "k": k},
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
total_rows = conn.execute("SELECT count(*) FROM ground_truth").fetchone()[0]
|
||||
conn.execute("DETACH DATABASE base")
|
||||
conn.close()
|
||||
print(f" Ground truth (full): {total_rows} rows -> {out_path}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Generate per-subset ground truth")
|
||||
parser.add_argument(
|
||||
"--subset-size", type=int, required=True, help="number of vectors in subset"
|
||||
)
|
||||
parser.add_argument("-n", type=int, default=100, help="number of query vectors")
|
||||
parser.add_argument("-k", type=int, default=100, help="max k for ground truth")
|
||||
parser.add_argument("--base-db", default=BASE_DB)
|
||||
parser.add_argument("--ext", default=EXT_PATH)
|
||||
parser.add_argument(
|
||||
"-o", "--out-dir", default=os.path.join(_SCRIPT_DIR, "seed"),
|
||||
help="output directory for ground_truth.{N}.db",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
out_path = os.path.join(args.out_dir, f"ground_truth.{args.subset_size}.db")
|
||||
|
||||
if args.subset_size >= FULL_DATASET_SIZE:
|
||||
gen_ground_truth_full(args.base_db, args.n, args.k, out_path)
|
||||
else:
|
||||
gen_ground_truth_subset(
|
||||
args.base_db, args.ext, args.subset_size, args.n, args.k, out_path
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
440
benchmarks-ann/profile.py
Normal file
440
benchmarks-ann/profile.py
Normal file
|
|
@ -0,0 +1,440 @@
|
|||
#!/usr/bin/env python3
|
||||
"""CPU profiling for sqlite-vec KNN configurations using macOS `sample` tool.
|
||||
|
||||
Builds dist/sqlite3 (with -g3), generates a SQL workload (inserts + repeated
|
||||
KNN queries) for each config, profiles the sqlite3 process with `sample`, and
|
||||
prints the top-N hottest functions by self (exclusive) CPU samples.
|
||||
|
||||
Usage:
|
||||
cd benchmarks-ann
|
||||
uv run profile.py --subset-size 50000 -n 50 \\
|
||||
"baseline-int8:type=baseline,variant=int8,oversample=8" \\
|
||||
"rescore-int8:type=rescore,quantizer=int8,oversample=8"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
_PROJECT_ROOT = os.path.join(_SCRIPT_DIR, "..")
|
||||
|
||||
sys.path.insert(0, _SCRIPT_DIR)
|
||||
from bench import (
|
||||
BASE_DB,
|
||||
DEFAULT_INSERT_SQL,
|
||||
INDEX_REGISTRY,
|
||||
INSERT_BATCH_SIZE,
|
||||
parse_config,
|
||||
)
|
||||
|
||||
SQLITE3_PATH = os.path.join(_PROJECT_ROOT, "dist", "sqlite3")
|
||||
EXT_PATH = os.path.join(_PROJECT_ROOT, "dist", "vec0")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SQL generation
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _query_sql_for_config(params, query_id, k):
|
||||
"""Return a SQL query string for a single KNN query by query_vector id."""
|
||||
index_type = params["index_type"]
|
||||
qvec = f"(SELECT vector FROM base.query_vectors WHERE id = {query_id})"
|
||||
|
||||
if index_type == "baseline":
|
||||
variant = params.get("variant", "float")
|
||||
oversample = params.get("oversample", 8)
|
||||
oversample_k = k * oversample
|
||||
|
||||
if variant == "int8":
|
||||
return (
|
||||
f"WITH coarse AS ("
|
||||
f" SELECT id, embedding FROM vec_items"
|
||||
f" WHERE embedding_int8 MATCH vec_quantize_int8({qvec}, 'unit')"
|
||||
f" LIMIT {oversample_k}"
|
||||
f") "
|
||||
f"SELECT id, vec_distance_cosine(embedding, {qvec}) as distance "
|
||||
f"FROM coarse ORDER BY 2 LIMIT {k};"
|
||||
)
|
||||
elif variant == "bit":
|
||||
return (
|
||||
f"WITH coarse AS ("
|
||||
f" SELECT id, embedding FROM vec_items"
|
||||
f" WHERE embedding_bq MATCH vec_quantize_binary({qvec})"
|
||||
f" LIMIT {oversample_k}"
|
||||
f") "
|
||||
f"SELECT id, vec_distance_cosine(embedding, {qvec}) as distance "
|
||||
f"FROM coarse ORDER BY 2 LIMIT {k};"
|
||||
)
|
||||
|
||||
# Default MATCH query (baseline-float, rescore, and others)
|
||||
return (
|
||||
f"SELECT id, distance FROM vec_items"
|
||||
f" WHERE embedding MATCH {qvec} AND k = {k};"
|
||||
)
|
||||
|
||||
|
||||
def generate_sql(db_path, params, subset_size, n_queries, k, repeats):
|
||||
"""Generate a complete SQL workload: load ext, create table, insert, query."""
|
||||
lines = []
|
||||
lines.append(".bail on")
|
||||
lines.append(f".load {EXT_PATH}")
|
||||
lines.append(f"ATTACH DATABASE '{os.path.abspath(BASE_DB)}' AS base;")
|
||||
lines.append("PRAGMA page_size=8192;")
|
||||
|
||||
# Create table
|
||||
reg = INDEX_REGISTRY[params["index_type"]]
|
||||
lines.append(reg["create_table_sql"](params) + ";")
|
||||
|
||||
# Inserts
|
||||
sql_fn = reg.get("insert_sql")
|
||||
insert_sql = sql_fn(params) if sql_fn else None
|
||||
if insert_sql is None:
|
||||
insert_sql = DEFAULT_INSERT_SQL
|
||||
for lo in range(0, subset_size, INSERT_BATCH_SIZE):
|
||||
hi = min(lo + INSERT_BATCH_SIZE, subset_size)
|
||||
stmt = insert_sql.replace(":lo", str(lo)).replace(":hi", str(hi))
|
||||
lines.append(stmt + ";")
|
||||
if hi % 10000 == 0 or hi == subset_size:
|
||||
lines.append("-- progress: inserted %d/%d" % (hi, subset_size))
|
||||
|
||||
# Queries (repeated)
|
||||
lines.append("-- BEGIN QUERIES")
|
||||
for _rep in range(repeats):
|
||||
for qid in range(n_queries):
|
||||
lines.append(_query_sql_for_config(params, qid, k))
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Profiling with macOS `sample`
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def run_profile(sqlite3_path, db_path, sql_file, sample_output, duration=120):
|
||||
"""Run sqlite3 under macOS `sample` profiler.
|
||||
|
||||
Starts sqlite3 directly with stdin from the SQL file, then immediately
|
||||
attaches `sample` to its PID with -mayDie (tolerates process exit).
|
||||
The workload must be long enough for sample to attach and capture useful data.
|
||||
"""
|
||||
sql_fd = open(sql_file, "r")
|
||||
proc = subprocess.Popen(
|
||||
[sqlite3_path, db_path],
|
||||
stdin=sql_fd,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
|
||||
pid = proc.pid
|
||||
print(f" sqlite3 PID: {pid}")
|
||||
|
||||
# Attach sample immediately (1ms interval, -mayDie tolerates process exit)
|
||||
sample_proc = subprocess.Popen(
|
||||
["sample", str(pid), str(duration), "1", "-mayDie", "-file", sample_output],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
|
||||
# Wait for sqlite3 to finish
|
||||
_, stderr = proc.communicate()
|
||||
sql_fd.close()
|
||||
rc = proc.returncode
|
||||
if rc != 0:
|
||||
print(f" sqlite3 failed (rc={rc}):", file=sys.stderr)
|
||||
print(f" {stderr.decode().strip()}", file=sys.stderr)
|
||||
sample_proc.kill()
|
||||
return False
|
||||
|
||||
# Wait for sample to finish
|
||||
sample_proc.wait()
|
||||
return True
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Parse `sample` output
|
||||
# ============================================================================
|
||||
|
||||
# Tree-drawing characters used by macOS `sample` to represent hierarchy.
|
||||
# We replace them with spaces so indentation depth reflects tree depth.
|
||||
_TREE_CHARS_RE = re.compile(r"[+!:|]")
|
||||
|
||||
# After tree chars are replaced with spaces, each call-graph line looks like:
|
||||
# " 800 rescore_knn (in vec0.dylib) + 3808,3640,... [0x1a,0x2b,...] file.c:123"
|
||||
# We extract just (indent, count, symbol, module) — everything after "(in ...)"
|
||||
# is decoration we don't need.
|
||||
_LEADING_RE = re.compile(r"^(\s+)(\d+)\s+(.+)")
|
||||
|
||||
|
||||
def _extract_symbol_and_module(rest):
|
||||
"""Given the text after 'count ', extract (symbol, module).
|
||||
|
||||
Handles patterns like:
|
||||
'rescore_knn (in vec0.dylib) + 3808,3640,... [0x...]'
|
||||
'pread (in libsystem_kernel.dylib) + 8 [0x...]'
|
||||
'??? (in <unknown binary>) [0x...]'
|
||||
'start (in dyld) + 2840 [0x198650274]'
|
||||
'Thread_26759239 DispatchQueue_1: ...'
|
||||
"""
|
||||
# Try to find "(in ...)" to split symbol from module
|
||||
m = re.match(r"^(.+?)\s+\(in\s+(.+?)\)", rest)
|
||||
if m:
|
||||
return m.group(1).strip(), m.group(2).strip()
|
||||
# No module — return whole thing as symbol, strip trailing junk
|
||||
sym = re.sub(r"\s+\[0x[0-9a-f].*", "", rest).strip()
|
||||
return sym, ""
|
||||
|
||||
|
||||
def _parse_call_graph_lines(text):
|
||||
"""Parse call-graph section into list of (depth, count, symbol, module)."""
|
||||
entries = []
|
||||
for raw_line in text.split("\n"):
|
||||
# Strip tree-drawing characters, replace with spaces to preserve depth
|
||||
line = _TREE_CHARS_RE.sub(" ", raw_line)
|
||||
m = _LEADING_RE.match(line)
|
||||
if not m:
|
||||
continue
|
||||
depth = len(m.group(1))
|
||||
count = int(m.group(2))
|
||||
rest = m.group(3)
|
||||
symbol, module = _extract_symbol_and_module(rest)
|
||||
entries.append((depth, count, symbol, module))
|
||||
return entries
|
||||
|
||||
|
||||
def parse_sample_output(filepath):
|
||||
"""Parse `sample` call-graph output, compute exclusive (self) samples per function.
|
||||
|
||||
Returns dict of {display_name: self_sample_count}.
|
||||
"""
|
||||
with open(filepath, "r") as f:
|
||||
text = f.read()
|
||||
|
||||
# Find "Call graph:" section
|
||||
cg_start = text.find("Call graph:")
|
||||
if cg_start == -1:
|
||||
print(" Warning: no 'Call graph:' section found in sample output")
|
||||
return {}
|
||||
|
||||
# End at "Total number in stack" or EOF
|
||||
cg_end = text.find("\nTotal number in stack", cg_start)
|
||||
if cg_end == -1:
|
||||
cg_end = len(text)
|
||||
|
||||
entries = _parse_call_graph_lines(text[cg_start:cg_end])
|
||||
|
||||
if not entries:
|
||||
print(" Warning: no call graph entries parsed")
|
||||
return {}
|
||||
|
||||
# Compute self (exclusive) samples per function:
|
||||
# self = count - sum(direct_children_counts)
|
||||
self_samples = {}
|
||||
for i, (depth, count, sym, mod) in enumerate(entries):
|
||||
children_sum = 0
|
||||
child_depth = None
|
||||
for j in range(i + 1, len(entries)):
|
||||
j_depth = entries[j][0]
|
||||
if j_depth <= depth:
|
||||
break
|
||||
if child_depth is None:
|
||||
child_depth = j_depth
|
||||
if j_depth == child_depth:
|
||||
children_sum += entries[j][1]
|
||||
|
||||
self_count = count - children_sum
|
||||
if self_count > 0:
|
||||
key = f"{sym} ({mod})" if mod else sym
|
||||
self_samples[key] = self_samples.get(key, 0) + self_count
|
||||
|
||||
return self_samples
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Display
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def print_profile(title, self_samples, top_n=20):
|
||||
total = sum(self_samples.values())
|
||||
if total == 0:
|
||||
print(f"\n=== {title} (no samples) ===")
|
||||
return
|
||||
|
||||
sorted_syms = sorted(self_samples.items(), key=lambda x: -x[1])
|
||||
|
||||
print(f"\n=== {title} (top {top_n}, {total} total self-samples) ===")
|
||||
for sym, count in sorted_syms[:top_n]:
|
||||
pct = 100.0 * count / total
|
||||
print(f" {pct:5.1f}% {count:>6} {sym}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Main
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="CPU profiling for sqlite-vec KNN configurations",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__,
|
||||
)
|
||||
parser.add_argument(
|
||||
"configs", nargs="+", help="config specs (name:type=X,key=val,...)"
|
||||
)
|
||||
parser.add_argument("--subset-size", type=int, required=True)
|
||||
parser.add_argument("-k", type=int, default=10, help="KNN k (default 10)")
|
||||
parser.add_argument(
|
||||
"-n", type=int, default=50, help="number of distinct queries (default 50)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repeats",
|
||||
type=int,
|
||||
default=10,
|
||||
help="repeat query set N times for more samples (default 10)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top", type=int, default=20, help="show top N functions (default 20)"
|
||||
)
|
||||
parser.add_argument("--base-db", default=BASE_DB)
|
||||
parser.add_argument("--sqlite3", default=SQLITE3_PATH)
|
||||
parser.add_argument(
|
||||
"--keep-temp",
|
||||
action="store_true",
|
||||
help="keep temp directory with DBs, SQL, and sample output",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Check prerequisites
|
||||
if not os.path.exists(args.base_db):
|
||||
print(f"Error: base DB not found at {args.base_db}", file=sys.stderr)
|
||||
print("Run 'make seed' in benchmarks-ann/ first.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if not shutil.which("sample"):
|
||||
print("Error: macOS 'sample' tool not found.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Build CLI
|
||||
print("Building dist/sqlite3...")
|
||||
result = subprocess.run(
|
||||
["make", "cli"], cwd=_PROJECT_ROOT, capture_output=True, text=True
|
||||
)
|
||||
if result.returncode != 0:
|
||||
print(f"Error: make cli failed:\n{result.stderr}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
print(" done.")
|
||||
|
||||
if not os.path.exists(args.sqlite3):
|
||||
print(f"Error: sqlite3 not found at {args.sqlite3}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
configs = [parse_config(c) for c in args.configs]
|
||||
|
||||
tmpdir = tempfile.mkdtemp(prefix="sqlite-vec-profile-")
|
||||
print(f"Working directory: {tmpdir}")
|
||||
|
||||
all_profiles = []
|
||||
|
||||
for i, (name, params) in enumerate(configs, 1):
|
||||
reg = INDEX_REGISTRY[params["index_type"]]
|
||||
desc = reg["describe"](params)
|
||||
print(f"\n[{i}/{len(configs)}] {name} ({desc})")
|
||||
|
||||
# Generate SQL workload
|
||||
db_path = os.path.join(tmpdir, f"{name}.db")
|
||||
sql_text = generate_sql(
|
||||
db_path, params, args.subset_size, args.n, args.k, args.repeats
|
||||
)
|
||||
sql_file = os.path.join(tmpdir, f"{name}.sql")
|
||||
with open(sql_file, "w") as f:
|
||||
f.write(sql_text)
|
||||
|
||||
total_queries = args.n * args.repeats
|
||||
print(
|
||||
f" SQL workload: {args.subset_size} inserts + "
|
||||
f"{total_queries} queries ({args.n} x {args.repeats} repeats)"
|
||||
)
|
||||
|
||||
# Profile
|
||||
sample_file = os.path.join(tmpdir, f"{name}.sample.txt")
|
||||
print(f" Profiling...")
|
||||
ok = run_profile(args.sqlite3, db_path, sql_file, sample_file)
|
||||
if not ok:
|
||||
print(f" FAILED — skipping {name}")
|
||||
all_profiles.append((name, desc, {}))
|
||||
continue
|
||||
|
||||
if not os.path.exists(sample_file):
|
||||
print(f" Warning: sample output not created")
|
||||
all_profiles.append((name, desc, {}))
|
||||
continue
|
||||
|
||||
# Parse
|
||||
self_samples = parse_sample_output(sample_file)
|
||||
all_profiles.append((name, desc, self_samples))
|
||||
|
||||
# Show individual profile
|
||||
print_profile(f"{name} ({desc})", self_samples, args.top)
|
||||
|
||||
# Side-by-side comparison if multiple configs
|
||||
if len(all_profiles) > 1:
|
||||
print("\n" + "=" * 80)
|
||||
print("COMPARISON")
|
||||
print("=" * 80)
|
||||
|
||||
# Collect all symbols that appear in top-N of any config
|
||||
all_syms = set()
|
||||
for _name, _desc, prof in all_profiles:
|
||||
sorted_syms = sorted(prof.items(), key=lambda x: -x[1])
|
||||
for sym, _count in sorted_syms[: args.top]:
|
||||
all_syms.add(sym)
|
||||
|
||||
# Build comparison table
|
||||
rows = []
|
||||
for sym in all_syms:
|
||||
row = [sym]
|
||||
for _name, _desc, prof in all_profiles:
|
||||
total = sum(prof.values())
|
||||
count = prof.get(sym, 0)
|
||||
pct = 100.0 * count / total if total > 0 else 0.0
|
||||
row.append((pct, count))
|
||||
max_pct = max(r[0] for r in row[1:])
|
||||
rows.append((max_pct, row))
|
||||
|
||||
rows.sort(key=lambda x: -x[0])
|
||||
|
||||
# Header
|
||||
header = f"{'function':>40}"
|
||||
for name, desc, _ in all_profiles:
|
||||
header += f" {name:>14}"
|
||||
print(header)
|
||||
print("-" * len(header))
|
||||
|
||||
for _sort_key, row in rows[: args.top * 2]:
|
||||
sym = row[0]
|
||||
display_sym = sym if len(sym) <= 40 else sym[:37] + "..."
|
||||
line = f"{display_sym:>40}"
|
||||
for pct, count in row[1:]:
|
||||
if count > 0:
|
||||
line += f" {pct:>13.1f}%"
|
||||
else:
|
||||
line += f" {'-':>14}"
|
||||
print(line)
|
||||
|
||||
if args.keep_temp:
|
||||
print(f"\nTemp files kept at: {tmpdir}")
|
||||
else:
|
||||
shutil.rmtree(tmpdir)
|
||||
print(f"\nTemp files cleaned up. Use --keep-temp to preserve.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
76
benchmarks-ann/results_schema.sql
Normal file
76
benchmarks-ann/results_schema.sql
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
-- Comprehensive results schema for vec0 KNN benchmark runs.
|
||||
-- Created in WAL mode: PRAGMA journal_mode=WAL
|
||||
|
||||
CREATE TABLE IF NOT EXISTS runs (
|
||||
run_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
config_name TEXT NOT NULL,
|
||||
index_type TEXT NOT NULL,
|
||||
params TEXT NOT NULL, -- JSON: {"R":48,"L":128,"quantizer":"binary"}
|
||||
dataset TEXT NOT NULL, -- "cohere1m"
|
||||
subset_size INTEGER NOT NULL,
|
||||
k INTEGER NOT NULL,
|
||||
n_queries INTEGER NOT NULL,
|
||||
phase TEXT NOT NULL DEFAULT 'both',
|
||||
-- 'build', 'query', or 'both'
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
-- pending → inserting → training → querying → done | built | error
|
||||
created_at_ns INTEGER NOT NULL -- time.time_ns()
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS run_results (
|
||||
run_id INTEGER PRIMARY KEY REFERENCES runs(run_id),
|
||||
insert_started_ns INTEGER,
|
||||
insert_ended_ns INTEGER,
|
||||
insert_duration_ns INTEGER,
|
||||
train_started_ns INTEGER, -- NULL if no training
|
||||
train_ended_ns INTEGER,
|
||||
train_duration_ns INTEGER,
|
||||
build_duration_ns INTEGER, -- insert + train
|
||||
db_file_size_bytes INTEGER,
|
||||
db_file_path TEXT,
|
||||
create_sql TEXT, -- CREATE VIRTUAL TABLE ...
|
||||
insert_sql TEXT, -- INSERT INTO vec_items ...
|
||||
train_sql TEXT, -- NULL if no training step
|
||||
query_sql TEXT, -- SELECT ... WHERE embedding MATCH ...
|
||||
k INTEGER, -- denormalized from runs for easy filtering
|
||||
query_mean_ms REAL, -- denormalized aggregates
|
||||
query_median_ms REAL,
|
||||
query_p99_ms REAL,
|
||||
query_total_ms REAL,
|
||||
qps REAL,
|
||||
recall REAL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS insert_batches (
|
||||
batch_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
run_id INTEGER NOT NULL REFERENCES runs(run_id),
|
||||
batch_lo INTEGER NOT NULL, -- start index (inclusive)
|
||||
batch_hi INTEGER NOT NULL, -- end index (exclusive)
|
||||
rows_in_batch INTEGER NOT NULL,
|
||||
started_ns INTEGER NOT NULL,
|
||||
ended_ns INTEGER NOT NULL,
|
||||
duration_ns INTEGER NOT NULL,
|
||||
cumulative_rows INTEGER NOT NULL, -- total rows inserted so far
|
||||
rate_rows_per_s REAL NOT NULL -- cumulative rate
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS queries (
|
||||
query_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
run_id INTEGER NOT NULL REFERENCES runs(run_id),
|
||||
k INTEGER NOT NULL,
|
||||
query_vector_id INTEGER NOT NULL,
|
||||
started_ns INTEGER NOT NULL,
|
||||
ended_ns INTEGER NOT NULL,
|
||||
duration_ms REAL NOT NULL,
|
||||
result_ids TEXT NOT NULL, -- JSON array
|
||||
result_distances TEXT NOT NULL, -- JSON array
|
||||
ground_truth_ids TEXT NOT NULL, -- JSON array
|
||||
recall REAL NOT NULL,
|
||||
UNIQUE(run_id, k, query_vector_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_runs_config ON runs(config_name);
|
||||
CREATE INDEX IF NOT EXISTS idx_runs_type ON runs(index_type);
|
||||
CREATE INDEX IF NOT EXISTS idx_runs_status ON runs(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_batches_run ON insert_batches(run_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_queries_run ON queries(run_id);
|
||||
60
benchmarks-ann/schema.sql
Normal file
60
benchmarks-ann/schema.sql
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
-- Canonical results schema for vec0 KNN benchmark comparisons.
|
||||
-- The index_type column is a free-form TEXT field. Baseline configs use
|
||||
-- "baseline"; index-specific branches add their own types (registered
|
||||
-- via INDEX_REGISTRY in bench.py).
|
||||
|
||||
CREATE TABLE IF NOT EXISTS runs (
|
||||
run_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
config_name TEXT NOT NULL,
|
||||
index_type TEXT NOT NULL,
|
||||
subset_size INTEGER NOT NULL,
|
||||
phase TEXT NOT NULL DEFAULT 'both', -- 'build', 'query', or 'both'
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
k INTEGER,
|
||||
n INTEGER,
|
||||
db_path TEXT,
|
||||
insert_time_s REAL,
|
||||
train_time_s REAL,
|
||||
total_build_time_s REAL,
|
||||
rows INTEGER,
|
||||
file_size_mb REAL,
|
||||
mean_ms REAL,
|
||||
median_ms REAL,
|
||||
p99_ms REAL,
|
||||
total_query_ms REAL,
|
||||
qps REAL,
|
||||
recall REAL,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
finished_at TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS build_results (
|
||||
config_name TEXT NOT NULL,
|
||||
index_type TEXT NOT NULL,
|
||||
subset_size INTEGER NOT NULL,
|
||||
db_path TEXT NOT NULL,
|
||||
insert_time_s REAL NOT NULL,
|
||||
train_time_s REAL, -- NULL when no training/build step is needed
|
||||
total_time_s REAL NOT NULL,
|
||||
rows INTEGER NOT NULL,
|
||||
file_size_mb REAL NOT NULL,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
PRIMARY KEY (config_name, subset_size)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS bench_results (
|
||||
config_name TEXT NOT NULL,
|
||||
index_type TEXT NOT NULL,
|
||||
subset_size INTEGER NOT NULL,
|
||||
k INTEGER NOT NULL,
|
||||
n INTEGER NOT NULL,
|
||||
mean_ms REAL NOT NULL,
|
||||
median_ms REAL NOT NULL,
|
||||
p99_ms REAL NOT NULL,
|
||||
total_ms REAL NOT NULL,
|
||||
qps REAL NOT NULL,
|
||||
recall REAL NOT NULL,
|
||||
db_path TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
PRIMARY KEY (config_name, subset_size, k)
|
||||
);
|
||||
|
|
@ -248,59 +248,6 @@ def bench_libsql(base, query, page_size, k) -> BenchResult:
|
|||
return BenchResult(f"libsql ({page_size})", build_time, times)
|
||||
|
||||
|
||||
def register_np(db, array, name):
|
||||
ptr = array.__array_interface__["data"][0]
|
||||
nvectors, dimensions = array.__array_interface__["shape"]
|
||||
element_type = array.__array_interface__["typestr"]
|
||||
|
||||
assert element_type == "<f4"
|
||||
|
||||
name_escaped = db.execute("select printf('%w', ?)", [name]).fetchone()[0]
|
||||
|
||||
db.execute(
|
||||
"insert into temp.vec_static_blobs(name, data) select ?, vec_static_blob_from_raw(?, ?, ?, ?)",
|
||||
[name, ptr, element_type, dimensions, nvectors],
|
||||
)
|
||||
|
||||
db.execute(
|
||||
f'create virtual table "{name_escaped}" using vec_static_blob_entries({name_escaped})'
|
||||
)
|
||||
|
||||
def bench_sqlite_vec_static(base, query, k) -> BenchResult:
|
||||
print(f"sqlite-vec static...")
|
||||
|
||||
db = sqlite3.connect(":memory:")
|
||||
db.enable_load_extension(True)
|
||||
db.load_extension("../../dist/vec0")
|
||||
|
||||
|
||||
|
||||
t = time.time()
|
||||
register_np(db, base, "base")
|
||||
build_time = time.time() - t
|
||||
|
||||
times = []
|
||||
results = []
|
||||
for (
|
||||
idx,
|
||||
q,
|
||||
) in enumerate(query):
|
||||
t0 = time.time()
|
||||
result = db.execute(
|
||||
"""
|
||||
select
|
||||
rowid
|
||||
from base
|
||||
where vector match ?
|
||||
and k = ?
|
||||
order by distance
|
||||
""",
|
||||
[q.tobytes(), k],
|
||||
).fetchall()
|
||||
assert len(result) == k
|
||||
times.append(time.time() - t0)
|
||||
return BenchResult(f"sqlite-vec static", build_time, times)
|
||||
|
||||
def bench_faiss(base, query, k) -> BenchResult:
|
||||
import faiss
|
||||
dimensions = base.shape[1]
|
||||
|
|
@ -438,8 +385,6 @@ def suite(name, base, query, k, benchmarks):
|
|||
for b in benchmarks:
|
||||
if b == "faiss":
|
||||
results.append(bench_faiss(base, query, k=k))
|
||||
elif b == "vec-static":
|
||||
results.append(bench_sqlite_vec_static(base, query, k=k))
|
||||
elif b.startswith("vec-scalar"):
|
||||
_, page_size = b.split('.')
|
||||
results.append(bench_sqlite_vec_scalar(base, query, page_size, k=k))
|
||||
|
|
@ -541,7 +486,7 @@ def parse_args():
|
|||
help="Number of queries to use. Defaults all",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-x", help="type of runs to make", default="faiss,vec-scalar.4096,vec-static,vec-vec0.4096.16,usearch,duckdb,hnswlib,numpy"
|
||||
"-x", help="type of runs to make", default="faiss,vec-scalar.4096,vec-vec0.4096.16,usearch,duckdb,hnswlib,numpy"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
|
|
|||
|
|
@ -8,10 +8,3 @@ create virtual table vec_items using vec0(
|
|||
embedding float[1536]
|
||||
);
|
||||
|
||||
-- 65s (limit 1e5), ~615MB on disk
|
||||
insert into vec_items
|
||||
select
|
||||
rowid,
|
||||
vector
|
||||
from vec_npy_each(vec_npy_file('examples/dbpedia-openai/data/vectors.npy'))
|
||||
limit 1e5;
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ def connect(path):
|
|||
db = sqlite3.connect(path)
|
||||
db.enable_load_extension(True)
|
||||
db.load_extension("../dist/vec0")
|
||||
db.execute("select load_extension('../dist/vec0', 'sqlite3_vec_fs_read_init')")
|
||||
db.enable_load_extension(False)
|
||||
return db
|
||||
|
||||
|
|
@ -18,8 +17,6 @@ page_sizes = [ # 4096, 8192,
|
|||
chunk_sizes = [128, 256, 1024, 2048]
|
||||
types = ["f32", "int8", "bit"]
|
||||
|
||||
SRC = "../examples/dbpedia-openai/data/vectors.npy"
|
||||
|
||||
for page_size in page_sizes:
|
||||
for chunk_size in chunk_sizes:
|
||||
for t in types:
|
||||
|
|
@ -42,15 +39,8 @@ for page_size in page_sizes:
|
|||
func = "vec_quantize_i8(vector, 'unit')"
|
||||
if t == "bit":
|
||||
func = "vec_quantize_binary(vector)"
|
||||
db.execute(
|
||||
f"""
|
||||
insert into vec_items
|
||||
select rowid, {func}
|
||||
from vec_npy_each(vec_npy_file(?))
|
||||
limit 100000
|
||||
""",
|
||||
[SRC],
|
||||
)
|
||||
# TODO: replace with non-npy data loading
|
||||
pass
|
||||
elapsed = time.time() - t0
|
||||
print(elapsed)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ index ed2aaec..4cc0b0e 100755
|
|||
-Wl,--initial-memory=327680 \
|
||||
-D_HAVE_SQLITE_CONFIG_H \
|
||||
-DSQLITE_CUSTOM_INCLUDE=sqlite_opt.h \
|
||||
+ -DSQLITE_VEC_OMIT_FS=1 \
|
||||
$(awk '{print "-Wl,--export="$0}' exports.txt)
|
||||
|
||||
"$BINARYEN/wasm-ctor-eval" -g -c _initialize sqlite3.wasm -o sqlite3.tmp
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from typing import List
|
||||
from struct import pack
|
||||
from sqlite3 import Connection
|
||||
|
||||
|
||||
def serialize_float32(vector: List[float]) -> bytes:
|
||||
|
|
@ -13,33 +12,3 @@ def serialize_int8(vector: List[int]) -> bytes:
|
|||
return pack("%sb" % len(vector), *vector)
|
||||
|
||||
|
||||
try:
|
||||
import numpy.typing as npt
|
||||
|
||||
def register_numpy(db: Connection, name: str, array: npt.NDArray):
|
||||
"""ayoo"""
|
||||
|
||||
ptr = array.__array_interface__["data"][0]
|
||||
nvectors, dimensions = array.__array_interface__["shape"]
|
||||
element_type = array.__array_interface__["typestr"]
|
||||
|
||||
assert element_type == "<f4"
|
||||
|
||||
name_escaped = db.execute("select printf('%w', ?)", [name]).fetchone()[0]
|
||||
|
||||
db.execute(
|
||||
"""
|
||||
insert into temp.vec_static_blobs(name, data)
|
||||
select ?, vec_static_blob_from_raw(?, ?, ?, ?)
|
||||
""",
|
||||
[name, ptr, element_type, dimensions, nvectors],
|
||||
)
|
||||
|
||||
db.execute(
|
||||
f'create virtual table "{name_escaped}" using vec_static_blob_entries({name_escaped})'
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
|
||||
def register_numpy(db: Connection, name: str, array):
|
||||
raise Exception("numpy package is required for register_numpy")
|
||||
|
|
|
|||
119
scripts/amalgamate.py
Normal file
119
scripts/amalgamate.py
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Amalgamate sqlite-vec into a single distributable .c file.
|
||||
|
||||
Reads the dev sqlite-vec.c and inlines any #include "sqlite-vec-*.c" files,
|
||||
stripping LSP-support blocks and per-file include guards.
|
||||
|
||||
Usage:
|
||||
python3 scripts/amalgamate.py sqlite-vec.c > dist/sqlite-vec.c
|
||||
"""
|
||||
|
||||
import re
|
||||
import sys
|
||||
import os
|
||||
|
||||
|
||||
def strip_lsp_block(content):
|
||||
"""Remove the LSP-support pattern:
|
||||
#ifndef SQLITE_VEC_H
|
||||
#include "sqlite-vec.c" // ...
|
||||
#endif
|
||||
"""
|
||||
pattern = re.compile(
|
||||
r'^\s*#ifndef\s+SQLITE_VEC_H\s*\n'
|
||||
r'\s*#include\s+"sqlite-vec\.c"[^\n]*\n'
|
||||
r'\s*#endif[^\n]*\n',
|
||||
re.MULTILINE,
|
||||
)
|
||||
return pattern.sub('', content)
|
||||
|
||||
|
||||
def strip_include_guard(content, guard_macro):
|
||||
"""Remove the include guard pair:
|
||||
#ifndef GUARD_MACRO
|
||||
#define GUARD_MACRO
|
||||
...content...
|
||||
(trailing #endif removed)
|
||||
"""
|
||||
# Strip the #ifndef / #define pair at the top
|
||||
header_pattern = re.compile(
|
||||
r'^\s*#ifndef\s+' + re.escape(guard_macro) + r'\s*\n'
|
||||
r'\s*#define\s+' + re.escape(guard_macro) + r'\s*\n',
|
||||
re.MULTILINE,
|
||||
)
|
||||
content = header_pattern.sub('', content, count=1)
|
||||
|
||||
# Strip the trailing #endif (last one in file that closes the guard)
|
||||
# Find the last #endif and remove it
|
||||
lines = content.rstrip('\n').split('\n')
|
||||
for i in range(len(lines) - 1, -1, -1):
|
||||
if re.match(r'^\s*#endif', lines[i]):
|
||||
lines.pop(i)
|
||||
break
|
||||
|
||||
return '\n'.join(lines) + '\n'
|
||||
|
||||
|
||||
def detect_include_guard(content):
|
||||
"""Detect an include guard macro like SQLITE_VEC_IVF_C."""
|
||||
m = re.match(
|
||||
r'\s*(?:/\*[\s\S]*?\*/\s*)?' # optional block comment
|
||||
r'#ifndef\s+(SQLITE_VEC_\w+_C)\s*\n'
|
||||
r'#define\s+\1',
|
||||
content,
|
||||
)
|
||||
return m.group(1) if m else None
|
||||
|
||||
|
||||
def inline_include(match, base_dir):
|
||||
"""Replace an #include "sqlite-vec-*.c" with the file's contents."""
|
||||
filename = match.group(1)
|
||||
filepath = os.path.join(base_dir, filename)
|
||||
|
||||
if not os.path.exists(filepath):
|
||||
print(f"Warning: {filepath} not found, leaving #include in place", file=sys.stderr)
|
||||
return match.group(0)
|
||||
|
||||
with open(filepath, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
# Strip LSP-support block
|
||||
content = strip_lsp_block(content)
|
||||
|
||||
# Strip include guard if present
|
||||
guard = detect_include_guard(content)
|
||||
if guard:
|
||||
content = strip_include_guard(content, guard)
|
||||
|
||||
separator = '/' * 78
|
||||
header = f'\n{separator}\n// Begin inlined: {filename}\n{separator}\n\n'
|
||||
footer = f'\n{separator}\n// End inlined: {filename}\n{separator}\n'
|
||||
|
||||
return header + content.strip('\n') + footer
|
||||
|
||||
|
||||
def amalgamate(input_path):
|
||||
base_dir = os.path.dirname(os.path.abspath(input_path))
|
||||
|
||||
with open(input_path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
# Replace #include "sqlite-vec-*.c" with inlined contents
|
||||
include_pattern = re.compile(r'^#include\s+"(sqlite-vec-[^"]+\.c)"\s*$', re.MULTILINE)
|
||||
content = include_pattern.sub(lambda m: inline_include(m, base_dir), content)
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) != 2:
|
||||
print(f"Usage: {sys.argv[0]} <input-file>", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
result = amalgamate(sys.argv[1])
|
||||
sys.stdout.write(result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
#!/bin/bash
|
||||
mkdir -p vendor
|
||||
curl -o sqlite-amalgamation.zip https://www.sqlite.org/2024/sqlite-amalgamation-3450300.zip
|
||||
unzip -d
|
||||
unzip sqlite-amalgamation.zip
|
||||
mv sqlite-amalgamation-3450300/* vendor/
|
||||
rmdir sqlite-amalgamation-3450300
|
||||
|
|
|
|||
|
|
@ -568,65 +568,6 @@ select 'todo';
|
|||
-- 'todo'
|
||||
|
||||
|
||||
```
|
||||
|
||||
## NumPy Utilities {#numpy}
|
||||
|
||||
Functions to read data from or work with [NumPy arrays](https://numpy.org/doc/stable/reference/generated/numpy.array.html).
|
||||
|
||||
### `vec_npy_each(vector)` {#vec_npy_each}
|
||||
|
||||
xxx
|
||||
|
||||
|
||||
```sql
|
||||
-- db.execute('select quote(?)', [to_npy(np.array([[1.0], [2.0], [3.0]], dtype=np.float32))]).fetchone()
|
||||
select
|
||||
rowid,
|
||||
vector,
|
||||
vec_type(vector),
|
||||
vec_to_json(vector)
|
||||
from vec_npy_each(
|
||||
X'934E554D5059010076007B276465736372273A20273C6634272C2027666F727472616E5F6F72646572273A2046616C73652C20277368617065273A2028332C2031292C207D202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020200A0000803F0000004000004040'
|
||||
)
|
||||
/*
|
||||
┌───────┬─────────────┬──────────────────┬─────────────────────┐
|
||||
│ rowid │ vector │ vec_type(vector) │ vec_to_json(vector) │
|
||||
├───────┼─────────────┼──────────────────┼─────────────────────┤
|
||||
│ 0 │ X'0000803F' │ 'float32' │ '[1.000000]' │
|
||||
├───────┼─────────────┼──────────────────┼─────────────────────┤
|
||||
│ 1 │ X'00000040' │ 'float32' │ '[2.000000]' │
|
||||
├───────┼─────────────┼──────────────────┼─────────────────────┤
|
||||
│ 2 │ X'00004040' │ 'float32' │ '[3.000000]' │
|
||||
└───────┴─────────────┴──────────────────┴─────────────────────┘
|
||||
|
||||
*/
|
||||
|
||||
|
||||
-- db.execute('select quote(?)', [to_npy(np.array([[1.0], [2.0], [3.0]], dtype=np.float32))]).fetchone()
|
||||
select
|
||||
rowid,
|
||||
vector,
|
||||
vec_type(vector),
|
||||
vec_to_json(vector)
|
||||
from vec_npy_each(
|
||||
X'934E554D5059010076007B276465736372273A20273C6634272C2027666F727472616E5F6F72646572273A2046616C73652C20277368617065273A2028332C2031292C207D202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020200A0000803F0000004000004040'
|
||||
)
|
||||
/*
|
||||
┌───────┬─────────────┬──────────────────┬─────────────────────┐
|
||||
│ rowid │ vector │ vec_type(vector) │ vec_to_json(vector) │
|
||||
├───────┼─────────────┼──────────────────┼─────────────────────┤
|
||||
│ 0 │ X'0000803F' │ 'float32' │ '[1.000000]' │
|
||||
├───────┼─────────────┼──────────────────┼─────────────────────┤
|
||||
│ 1 │ X'00000040' │ 'float32' │ '[2.000000]' │
|
||||
├───────┼─────────────┼──────────────────┼─────────────────────┤
|
||||
│ 2 │ X'00004040' │ 'float32' │ '[3.000000]' │
|
||||
└───────┴─────────────┴──────────────────┴─────────────────────┘
|
||||
|
||||
*/
|
||||
|
||||
|
||||
|
||||
```
|
||||
|
||||
## Meta {#meta}
|
||||
|
|
|
|||
|
|
@ -59,5 +59,4 @@ The current compile-time flags are:
|
|||
|
||||
- `SQLITE_VEC_ENABLE_AVX`, enables AVX CPU instructions for some vector search operations
|
||||
- `SQLITE_VEC_ENABLE_NEON`, enables NEON CPU instructions for some vector search operations
|
||||
- `SQLITE_VEC_OMIT_FS`, removes some obsure SQL functions and features that use the filesystem, meant for some WASM builds where there's no available filesystem
|
||||
- `SQLITE_VEC_STATIC`, meant for statically linking `sqlite-vec`
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "sqlite-vec"
|
||||
license = "MIT OR Apache"
|
||||
license = "MIT OR Apache-2.0"
|
||||
homepage = "https://alexgarcia.xyz/sqlite-vec"
|
||||
repo = "https://github.com/asg017/sqlite-vec"
|
||||
description = "A vector search SQLite extension."
|
||||
|
|
|
|||
1889
sqlite-vec-diskann.c
Normal file
1889
sqlite-vec-diskann.c
Normal file
File diff suppressed because it is too large
Load diff
214
sqlite-vec-ivf-kmeans.c
Normal file
214
sqlite-vec-ivf-kmeans.c
Normal file
|
|
@ -0,0 +1,214 @@
|
|||
/**
|
||||
* sqlite-vec-ivf-kmeans.c — Pure k-means clustering algorithm.
|
||||
*
|
||||
* No SQLite dependency. Operates on float arrays in memory.
|
||||
* #include'd into sqlite-vec.c after struct definitions.
|
||||
*/
|
||||
|
||||
#ifndef SQLITE_VEC_IVF_KMEANS_C
|
||||
#define SQLITE_VEC_IVF_KMEANS_C
|
||||
|
||||
// When opened standalone in an editor, pull in types so the LSP is happy.
|
||||
// When #include'd from sqlite-vec.c, SQLITE_VEC_H is already defined.
|
||||
#ifndef SQLITE_VEC_H
|
||||
#include "sqlite-vec.c" // IWYU pragma: keep
|
||||
#endif
|
||||
|
||||
#include <float.h>
|
||||
#include <string.h>
|
||||
|
||||
#define VEC0_IVF_KMEANS_MAX_ITER 25
|
||||
#define VEC0_IVF_KMEANS_DEFAULT_SEED 0
|
||||
|
||||
// Simple xorshift32 PRNG
|
||||
static uint32_t ivf_xorshift32(uint32_t *state) {
|
||||
uint32_t x = *state;
|
||||
x ^= x << 13;
|
||||
x ^= x >> 17;
|
||||
x ^= x << 5;
|
||||
*state = x;
|
||||
return x;
|
||||
}
|
||||
|
||||
// L2 squared distance between two float vectors
|
||||
static float ivf_l2_dist(const float *a, const float *b, int D) {
|
||||
float sum = 0.0f;
|
||||
for (int d = 0; d < D; d++) {
|
||||
float diff = a[d] - b[d];
|
||||
sum += diff * diff;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
// Find nearest centroid for a single vector. Returns centroid index.
|
||||
static int ivf_nearest_centroid(const float *vec, const float *centroids,
|
||||
int D, int k) {
|
||||
float min_dist = FLT_MAX;
|
||||
int best = 0;
|
||||
for (int c = 0; c < k; c++) {
|
||||
float dist = ivf_l2_dist(vec, ¢roids[c * D], D);
|
||||
if (dist < min_dist) {
|
||||
min_dist = dist;
|
||||
best = c;
|
||||
}
|
||||
}
|
||||
return best;
|
||||
}
|
||||
|
||||
/**
|
||||
* K-means++ initialization.
|
||||
* Picks k initial centroids from the data with probability proportional
|
||||
* to squared distance from nearest existing centroid.
|
||||
*/
|
||||
static int ivf_kmeans_init_plusplus(const float *vectors, int N, int D,
|
||||
int k, uint32_t seed, float *centroids) {
|
||||
if (N <= 0 || k <= 0 || D <= 0)
|
||||
return -1;
|
||||
if (seed == 0)
|
||||
seed = 42;
|
||||
|
||||
// Pick first centroid randomly
|
||||
int first = ivf_xorshift32(&seed) % N;
|
||||
memcpy(centroids, &vectors[first * D], D * sizeof(float));
|
||||
|
||||
if (k == 1)
|
||||
return 0;
|
||||
|
||||
// Allocate distance array
|
||||
float *dists = sqlite3_malloc64((i64)N * sizeof(float));
|
||||
if (!dists)
|
||||
return -1;
|
||||
|
||||
for (int c = 1; c < k; c++) {
|
||||
// Compute D(x) = distance to nearest existing centroid
|
||||
double total = 0.0;
|
||||
for (int i = 0; i < N; i++) {
|
||||
float d = ivf_l2_dist(&vectors[i * D], ¢roids[(c - 1) * D], D);
|
||||
if (c == 1 || d < dists[i]) {
|
||||
dists[i] = d;
|
||||
}
|
||||
total += dists[i];
|
||||
}
|
||||
|
||||
// Weighted random selection
|
||||
if (total <= 0.0) {
|
||||
// All distances zero — pick randomly
|
||||
int pick = ivf_xorshift32(&seed) % N;
|
||||
memcpy(¢roids[c * D], &vectors[pick * D], D * sizeof(float));
|
||||
} else {
|
||||
double threshold = ((double)ivf_xorshift32(&seed) / (double)0xFFFFFFFF) * total;
|
||||
double cumulative = 0.0;
|
||||
int pick = N - 1;
|
||||
for (int i = 0; i < N; i++) {
|
||||
cumulative += dists[i];
|
||||
if (cumulative >= threshold) {
|
||||
pick = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
memcpy(¢roids[c * D], &vectors[pick * D], D * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
sqlite3_free(dists);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Lloyd's k-means algorithm.
|
||||
*
|
||||
* @param vectors N*D float array (row-major)
|
||||
* @param N number of vectors
|
||||
* @param D dimensionality
|
||||
* @param k number of clusters
|
||||
* @param max_iter maximum iterations
|
||||
* @param seed PRNG seed for initialization
|
||||
* @param out_centroids output: k*D float array (caller-allocated)
|
||||
* @return 0 on success, -1 on error
|
||||
*/
|
||||
static int ivf_kmeans(const float *vectors, int N, int D, int k,
|
||||
int max_iter, uint32_t seed, float *out_centroids) {
|
||||
if (N <= 0 || D <= 0 || k <= 0)
|
||||
return -1;
|
||||
|
||||
// Clamp k to N
|
||||
if (k > N)
|
||||
k = N;
|
||||
|
||||
// Allocate working memory
|
||||
int *assignments = sqlite3_malloc64((i64)N * sizeof(int));
|
||||
float *new_centroids = sqlite3_malloc64((i64)k * D * sizeof(float));
|
||||
int *counts = sqlite3_malloc64((i64)k * sizeof(int));
|
||||
|
||||
if (!assignments || !new_centroids || !counts) {
|
||||
sqlite3_free(assignments);
|
||||
sqlite3_free(new_centroids);
|
||||
sqlite3_free(counts);
|
||||
return -1;
|
||||
}
|
||||
|
||||
memset(assignments, -1, N * sizeof(int));
|
||||
|
||||
// Initialize centroids via k-means++
|
||||
if (ivf_kmeans_init_plusplus(vectors, N, D, k, seed, out_centroids) != 0) {
|
||||
sqlite3_free(assignments);
|
||||
sqlite3_free(new_centroids);
|
||||
sqlite3_free(counts);
|
||||
return -1;
|
||||
}
|
||||
|
||||
for (int iter = 0; iter < max_iter; iter++) {
|
||||
// Assignment step
|
||||
int changed = 0;
|
||||
for (int i = 0; i < N; i++) {
|
||||
int nearest = ivf_nearest_centroid(&vectors[i * D], out_centroids, D, k);
|
||||
if (nearest != assignments[i]) {
|
||||
assignments[i] = nearest;
|
||||
changed++;
|
||||
}
|
||||
}
|
||||
if (changed == 0)
|
||||
break;
|
||||
|
||||
// Update step
|
||||
memset(new_centroids, 0, (size_t)k * D * sizeof(float));
|
||||
memset(counts, 0, k * sizeof(int));
|
||||
|
||||
for (int i = 0; i < N; i++) {
|
||||
int c = assignments[i];
|
||||
counts[c]++;
|
||||
for (int d = 0; d < D; d++) {
|
||||
new_centroids[c * D + d] += vectors[i * D + d];
|
||||
}
|
||||
}
|
||||
|
||||
for (int c = 0; c < k; c++) {
|
||||
if (counts[c] == 0) {
|
||||
// Empty cluster: reassign to farthest point from its nearest centroid
|
||||
float max_dist = -1.0f;
|
||||
int farthest = 0;
|
||||
for (int i = 0; i < N; i++) {
|
||||
float d = ivf_l2_dist(&vectors[i * D],
|
||||
&out_centroids[assignments[i] * D], D);
|
||||
if (d > max_dist) {
|
||||
max_dist = d;
|
||||
farthest = i;
|
||||
}
|
||||
}
|
||||
memcpy(&out_centroids[c * D], &vectors[farthest * D],
|
||||
D * sizeof(float));
|
||||
} else {
|
||||
for (int d = 0; d < D; d++) {
|
||||
out_centroids[c * D + d] = new_centroids[c * D + d] / counts[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sqlite3_free(assignments);
|
||||
sqlite3_free(new_centroids);
|
||||
sqlite3_free(counts);
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif /* SQLITE_VEC_IVF_KMEANS_C */
|
||||
1445
sqlite-vec-ivf.c
Normal file
1445
sqlite-vec-ivf.c
Normal file
File diff suppressed because it is too large
Load diff
687
sqlite-vec-rescore.c
Normal file
687
sqlite-vec-rescore.c
Normal file
|
|
@ -0,0 +1,687 @@
|
|||
/**
|
||||
* sqlite-vec-rescore.c — Rescore index logic for sqlite-vec.
|
||||
*
|
||||
* This file is #included into sqlite-vec.c after the vec0_vtab definition.
|
||||
* All functions receive a vec0_vtab *p and access p->vector_columns[i].rescore.
|
||||
*
|
||||
* Shadow tables per rescore-enabled vector column:
|
||||
* _rescore_chunks{NN} — quantized vectors in chunk layout (for coarse scan)
|
||||
* _rescore_vectors{NN} — float vectors keyed by rowid (for fast rescore lookup)
|
||||
*/
|
||||
|
||||
// ============================================================================
|
||||
// Shadow table lifecycle
|
||||
// ============================================================================
|
||||
|
||||
static int rescore_create_tables(vec0_vtab *p, sqlite3 *db, char **pzErr) {
|
||||
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||
if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_RESCORE)
|
||||
continue;
|
||||
|
||||
// Quantized chunk table (same layout as _vector_chunks)
|
||||
char *zSql = sqlite3_mprintf(
|
||||
"CREATE TABLE \"%w\".\"%w_rescore_chunks%02d\""
|
||||
"(rowid PRIMARY KEY, vectors BLOB NOT NULL)",
|
||||
p->schemaName, p->tableName, i);
|
||||
if (!zSql)
|
||||
return SQLITE_NOMEM;
|
||||
sqlite3_stmt *stmt;
|
||||
int rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, 0);
|
||||
sqlite3_free(zSql);
|
||||
if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) {
|
||||
*pzErr = sqlite3_mprintf(
|
||||
"Could not create '_rescore_chunks%02d' shadow table: %s", i,
|
||||
sqlite3_errmsg(db));
|
||||
sqlite3_finalize(stmt);
|
||||
return SQLITE_ERROR;
|
||||
}
|
||||
sqlite3_finalize(stmt);
|
||||
|
||||
// Float vector table (rowid-keyed for fast random access)
|
||||
zSql = sqlite3_mprintf(
|
||||
"CREATE TABLE \"%w\".\"%w_rescore_vectors%02d\""
|
||||
"(rowid INTEGER PRIMARY KEY, vector BLOB NOT NULL)",
|
||||
p->schemaName, p->tableName, i);
|
||||
if (!zSql)
|
||||
return SQLITE_NOMEM;
|
||||
rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, 0);
|
||||
sqlite3_free(zSql);
|
||||
if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) {
|
||||
*pzErr = sqlite3_mprintf(
|
||||
"Could not create '_rescore_vectors%02d' shadow table: %s", i,
|
||||
sqlite3_errmsg(db));
|
||||
sqlite3_finalize(stmt);
|
||||
return SQLITE_ERROR;
|
||||
}
|
||||
sqlite3_finalize(stmt);
|
||||
}
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
static int rescore_drop_tables(vec0_vtab *p) {
|
||||
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||
sqlite3_stmt *stmt;
|
||||
int rc;
|
||||
char *zSql;
|
||||
|
||||
if (p->shadowRescoreChunksNames[i]) {
|
||||
zSql = sqlite3_mprintf("DROP TABLE IF EXISTS \"%w\".\"%w\"",
|
||||
p->schemaName, p->shadowRescoreChunksNames[i]);
|
||||
if (!zSql)
|
||||
return SQLITE_NOMEM;
|
||||
rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0);
|
||||
sqlite3_free(zSql);
|
||||
if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) {
|
||||
sqlite3_finalize(stmt);
|
||||
return SQLITE_ERROR;
|
||||
}
|
||||
sqlite3_finalize(stmt);
|
||||
}
|
||||
|
||||
if (p->shadowRescoreVectorsNames[i]) {
|
||||
zSql = sqlite3_mprintf("DROP TABLE IF EXISTS \"%w\".\"%w\"",
|
||||
p->schemaName, p->shadowRescoreVectorsNames[i]);
|
||||
if (!zSql)
|
||||
return SQLITE_NOMEM;
|
||||
rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0);
|
||||
sqlite3_free(zSql);
|
||||
if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) {
|
||||
sqlite3_finalize(stmt);
|
||||
return SQLITE_ERROR;
|
||||
}
|
||||
sqlite3_finalize(stmt);
|
||||
}
|
||||
}
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
static size_t rescore_quantized_byte_size(struct VectorColumnDefinition *col) {
|
||||
switch (col->rescore.quantizer_type) {
|
||||
case VEC0_RESCORE_QUANTIZER_BIT:
|
||||
return col->dimensions / CHAR_BIT;
|
||||
case VEC0_RESCORE_QUANTIZER_INT8:
|
||||
return col->dimensions;
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Insert a new chunk row into each _rescore_chunks{NN} table with a zeroblob.
|
||||
*/
|
||||
static int rescore_new_chunk(vec0_vtab *p, i64 chunk_rowid) {
|
||||
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||
if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_RESCORE)
|
||||
continue;
|
||||
size_t quantized_size =
|
||||
rescore_quantized_byte_size(&p->vector_columns[i]);
|
||||
i64 blob_size = (i64)p->chunk_size * (i64)quantized_size;
|
||||
|
||||
char *zSql = sqlite3_mprintf(
|
||||
"INSERT INTO \"%w\".\"%w\"(_rowid_, rowid, vectors) VALUES (?, ?, ?)",
|
||||
p->schemaName, p->shadowRescoreChunksNames[i]);
|
||||
if (!zSql)
|
||||
return SQLITE_NOMEM;
|
||||
sqlite3_stmt *stmt;
|
||||
int rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL);
|
||||
sqlite3_free(zSql);
|
||||
if (rc != SQLITE_OK) {
|
||||
sqlite3_finalize(stmt);
|
||||
return rc;
|
||||
}
|
||||
sqlite3_bind_int64(stmt, 1, chunk_rowid);
|
||||
sqlite3_bind_int64(stmt, 2, chunk_rowid);
|
||||
sqlite3_bind_zeroblob64(stmt, 3, blob_size);
|
||||
rc = sqlite3_step(stmt);
|
||||
sqlite3_finalize(stmt);
|
||||
if (rc != SQLITE_DONE)
|
||||
return rc;
|
||||
}
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Quantization
|
||||
// ============================================================================
|
||||
|
||||
static void rescore_quantize_float_to_bit(const float *src, uint8_t *dst,
|
||||
size_t dimensions) {
|
||||
memset(dst, 0, dimensions / CHAR_BIT);
|
||||
for (size_t i = 0; i < dimensions; i++) {
|
||||
if (src[i] >= 0.0f) {
|
||||
dst[i / CHAR_BIT] |= (1 << (i % CHAR_BIT));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void rescore_quantize_float_to_int8(const float *src, int8_t *dst,
|
||||
size_t dimensions) {
|
||||
float step = 2.0f / 255.0f;
|
||||
for (size_t i = 0; i < dimensions; i++) {
|
||||
float v = (src[i] - (-1.0f)) / step - 128.0f;
|
||||
if (!(v <= 127.0f)) v = 127.0f;
|
||||
if (!(v >= -128.0f)) v = -128.0f;
|
||||
dst[i] = (int8_t)v;
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Insert path
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Quantize float vector to _rescore_chunks and store in _rescore_vectors.
|
||||
*/
|
||||
static int rescore_on_insert(vec0_vtab *p, i64 chunk_rowid, i64 chunk_offset,
|
||||
i64 rowid, void *vectorDatas[]) {
|
||||
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||
if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_RESCORE)
|
||||
continue;
|
||||
|
||||
struct VectorColumnDefinition *col = &p->vector_columns[i];
|
||||
size_t qsize = rescore_quantized_byte_size(col);
|
||||
size_t fsize = vector_column_byte_size(*col);
|
||||
int rc;
|
||||
|
||||
// 1. Write quantized vector to _rescore_chunks blob
|
||||
{
|
||||
void *qbuf = sqlite3_malloc(qsize);
|
||||
if (!qbuf)
|
||||
return SQLITE_NOMEM;
|
||||
|
||||
switch (col->rescore.quantizer_type) {
|
||||
case VEC0_RESCORE_QUANTIZER_BIT:
|
||||
rescore_quantize_float_to_bit((const float *)vectorDatas[i],
|
||||
(uint8_t *)qbuf, col->dimensions);
|
||||
break;
|
||||
case VEC0_RESCORE_QUANTIZER_INT8:
|
||||
rescore_quantize_float_to_int8((const float *)vectorDatas[i],
|
||||
(int8_t *)qbuf, col->dimensions);
|
||||
break;
|
||||
}
|
||||
|
||||
sqlite3_blob *blob = NULL;
|
||||
rc = sqlite3_blob_open(p->db, p->schemaName,
|
||||
p->shadowRescoreChunksNames[i], "vectors",
|
||||
chunk_rowid, 1, &blob);
|
||||
if (rc != SQLITE_OK) {
|
||||
sqlite3_free(qbuf);
|
||||
return rc;
|
||||
}
|
||||
rc = sqlite3_blob_write(blob, qbuf, qsize, chunk_offset * qsize);
|
||||
sqlite3_free(qbuf);
|
||||
int brc = sqlite3_blob_close(blob);
|
||||
if (rc != SQLITE_OK)
|
||||
return rc;
|
||||
if (brc != SQLITE_OK)
|
||||
return brc;
|
||||
}
|
||||
|
||||
// 2. Insert float vector into _rescore_vectors (rowid-keyed)
|
||||
{
|
||||
char *zSql = sqlite3_mprintf(
|
||||
"INSERT INTO \"%w\".\"%w\"(rowid, vector) VALUES (?, ?)",
|
||||
p->schemaName, p->shadowRescoreVectorsNames[i]);
|
||||
if (!zSql)
|
||||
return SQLITE_NOMEM;
|
||||
sqlite3_stmt *stmt;
|
||||
rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL);
|
||||
sqlite3_free(zSql);
|
||||
if (rc != SQLITE_OK) {
|
||||
sqlite3_finalize(stmt);
|
||||
return rc;
|
||||
}
|
||||
sqlite3_bind_int64(stmt, 1, rowid);
|
||||
sqlite3_bind_blob(stmt, 2, vectorDatas[i], fsize, SQLITE_TRANSIENT);
|
||||
rc = sqlite3_step(stmt);
|
||||
sqlite3_finalize(stmt);
|
||||
if (rc != SQLITE_DONE)
|
||||
return SQLITE_ERROR;
|
||||
}
|
||||
}
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Delete path
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Zero out quantized vector in _rescore_chunks and delete from _rescore_vectors.
|
||||
*/
|
||||
static int rescore_on_delete(vec0_vtab *p, i64 chunk_id, u64 chunk_offset,
|
||||
i64 rowid) {
|
||||
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||
if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_RESCORE)
|
||||
continue;
|
||||
int rc;
|
||||
|
||||
// 1. Zero out quantized data in _rescore_chunks
|
||||
{
|
||||
size_t qsize = rescore_quantized_byte_size(&p->vector_columns[i]);
|
||||
void *zeroBuf = sqlite3_malloc(qsize);
|
||||
if (!zeroBuf)
|
||||
return SQLITE_NOMEM;
|
||||
memset(zeroBuf, 0, qsize);
|
||||
|
||||
sqlite3_blob *blob = NULL;
|
||||
rc = sqlite3_blob_open(p->db, p->schemaName,
|
||||
p->shadowRescoreChunksNames[i], "vectors",
|
||||
chunk_id, 1, &blob);
|
||||
if (rc != SQLITE_OK) {
|
||||
sqlite3_free(zeroBuf);
|
||||
return rc;
|
||||
}
|
||||
rc = sqlite3_blob_write(blob, zeroBuf, qsize, chunk_offset * qsize);
|
||||
sqlite3_free(zeroBuf);
|
||||
int brc = sqlite3_blob_close(blob);
|
||||
if (rc != SQLITE_OK)
|
||||
return rc;
|
||||
if (brc != SQLITE_OK)
|
||||
return brc;
|
||||
}
|
||||
|
||||
// 2. Delete from _rescore_vectors
|
||||
{
|
||||
char *zSql = sqlite3_mprintf(
|
||||
"DELETE FROM \"%w\".\"%w\" WHERE rowid = ?",
|
||||
p->schemaName, p->shadowRescoreVectorsNames[i]);
|
||||
if (!zSql)
|
||||
return SQLITE_NOMEM;
|
||||
sqlite3_stmt *stmt;
|
||||
rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL);
|
||||
sqlite3_free(zSql);
|
||||
if (rc != SQLITE_OK)
|
||||
return rc;
|
||||
sqlite3_bind_int64(stmt, 1, rowid);
|
||||
rc = sqlite3_step(stmt);
|
||||
sqlite3_finalize(stmt);
|
||||
if (rc != SQLITE_DONE)
|
||||
return SQLITE_ERROR;
|
||||
}
|
||||
}
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete a chunk row from _rescore_chunks{NN} tables.
|
||||
* (_rescore_vectors rows were already deleted per-row in rescore_on_delete)
|
||||
*/
|
||||
static int rescore_delete_chunk(vec0_vtab *p, i64 chunk_id) {
|
||||
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||
if (!p->shadowRescoreChunksNames[i])
|
||||
continue;
|
||||
char *zSql = sqlite3_mprintf(
|
||||
"DELETE FROM \"%w\".\"%w\" WHERE rowid = ?",
|
||||
p->schemaName, p->shadowRescoreChunksNames[i]);
|
||||
if (!zSql)
|
||||
return SQLITE_NOMEM;
|
||||
sqlite3_stmt *stmt;
|
||||
int rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL);
|
||||
sqlite3_free(zSql);
|
||||
if (rc != SQLITE_OK)
|
||||
return rc;
|
||||
sqlite3_bind_int64(stmt, 1, chunk_id);
|
||||
rc = sqlite3_step(stmt);
|
||||
sqlite3_finalize(stmt);
|
||||
if (rc != SQLITE_DONE)
|
||||
return SQLITE_ERROR;
|
||||
}
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// KNN rescore query
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Phase 1: Coarse scan of quantized chunks → top k*oversample candidates (rowids).
|
||||
* Phase 2: For each candidate, blob_open _rescore_vectors by rowid, read float
|
||||
* vector, compute float distance. Sort, return top k.
|
||||
*
|
||||
* Phase 2 is fast because _rescore_vectors has INTEGER PRIMARY KEY, so
|
||||
* sqlite3_blob_open/reopen addresses rows directly by rowid — no index lookup.
|
||||
*/
|
||||
static int rescore_knn(vec0_vtab *p, vec0_cursor *pCur,
|
||||
struct VectorColumnDefinition *vector_column,
|
||||
int vectorColumnIdx, struct Array *arrayRowidsIn,
|
||||
struct Array *aMetadataIn, const char *idxStr, int argc,
|
||||
sqlite3_value **argv, void *queryVector, i64 k,
|
||||
struct vec0_query_knn_data *knn_data) {
|
||||
(void)pCur;
|
||||
(void)aMetadataIn;
|
||||
int rc = SQLITE_OK;
|
||||
int oversample = vector_column->rescore.oversample_search > 0
|
||||
? vector_column->rescore.oversample_search
|
||||
: vector_column->rescore.oversample;
|
||||
i64 k_oversample = k * oversample;
|
||||
if (k_oversample > 4096)
|
||||
k_oversample = 4096;
|
||||
|
||||
size_t qdim = vector_column->dimensions;
|
||||
size_t qsize = rescore_quantized_byte_size(vector_column);
|
||||
size_t fsize = vector_column_byte_size(*vector_column);
|
||||
|
||||
// Quantize the query vector
|
||||
void *quantizedQuery = sqlite3_malloc(qsize);
|
||||
if (!quantizedQuery)
|
||||
return SQLITE_NOMEM;
|
||||
|
||||
switch (vector_column->rescore.quantizer_type) {
|
||||
case VEC0_RESCORE_QUANTIZER_BIT:
|
||||
rescore_quantize_float_to_bit((const float *)queryVector,
|
||||
(uint8_t *)quantizedQuery, qdim);
|
||||
break;
|
||||
case VEC0_RESCORE_QUANTIZER_INT8:
|
||||
rescore_quantize_float_to_int8((const float *)queryVector,
|
||||
(int8_t *)quantizedQuery, qdim);
|
||||
break;
|
||||
}
|
||||
|
||||
// Phase 1: Scan quantized chunks for k*oversample candidates
|
||||
sqlite3_stmt *stmtChunks = NULL;
|
||||
rc = vec0_chunks_iter(p, idxStr, argc, argv, &stmtChunks);
|
||||
if (rc != SQLITE_OK) {
|
||||
sqlite3_free(quantizedQuery);
|
||||
return rc;
|
||||
}
|
||||
|
||||
i64 *cand_rowids = sqlite3_malloc(k_oversample * sizeof(i64));
|
||||
f32 *cand_distances = sqlite3_malloc(k_oversample * sizeof(f32));
|
||||
i64 *tmp_rowids = sqlite3_malloc(k_oversample * sizeof(i64));
|
||||
f32 *tmp_distances = sqlite3_malloc(k_oversample * sizeof(f32));
|
||||
f32 *chunk_distances = sqlite3_malloc(p->chunk_size * sizeof(f32));
|
||||
i32 *chunk_topk_idxs = sqlite3_malloc(k_oversample * sizeof(i32));
|
||||
u8 *b = sqlite3_malloc(p->chunk_size / CHAR_BIT);
|
||||
u8 *bTaken = sqlite3_malloc(p->chunk_size / CHAR_BIT);
|
||||
u8 *bmRowids = NULL;
|
||||
void *baseVectors = sqlite3_malloc((i64)p->chunk_size * (i64)qsize);
|
||||
|
||||
if (!cand_rowids || !cand_distances || !tmp_rowids || !tmp_distances ||
|
||||
!chunk_distances || !chunk_topk_idxs || !b || !bTaken || !baseVectors) {
|
||||
rc = SQLITE_NOMEM;
|
||||
goto cleanup;
|
||||
}
|
||||
memset(cand_rowids, 0, k_oversample * sizeof(i64));
|
||||
memset(cand_distances, 0, k_oversample * sizeof(f32));
|
||||
|
||||
if (arrayRowidsIn) {
|
||||
bmRowids = sqlite3_malloc(p->chunk_size / CHAR_BIT);
|
||||
if (!bmRowids) {
|
||||
rc = SQLITE_NOMEM;
|
||||
goto cleanup;
|
||||
}
|
||||
}
|
||||
|
||||
i64 cand_used = 0;
|
||||
|
||||
while (1) {
|
||||
rc = sqlite3_step(stmtChunks);
|
||||
if (rc == SQLITE_DONE)
|
||||
break;
|
||||
if (rc != SQLITE_ROW) {
|
||||
rc = SQLITE_ERROR;
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
i64 chunk_id = sqlite3_column_int64(stmtChunks, 0);
|
||||
unsigned char *chunkValidity =
|
||||
(unsigned char *)sqlite3_column_blob(stmtChunks, 1);
|
||||
i64 *chunkRowids = (i64 *)sqlite3_column_blob(stmtChunks, 2);
|
||||
int validityBytes = sqlite3_column_bytes(stmtChunks, 1);
|
||||
int rowidsBytes = sqlite3_column_bytes(stmtChunks, 2);
|
||||
if (!chunkValidity || !chunkRowids) {
|
||||
rc = SQLITE_ERROR;
|
||||
goto cleanup;
|
||||
}
|
||||
// Validate blob sizes match chunk_size expectations
|
||||
if (validityBytes < (p->chunk_size + 7) / 8 ||
|
||||
rowidsBytes < p->chunk_size * (int)sizeof(i64)) {
|
||||
rc = SQLITE_ERROR;
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
memset(chunk_distances, 0, p->chunk_size * sizeof(f32));
|
||||
memset(chunk_topk_idxs, 0, k_oversample * sizeof(i32));
|
||||
bitmap_copy(b, chunkValidity, p->chunk_size);
|
||||
|
||||
if (arrayRowidsIn) {
|
||||
bitmap_clear(bmRowids, p->chunk_size);
|
||||
for (int j = 0; j < p->chunk_size; j++) {
|
||||
if (!bitmap_get(chunkValidity, j))
|
||||
continue;
|
||||
i64 rid = chunkRowids[j];
|
||||
void *found = bsearch(&rid, arrayRowidsIn->z, arrayRowidsIn->length,
|
||||
sizeof(i64), _cmp);
|
||||
bitmap_set(bmRowids, j, found ? 1 : 0);
|
||||
}
|
||||
bitmap_and_inplace(b, bmRowids, p->chunk_size);
|
||||
}
|
||||
|
||||
// Read quantized vectors
|
||||
sqlite3_blob *blobQ = NULL;
|
||||
rc = sqlite3_blob_open(p->db, p->schemaName,
|
||||
p->shadowRescoreChunksNames[vectorColumnIdx],
|
||||
"vectors", chunk_id, 0, &blobQ);
|
||||
if (rc != SQLITE_OK)
|
||||
goto cleanup;
|
||||
rc = sqlite3_blob_read(blobQ, baseVectors,
|
||||
(i64)p->chunk_size * (i64)qsize, 0);
|
||||
sqlite3_blob_close(blobQ);
|
||||
if (rc != SQLITE_OK)
|
||||
goto cleanup;
|
||||
|
||||
// Compute quantized distances
|
||||
for (int j = 0; j < p->chunk_size; j++) {
|
||||
if (!bitmap_get(b, j))
|
||||
continue;
|
||||
f32 dist = FLT_MAX;
|
||||
switch (vector_column->rescore.quantizer_type) {
|
||||
case VEC0_RESCORE_QUANTIZER_BIT: {
|
||||
const u8 *base_j = ((u8 *)baseVectors) + (j * (qdim / CHAR_BIT));
|
||||
dist = distance_hamming(base_j, (u8 *)quantizedQuery, &qdim);
|
||||
break;
|
||||
}
|
||||
case VEC0_RESCORE_QUANTIZER_INT8: {
|
||||
const i8 *base_j = ((i8 *)baseVectors) + (j * qdim);
|
||||
switch (vector_column->distance_metric) {
|
||||
case VEC0_DISTANCE_METRIC_L2:
|
||||
dist = distance_l2_sqr_int8(base_j, (i8 *)quantizedQuery, &qdim);
|
||||
break;
|
||||
case VEC0_DISTANCE_METRIC_COSINE:
|
||||
dist = distance_cosine_int8(base_j, (i8 *)quantizedQuery, &qdim);
|
||||
break;
|
||||
case VEC0_DISTANCE_METRIC_L1:
|
||||
dist = (f32)distance_l1_int8(base_j, (i8 *)quantizedQuery, &qdim);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
chunk_distances[j] = dist;
|
||||
}
|
||||
|
||||
int used1;
|
||||
min_idx(chunk_distances, p->chunk_size, b, chunk_topk_idxs,
|
||||
min(k_oversample, p->chunk_size), bTaken, &used1);
|
||||
|
||||
i64 merged_used;
|
||||
merge_sorted_lists(cand_distances, cand_rowids, cand_used, chunk_distances,
|
||||
chunkRowids, chunk_topk_idxs,
|
||||
min(min(k_oversample, p->chunk_size), used1),
|
||||
tmp_distances, tmp_rowids, k_oversample, &merged_used);
|
||||
|
||||
for (i64 j = 0; j < merged_used; j++) {
|
||||
cand_rowids[j] = tmp_rowids[j];
|
||||
cand_distances[j] = tmp_distances[j];
|
||||
}
|
||||
cand_used = merged_used;
|
||||
}
|
||||
rc = SQLITE_OK;
|
||||
|
||||
// Phase 2: Rescore candidates using _rescore_vectors (rowid-keyed)
|
||||
if (cand_used == 0) {
|
||||
knn_data->current_idx = 0;
|
||||
knn_data->k = 0;
|
||||
knn_data->rowids = NULL;
|
||||
knn_data->distances = NULL;
|
||||
knn_data->k_used = 0;
|
||||
goto cleanup;
|
||||
}
|
||||
{
|
||||
f32 *float_distances = sqlite3_malloc(cand_used * sizeof(f32));
|
||||
void *fBuf = sqlite3_malloc(fsize);
|
||||
if (!float_distances || !fBuf) {
|
||||
sqlite3_free(float_distances);
|
||||
sqlite3_free(fBuf);
|
||||
rc = SQLITE_NOMEM;
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
// Open blob on _rescore_vectors, then reopen for each candidate rowid.
|
||||
// blob_reopen is O(1) for INTEGER PRIMARY KEY tables.
|
||||
sqlite3_blob *blobFloat = NULL;
|
||||
rc = sqlite3_blob_open(p->db, p->schemaName,
|
||||
p->shadowRescoreVectorsNames[vectorColumnIdx],
|
||||
"vector", cand_rowids[0], 0, &blobFloat);
|
||||
if (rc != SQLITE_OK) {
|
||||
sqlite3_free(float_distances);
|
||||
sqlite3_free(fBuf);
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
rc = sqlite3_blob_read(blobFloat, fBuf, fsize, 0);
|
||||
if (rc != SQLITE_OK) {
|
||||
sqlite3_blob_close(blobFloat);
|
||||
sqlite3_free(float_distances);
|
||||
sqlite3_free(fBuf);
|
||||
goto cleanup;
|
||||
}
|
||||
float_distances[0] =
|
||||
vec0_distance_full(fBuf, queryVector, vector_column->dimensions,
|
||||
vector_column->element_type,
|
||||
vector_column->distance_metric);
|
||||
|
||||
for (i64 j = 1; j < cand_used; j++) {
|
||||
rc = sqlite3_blob_reopen(blobFloat, cand_rowids[j]);
|
||||
if (rc != SQLITE_OK) {
|
||||
sqlite3_blob_close(blobFloat);
|
||||
sqlite3_free(float_distances);
|
||||
sqlite3_free(fBuf);
|
||||
goto cleanup;
|
||||
}
|
||||
rc = sqlite3_blob_read(blobFloat, fBuf, fsize, 0);
|
||||
if (rc != SQLITE_OK) {
|
||||
sqlite3_blob_close(blobFloat);
|
||||
sqlite3_free(float_distances);
|
||||
sqlite3_free(fBuf);
|
||||
goto cleanup;
|
||||
}
|
||||
float_distances[j] =
|
||||
vec0_distance_full(fBuf, queryVector, vector_column->dimensions,
|
||||
vector_column->element_type,
|
||||
vector_column->distance_metric);
|
||||
}
|
||||
sqlite3_blob_close(blobFloat);
|
||||
sqlite3_free(fBuf);
|
||||
|
||||
// Sort by float distance
|
||||
for (i64 a = 0; a + 1 < cand_used; a++) {
|
||||
i64 minIdx = a;
|
||||
for (i64 c = a + 1; c < cand_used; c++) {
|
||||
if (float_distances[c] < float_distances[minIdx])
|
||||
minIdx = c;
|
||||
}
|
||||
if (minIdx != a) {
|
||||
f32 td = float_distances[a];
|
||||
float_distances[a] = float_distances[minIdx];
|
||||
float_distances[minIdx] = td;
|
||||
i64 tr = cand_rowids[a];
|
||||
cand_rowids[a] = cand_rowids[minIdx];
|
||||
cand_rowids[minIdx] = tr;
|
||||
}
|
||||
}
|
||||
|
||||
i64 result_k = min(k, cand_used);
|
||||
i64 *out_rowids = sqlite3_malloc(result_k * sizeof(i64));
|
||||
f32 *out_distances = sqlite3_malloc(result_k * sizeof(f32));
|
||||
if (!out_rowids || !out_distances) {
|
||||
sqlite3_free(out_rowids);
|
||||
sqlite3_free(out_distances);
|
||||
sqlite3_free(float_distances);
|
||||
rc = SQLITE_NOMEM;
|
||||
goto cleanup;
|
||||
}
|
||||
for (i64 j = 0; j < result_k; j++) {
|
||||
out_rowids[j] = cand_rowids[j];
|
||||
out_distances[j] = float_distances[j];
|
||||
}
|
||||
|
||||
knn_data->current_idx = 0;
|
||||
knn_data->k = result_k;
|
||||
knn_data->rowids = out_rowids;
|
||||
knn_data->distances = out_distances;
|
||||
knn_data->k_used = result_k;
|
||||
|
||||
sqlite3_free(float_distances);
|
||||
}
|
||||
|
||||
cleanup:
|
||||
sqlite3_finalize(stmtChunks);
|
||||
sqlite3_free(quantizedQuery);
|
||||
sqlite3_free(cand_rowids);
|
||||
sqlite3_free(cand_distances);
|
||||
sqlite3_free(tmp_rowids);
|
||||
sqlite3_free(tmp_distances);
|
||||
sqlite3_free(chunk_distances);
|
||||
sqlite3_free(chunk_topk_idxs);
|
||||
sqlite3_free(b);
|
||||
sqlite3_free(bTaken);
|
||||
sqlite3_free(bmRowids);
|
||||
sqlite3_free(baseVectors);
|
||||
return rc;
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle FTS5-style command dispatch for rescore parameters.
|
||||
* Returns SQLITE_OK if handled, SQLITE_EMPTY if not a rescore command.
|
||||
*/
|
||||
static int rescore_handle_command(vec0_vtab *p, const char *command) {
|
||||
if (strncmp(command, "oversample=", 11) == 0) {
|
||||
int val = atoi(command + 11);
|
||||
if (val < 1) {
|
||||
vtab_set_error(&p->base, "oversample must be >= 1");
|
||||
return SQLITE_ERROR;
|
||||
}
|
||||
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||
if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE) {
|
||||
p->vector_columns[i].rescore.oversample_search = val;
|
||||
}
|
||||
}
|
||||
return SQLITE_OK;
|
||||
}
|
||||
return SQLITE_EMPTY;
|
||||
}
|
||||
|
||||
#ifdef SQLITE_VEC_TEST
|
||||
void _test_rescore_quantize_float_to_bit(const float *src, uint8_t *dst, size_t dim) {
|
||||
rescore_quantize_float_to_bit(src, dst, dim);
|
||||
}
|
||||
void _test_rescore_quantize_float_to_int8(const float *src, int8_t *dst, size_t dim) {
|
||||
rescore_quantize_float_to_int8(src, dst, dim);
|
||||
}
|
||||
size_t _test_rescore_quantized_byte_size_bit(size_t dimensions) {
|
||||
struct VectorColumnDefinition col;
|
||||
memset(&col, 0, sizeof(col));
|
||||
col.dimensions = dimensions;
|
||||
col.rescore.quantizer_type = VEC0_RESCORE_QUANTIZER_BIT;
|
||||
return rescore_quantized_byte_size(&col);
|
||||
}
|
||||
size_t _test_rescore_quantized_byte_size_int8(size_t dimensions) {
|
||||
struct VectorColumnDefinition col;
|
||||
memset(&col, 0, sizeof(col));
|
||||
col.dimensions = dimensions;
|
||||
col.rescore.quantizer_type = VEC0_RESCORE_QUANTIZER_INT8;
|
||||
return rescore_quantized_byte_size(&col);
|
||||
}
|
||||
#endif
|
||||
3729
sqlite-vec.c
3729
sqlite-vec.c
File diff suppressed because it is too large
Load diff
File diff suppressed because one or more lines are too long
|
|
@ -27,8 +27,8 @@
|
|||
OrderedDict({
|
||||
'chunk_id': 1,
|
||||
'size': 8,
|
||||
'validity': b'\x06',
|
||||
'rowids': b'\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
|
||||
'validity': b'\x02',
|
||||
'rowids': b'\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
|
|
@ -37,7 +37,7 @@
|
|||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 1,
|
||||
'data': b'\x06',
|
||||
'data': b'\x02',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
|
|
@ -46,7 +46,7 @@
|
|||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 1,
|
||||
'data': b'\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
|
||||
'data': b'\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
|
|
@ -55,7 +55,7 @@
|
|||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 1,
|
||||
'data': b'\x00\x00\x00\x00\x00\x00\x00\x00\x9a\x99\x99\x99\x99\x99\x01@ffffff\n@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
|
||||
'data': b'\x00\x00\x00\x00\x00\x00\x00\x00\x9a\x99\x99\x99\x99\x99\x01@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
|
|
@ -64,17 +64,13 @@
|
|||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 1,
|
||||
'data': b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05\x00\x00\x00test2\x00\x00\x00\x00\x00\x00\x00\r\x00\x00\x00123456789012\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
|
||||
'data': b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05\x00\x00\x00test2\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
'v_metadatatext03': OrderedDict({
|
||||
'sql': 'select * from v_metadatatext03',
|
||||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 3,
|
||||
'data': '1234567890123',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
'v_rowids': OrderedDict({
|
||||
|
|
@ -86,12 +82,6 @@
|
|||
'chunk_id': 1,
|
||||
'chunk_offset': 1,
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 3,
|
||||
'id': None,
|
||||
'chunk_id': 1,
|
||||
'chunk_offset': 2,
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
'v_vector_chunks00': OrderedDict({
|
||||
|
|
@ -99,7 +89,7 @@
|
|||
'rows': list([
|
||||
OrderedDict({
|
||||
'rowid': 1,
|
||||
'vectors': b'\x00\x00\x00\x00""""3333\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
|
||||
'vectors': b'\x00\x00\x00\x00""""\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
|
|
@ -370,14 +360,6 @@
|
|||
'f': 2.2,
|
||||
't': 'test2',
|
||||
}),
|
||||
OrderedDict({
|
||||
'rowid': 3,
|
||||
'vector': b'3333',
|
||||
'b': 1,
|
||||
'n': 3,
|
||||
'f': 3.3,
|
||||
't': '1234567890123',
|
||||
}),
|
||||
]),
|
||||
})
|
||||
# ---
|
||||
|
|
|
|||
|
|
@ -1,5 +1,29 @@
|
|||
import pytest
|
||||
import sqlite3
|
||||
import os
|
||||
|
||||
|
||||
def _vec_debug():
|
||||
db = sqlite3.connect(":memory:")
|
||||
db.enable_load_extension(True)
|
||||
db.load_extension("dist/vec0")
|
||||
db.enable_load_extension(False)
|
||||
return db.execute("SELECT vec_debug()").fetchone()[0]
|
||||
|
||||
|
||||
def _has_build_flag(flag):
|
||||
return flag in _vec_debug().split("Build flags:")[-1]
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
has_ivf = _has_build_flag("ivf")
|
||||
if has_ivf:
|
||||
return
|
||||
skip_ivf = pytest.mark.skip(reason="IVF not enabled (compile with -DSQLITE_VEC_EXPERIMENTAL_IVF_ENABLE=1)")
|
||||
ivf_prefixes = ("test-ivf",)
|
||||
for item in items:
|
||||
if any(item.fspath.basename.startswith(p) for p in ivf_prefixes):
|
||||
item.add_marker(skip_ivf)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
|
|
|||
|
|
@ -48,7 +48,6 @@ import json
|
|||
db = sqlite3.connect(":memory:")
|
||||
db.enable_load_extension(True)
|
||||
db.load_extension("../../dist/vec0")
|
||||
db.execute("select load_extension('../../dist/vec0', 'sqlite3_vec_fs_read_init')")
|
||||
db.enable_load_extension(False)
|
||||
|
||||
results = db.execute(
|
||||
|
|
@ -75,17 +74,21 @@ print(b)
|
|||
|
||||
db.execute('PRAGMA page_size=16384')
|
||||
|
||||
print("Loading into sqlite-vec vec0 table...")
|
||||
t0 = time.time()
|
||||
db.execute("create virtual table v using vec0(a float[3072], chunk_size=16)")
|
||||
db.execute('insert into v select rowid, vector from vec_npy_each(vec_npy_file("dbpedia_openai_3_large_00.npy"))')
|
||||
print(time.time() - t0)
|
||||
|
||||
print("loading numpy array...")
|
||||
t0 = time.time()
|
||||
base = np.load('dbpedia_openai_3_large_00.npy')
|
||||
print(time.time() - t0)
|
||||
|
||||
print("Loading into sqlite-vec vec0 table...")
|
||||
t0 = time.time()
|
||||
db.execute("create virtual table v using vec0(a float[3072], chunk_size=16)")
|
||||
with db:
|
||||
db.executemany(
|
||||
"insert into v(rowid, a) values (?, ?)",
|
||||
[(i, row.tobytes()) for i, row in enumerate(base)],
|
||||
)
|
||||
print(time.time() - t0)
|
||||
|
||||
np.random.seed(1)
|
||||
queries = base[np.random.choice(base.shape[0], 20, replace=False), :]
|
||||
|
||||
|
|
|
|||
BIN
tests/fixtures/legacy-v0.1.6.db
vendored
Normal file
BIN
tests/fixtures/legacy-v0.1.6.db
vendored
Normal file
Binary file not shown.
5
tests/fuzz/.gitignore
vendored
5
tests/fuzz/.gitignore
vendored
|
|
@ -1,2 +1,7 @@
|
|||
*.dSYM
|
||||
targets/
|
||||
corpus/
|
||||
crash-*
|
||||
leak-*
|
||||
timeout-*
|
||||
*.log
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ FUZZ_LDFLAGS ?= $(shell \
|
|||
echo "-Wl,-ld_classic"; \
|
||||
fi)
|
||||
|
||||
FUZZ_CFLAGS = $(FUZZ_SANITIZERS) -I ../../ -I ../../vendor -DSQLITE_CORE -g $(FUZZ_LDFLAGS)
|
||||
FUZZ_CFLAGS = $(FUZZ_SANITIZERS) -I ../../ -I ../../vendor -DSQLITE_CORE -DSQLITE_VEC_ENABLE_DISKANN=1 -g $(FUZZ_LDFLAGS)
|
||||
FUZZ_SRCS = ../../vendor/sqlite3.c ../../sqlite-vec.c
|
||||
|
||||
TARGET_DIR = ./targets
|
||||
|
|
@ -72,10 +72,94 @@ $(TARGET_DIR)/vec_mismatch: vec-mismatch.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
|||
$(TARGET_DIR)/vec0_delete_completeness: vec0-delete-completeness.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/rescore_operations: rescore-operations.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/rescore_create: rescore-create.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/rescore_quantize: rescore-quantize.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE -DSQLITE_VEC_TEST $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/rescore_shadow_corrupt: rescore-shadow-corrupt.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/rescore_knn_deep: rescore-knn-deep.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/rescore_quantize_edge: rescore-quantize-edge.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE -DSQLITE_VEC_TEST $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/rescore_interleave: rescore-interleave.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/ivf_create: ivf-create.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/ivf_operations: ivf-operations.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/ivf_quantize: ivf-quantize.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/ivf_kmeans: ivf-kmeans.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/ivf_shadow_corrupt: ivf-shadow-corrupt.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/ivf_knn_deep: ivf-knn-deep.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/ivf_cell_overflow: ivf-cell-overflow.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/ivf_rescore: ivf-rescore.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(TARGET_DIR)/diskann_operations: diskann-operations.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/diskann_create: diskann-create.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/diskann_graph_corrupt: diskann-graph-corrupt.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/diskann_deep_search: diskann-deep-search.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/diskann_blob_truncate: diskann-blob-truncate.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/diskann_delete_stress: diskann-delete-stress.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/diskann_buffer_flush: diskann-buffer-flush.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/diskann_int8_quant: diskann-int8-quant.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/diskann_prune_direct: diskann-prune-direct.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
$(TARGET_DIR)/diskann_command_inject: diskann-command-inject.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||
|
||||
FUZZ_TARGETS = vec0_create exec json numpy \
|
||||
shadow_corrupt vec0_operations scalar_functions \
|
||||
vec0_create_full metadata_columns vec_each vec_mismatch \
|
||||
vec0_delete_completeness
|
||||
vec0_delete_completeness \
|
||||
rescore_operations rescore_create rescore_quantize \
|
||||
rescore_shadow_corrupt rescore_knn_deep \
|
||||
rescore_quantize_edge rescore_interleave \
|
||||
ivf_create ivf_operations \
|
||||
ivf_quantize ivf_kmeans ivf_shadow_corrupt \
|
||||
ivf_knn_deep ivf_cell_overflow ivf_rescore
|
||||
diskann_operations diskann_create diskann_graph_corrupt \
|
||||
diskann_deep_search diskann_blob_truncate \
|
||||
diskann_delete_stress diskann_buffer_flush \
|
||||
diskann_int8_quant diskann_prune_direct \
|
||||
diskann_command_inject
|
||||
|
||||
all: $(addprefix $(TARGET_DIR)/,$(FUZZ_TARGETS))
|
||||
|
||||
|
|
|
|||
250
tests/fuzz/diskann-blob-truncate.c
Normal file
250
tests/fuzz/diskann-blob-truncate.c
Normal file
|
|
@ -0,0 +1,250 @@
|
|||
/**
|
||||
* Fuzz target for DiskANN shadow table blob size mismatches.
|
||||
*
|
||||
* The critical vulnerability: diskann_node_read() copies whatever blob size
|
||||
* SQLite returns, but diskann_search/insert/delete index into those blobs
|
||||
* using cfg->n_neighbors * sizeof(i64) etc. If the blob is truncated,
|
||||
* extended, or has wrong size, this causes out-of-bounds reads/writes.
|
||||
*
|
||||
* This fuzzer:
|
||||
* 1. Creates a valid DiskANN graph with several nodes
|
||||
* 2. Uses fuzz data to directly write malformed blobs to shadow tables:
|
||||
* - Truncated neighbor_ids (fewer bytes than n_neighbors * 8)
|
||||
* - Truncated validity bitmaps
|
||||
* - Oversized blobs with garbage trailing data
|
||||
* - Zero-length blobs
|
||||
* - Blobs with valid headers but corrupted neighbor rowids
|
||||
* 3. Runs INSERT, DELETE, and KNN operations that traverse the corrupted graph
|
||||
*
|
||||
* Key code paths targeted:
|
||||
* - diskann_node_read with mismatched blob sizes
|
||||
* - diskann_validity_get / diskann_neighbor_id_get on truncated blobs
|
||||
* - diskann_add_reverse_edge reading corrupted neighbor data
|
||||
* - diskann_repair_reverse_edges traversing corrupted neighbor lists
|
||||
* - diskann_search iterating neighbors from corrupted blobs
|
||||
*/
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) {
|
||||
if (*size == 0) return def;
|
||||
uint8_t b = **data;
|
||||
(*data)++;
|
||||
(*size)--;
|
||||
return b;
|
||||
}
|
||||
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
if (size < 32) return 0;
|
||||
|
||||
int rc;
|
||||
sqlite3 *db;
|
||||
|
||||
rc = sqlite3_open(":memory:", &db);
|
||||
assert(rc == SQLITE_OK);
|
||||
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
|
||||
/* Use binary quantizer, float[16], n_neighbors=8 for predictable blob sizes:
|
||||
* validity: 8/8 = 1 byte
|
||||
* neighbor_ids: 8 * 8 = 64 bytes
|
||||
* qvecs: 8 * (16/8) = 16 bytes (binary: 2 bytes per qvec)
|
||||
*/
|
||||
rc = sqlite3_exec(db,
|
||||
"CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[16] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8))",
|
||||
NULL, NULL, NULL);
|
||||
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||
|
||||
/* Insert 12 vectors to create a valid graph structure */
|
||||
{
|
||||
sqlite3_stmt *stmt;
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmt, NULL);
|
||||
for (int i = 1; i <= 12; i++) {
|
||||
float vec[16];
|
||||
for (int j = 0; j < 16; j++) {
|
||||
vec[j] = (float)i * 0.1f + (float)j * 0.01f;
|
||||
}
|
||||
sqlite3_reset(stmt);
|
||||
sqlite3_bind_int64(stmt, 1, i);
|
||||
sqlite3_bind_blob(stmt, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmt);
|
||||
}
|
||||
sqlite3_finalize(stmt);
|
||||
}
|
||||
|
||||
/* Now corrupt shadow table blobs using fuzz data */
|
||||
const char *columns[] = {
|
||||
"neighbors_validity",
|
||||
"neighbor_ids",
|
||||
"neighbor_quantized_vectors"
|
||||
};
|
||||
|
||||
/* Expected sizes for n_neighbors=8, dims=16, binary quantizer */
|
||||
int expected_sizes[] = {1, 64, 16};
|
||||
|
||||
while (size >= 4) {
|
||||
int target_row = (fuzz_byte(&data, &size, 0) % 12) + 1;
|
||||
int col_idx = fuzz_byte(&data, &size, 0) % 3;
|
||||
uint8_t corrupt_mode = fuzz_byte(&data, &size, 0) % 6;
|
||||
uint8_t extra = fuzz_byte(&data, &size, 0);
|
||||
|
||||
char sqlbuf[256];
|
||||
snprintf(sqlbuf, sizeof(sqlbuf),
|
||||
"UPDATE v_diskann_nodes00 SET %s = ? WHERE rowid = ?",
|
||||
columns[col_idx]);
|
||||
|
||||
sqlite3_stmt *writeStmt;
|
||||
rc = sqlite3_prepare_v2(db, sqlbuf, -1, &writeStmt, NULL);
|
||||
if (rc != SQLITE_OK) continue;
|
||||
|
||||
int expected = expected_sizes[col_idx];
|
||||
unsigned char *blob = NULL;
|
||||
int blob_size = 0;
|
||||
|
||||
switch (corrupt_mode) {
|
||||
case 0: {
|
||||
/* Truncated blob: 0 to expected-1 bytes */
|
||||
blob_size = extra % expected;
|
||||
if (blob_size == 0) blob_size = 0; /* zero-length is interesting */
|
||||
blob = sqlite3_malloc(blob_size > 0 ? blob_size : 1);
|
||||
if (!blob) { sqlite3_finalize(writeStmt); continue; }
|
||||
for (int i = 0; i < blob_size; i++) {
|
||||
blob[i] = fuzz_byte(&data, &size, 0);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 1: {
|
||||
/* Oversized blob: expected + extra bytes */
|
||||
blob_size = expected + (extra % 64);
|
||||
blob = sqlite3_malloc(blob_size);
|
||||
if (!blob) { sqlite3_finalize(writeStmt); continue; }
|
||||
for (int i = 0; i < blob_size; i++) {
|
||||
blob[i] = fuzz_byte(&data, &size, 0xFF);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
/* Zero-length blob */
|
||||
blob_size = 0;
|
||||
blob = NULL;
|
||||
sqlite3_bind_zeroblob(writeStmt, 1, 0);
|
||||
sqlite3_bind_int64(writeStmt, 2, target_row);
|
||||
sqlite3_step(writeStmt);
|
||||
sqlite3_finalize(writeStmt);
|
||||
continue;
|
||||
}
|
||||
case 3: {
|
||||
/* Correct size but all-ones validity (all slots "valid") with
|
||||
* garbage neighbor IDs -- forces reading non-existent nodes */
|
||||
blob_size = expected;
|
||||
blob = sqlite3_malloc(blob_size);
|
||||
if (!blob) { sqlite3_finalize(writeStmt); continue; }
|
||||
memset(blob, 0xFF, blob_size);
|
||||
break;
|
||||
}
|
||||
case 4: {
|
||||
/* neighbor_ids with very large rowid values (near INT64_MAX) */
|
||||
blob_size = expected;
|
||||
blob = sqlite3_malloc(blob_size);
|
||||
if (!blob) { sqlite3_finalize(writeStmt); continue; }
|
||||
memset(blob, 0x7F, blob_size); /* fills with large positive values */
|
||||
break;
|
||||
}
|
||||
case 5: {
|
||||
/* neighbor_ids with negative rowid values (rowid=0 is sentinel) */
|
||||
blob_size = expected;
|
||||
blob = sqlite3_malloc(blob_size);
|
||||
if (!blob) { sqlite3_finalize(writeStmt); continue; }
|
||||
memset(blob, 0x80, blob_size); /* fills with large negative values */
|
||||
/* Flip some bytes from fuzz data */
|
||||
for (int i = 0; i < blob_size && size > 0; i++) {
|
||||
blob[i] ^= fuzz_byte(&data, &size, 0);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (blob) {
|
||||
sqlite3_bind_blob(writeStmt, 1, blob, blob_size, SQLITE_TRANSIENT);
|
||||
} else {
|
||||
sqlite3_bind_blob(writeStmt, 1, "", 0, SQLITE_STATIC);
|
||||
}
|
||||
sqlite3_bind_int64(writeStmt, 2, target_row);
|
||||
sqlite3_step(writeStmt);
|
||||
sqlite3_finalize(writeStmt);
|
||||
sqlite3_free(blob);
|
||||
}
|
||||
|
||||
/* Exercise the corrupted graph with various operations */
|
||||
|
||||
/* KNN query */
|
||||
{
|
||||
float qvec[16];
|
||||
for (int j = 0; j < 16; j++) qvec[j] = (float)j * 0.1f;
|
||||
sqlite3_stmt *knnStmt;
|
||||
rc = sqlite3_prepare_v2(db,
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = 5",
|
||||
-1, &knnStmt, NULL);
|
||||
if (rc == SQLITE_OK) {
|
||||
sqlite3_bind_blob(knnStmt, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||
while (sqlite3_step(knnStmt) == SQLITE_ROW) {}
|
||||
sqlite3_finalize(knnStmt);
|
||||
}
|
||||
}
|
||||
|
||||
/* Insert into corrupted graph (triggers add_reverse_edge on corrupted nodes) */
|
||||
{
|
||||
float vec[16];
|
||||
for (int j = 0; j < 16; j++) vec[j] = 0.5f;
|
||||
sqlite3_stmt *stmt;
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmt, NULL);
|
||||
if (stmt) {
|
||||
sqlite3_bind_int64(stmt, 1, 100);
|
||||
sqlite3_bind_blob(stmt, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmt);
|
||||
sqlite3_finalize(stmt);
|
||||
}
|
||||
}
|
||||
|
||||
/* Delete from corrupted graph (triggers repair_reverse_edges) */
|
||||
{
|
||||
sqlite3_stmt *stmt;
|
||||
sqlite3_prepare_v2(db,
|
||||
"DELETE FROM v WHERE rowid = ?", -1, &stmt, NULL);
|
||||
if (stmt) {
|
||||
sqlite3_bind_int64(stmt, 1, 5);
|
||||
sqlite3_step(stmt);
|
||||
sqlite3_finalize(stmt);
|
||||
}
|
||||
}
|
||||
|
||||
/* Another KNN to traverse the post-mutation graph */
|
||||
{
|
||||
float qvec[16];
|
||||
for (int j = 0; j < 16; j++) qvec[j] = -0.5f + (float)j * 0.07f;
|
||||
sqlite3_stmt *knnStmt;
|
||||
rc = sqlite3_prepare_v2(db,
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = 12",
|
||||
-1, &knnStmt, NULL);
|
||||
if (rc == SQLITE_OK) {
|
||||
sqlite3_bind_blob(knnStmt, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||
while (sqlite3_step(knnStmt) == SQLITE_ROW) {}
|
||||
sqlite3_finalize(knnStmt);
|
||||
}
|
||||
}
|
||||
|
||||
/* Full scan */
|
||||
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||
|
||||
sqlite3_close(db);
|
||||
return 0;
|
||||
}
|
||||
164
tests/fuzz/diskann-buffer-flush.c
Normal file
164
tests/fuzz/diskann-buffer-flush.c
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
/**
|
||||
* Fuzz target for DiskANN buffered insert and flush paths.
|
||||
*
|
||||
* When buffer_threshold > 0, inserts go into a flat buffer table and
|
||||
* are flushed into the graph in batch. This fuzzer exercises:
|
||||
*
|
||||
* - diskann_buffer_write / diskann_buffer_delete / diskann_buffer_exists
|
||||
* - diskann_flush_buffer (batch graph insertion)
|
||||
* - diskann_insert with buffer_threshold (batching logic)
|
||||
* - Buffer-graph merge in vec0Filter_knn_diskann (unflushed vectors
|
||||
* must be scanned during KNN and merged with graph results)
|
||||
* - Delete of a buffered (not yet flushed) vector
|
||||
* - Delete of a graph vector while buffer has pending inserts
|
||||
* - Interaction: insert to buffer, query (triggers buffer scan), flush,
|
||||
* query again (now from graph)
|
||||
*
|
||||
* The buffer merge path in vec0Filter_knn_diskann is particularly
|
||||
* interesting because it does a brute-force scan of buffer vectors and
|
||||
* merges with the top-k from graph search.
|
||||
*/
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) {
|
||||
if (*size == 0) return def;
|
||||
uint8_t b = **data;
|
||||
(*data)++;
|
||||
(*size)--;
|
||||
return b;
|
||||
}
|
||||
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
if (size < 16) return 0;
|
||||
|
||||
int rc;
|
||||
sqlite3 *db;
|
||||
rc = sqlite3_open(":memory:", &db);
|
||||
assert(rc == SQLITE_OK);
|
||||
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
|
||||
/* buffer_threshold: small (3-8) to trigger frequent flushes */
|
||||
int buf_threshold = 3 + (fuzz_byte(&data, &size, 0) % 6);
|
||||
int dims = 8;
|
||||
|
||||
char sql[512];
|
||||
snprintf(sql, sizeof(sql),
|
||||
"CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[%d] INDEXED BY diskann("
|
||||
"neighbor_quantizer=binary, n_neighbors=8, "
|
||||
"search_list_size=16, buffer_threshold=%d"
|
||||
"))", dims, buf_threshold);
|
||||
|
||||
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||
|
||||
sqlite3_stmt *stmtInsert = NULL, *stmtDelete = NULL, *stmtKnn = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||
sqlite3_prepare_v2(db,
|
||||
"DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
|
||||
sqlite3_prepare_v2(db,
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?",
|
||||
-1, &stmtKnn, NULL);
|
||||
|
||||
if (!stmtInsert || !stmtDelete || !stmtKnn) goto cleanup;
|
||||
|
||||
float vec[8];
|
||||
int next_rowid = 1;
|
||||
|
||||
while (size >= 2) {
|
||||
uint8_t op = fuzz_byte(&data, &size, 0) % 6;
|
||||
uint8_t param = fuzz_byte(&data, &size, 0);
|
||||
|
||||
switch (op) {
|
||||
case 0: { /* Insert: accumulates in buffer until threshold */
|
||||
int64_t rowid = next_rowid++;
|
||||
if (next_rowid > 64) next_rowid = 1; /* wrap around for reuse */
|
||||
for (int j = 0; j < dims; j++) {
|
||||
vec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f;
|
||||
}
|
||||
sqlite3_reset(stmtInsert);
|
||||
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmtInsert);
|
||||
break;
|
||||
}
|
||||
case 1: { /* KNN query while buffer may have unflushed vectors */
|
||||
for (int j = 0; j < dims; j++) {
|
||||
vec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f;
|
||||
}
|
||||
int k = (param % 10) + 1;
|
||||
sqlite3_reset(stmtKnn);
|
||||
sqlite3_bind_blob(stmtKnn, 1, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||
sqlite3_bind_int(stmtKnn, 2, k);
|
||||
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||
break;
|
||||
}
|
||||
case 2: { /* Delete a potentially-buffered vector */
|
||||
int64_t rowid = (int64_t)(param % 64) + 1;
|
||||
sqlite3_reset(stmtDelete);
|
||||
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||
sqlite3_step(stmtDelete);
|
||||
break;
|
||||
}
|
||||
case 3: { /* Insert several at once to trigger flush mid-batch */
|
||||
for (int i = 0; i < buf_threshold + 1 && size >= 2; i++) {
|
||||
int64_t rowid = (int64_t)(fuzz_byte(&data, &size, 0) % 64) + 1;
|
||||
for (int j = 0; j < dims; j++) {
|
||||
vec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f;
|
||||
}
|
||||
sqlite3_reset(stmtInsert);
|
||||
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmtInsert);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 4: { /* Insert then immediately delete (still in buffer) */
|
||||
int64_t rowid = (int64_t)(param % 64) + 1;
|
||||
for (int j = 0; j < dims; j++) vec[j] = 0.1f * param;
|
||||
sqlite3_reset(stmtInsert);
|
||||
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmtInsert);
|
||||
|
||||
sqlite3_reset(stmtDelete);
|
||||
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||
sqlite3_step(stmtDelete);
|
||||
break;
|
||||
}
|
||||
case 5: { /* Query with k=0 and k=1 (boundary) */
|
||||
for (int j = 0; j < dims; j++) vec[j] = 0.0f;
|
||||
sqlite3_reset(stmtKnn);
|
||||
sqlite3_bind_blob(stmtKnn, 1, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||
sqlite3_bind_int(stmtKnn, 2, param % 2); /* k=0 or k=1 */
|
||||
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Final query to exercise post-operation state */
|
||||
{
|
||||
float qvec[8] = {1.0f, -1.0f, 0.5f, -0.5f, 0.0f, 0.0f, 0.0f, 0.0f};
|
||||
sqlite3_reset(stmtKnn);
|
||||
sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_TRANSIENT);
|
||||
sqlite3_bind_int(stmtKnn, 2, 20);
|
||||
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||
}
|
||||
|
||||
cleanup:
|
||||
sqlite3_finalize(stmtInsert);
|
||||
sqlite3_finalize(stmtDelete);
|
||||
sqlite3_finalize(stmtKnn);
|
||||
sqlite3_close(db);
|
||||
return 0;
|
||||
}
|
||||
158
tests/fuzz/diskann-command-inject.c
Normal file
158
tests/fuzz/diskann-command-inject.c
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
/**
|
||||
* Fuzz target for DiskANN runtime command dispatch (diskann_handle_command).
|
||||
*
|
||||
* The command handler parses strings like "search_list_size_search=42" and
|
||||
* modifies live DiskANN config. This fuzzer exercises:
|
||||
*
|
||||
* - atoi on fuzz-controlled strings (integer overflow, negative, non-numeric)
|
||||
* - strncmp boundary with fuzz data (near-matches to valid commands)
|
||||
* - Changing search_list_size mid-operation (affects subsequent queries)
|
||||
* - Setting search_list_size to 1 (minimum - single-candidate beam search)
|
||||
* - Setting search_list_size very large (memory pressure)
|
||||
* - Interleaving command changes with inserts and queries
|
||||
*
|
||||
* Also tests the UPDATE v SET command = ? path through the vtable.
|
||||
*/
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) {
|
||||
if (*size == 0) return def;
|
||||
uint8_t b = **data;
|
||||
(*data)++;
|
||||
(*size)--;
|
||||
return b;
|
||||
}
|
||||
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
if (size < 20) return 0;
|
||||
|
||||
int rc;
|
||||
sqlite3 *db;
|
||||
rc = sqlite3_open(":memory:", &db);
|
||||
assert(rc == SQLITE_OK);
|
||||
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
|
||||
rc = sqlite3_exec(db,
|
||||
"CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8))",
|
||||
NULL, NULL, NULL);
|
||||
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||
|
||||
/* Insert some vectors first */
|
||||
{
|
||||
sqlite3_stmt *stmt;
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmt, NULL);
|
||||
for (int i = 1; i <= 8; i++) {
|
||||
float vec[8];
|
||||
for (int j = 0; j < 8; j++) vec[j] = (float)i * 0.1f + (float)j * 0.01f;
|
||||
sqlite3_reset(stmt);
|
||||
sqlite3_bind_int64(stmt, 1, i);
|
||||
sqlite3_bind_blob(stmt, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmt);
|
||||
}
|
||||
sqlite3_finalize(stmt);
|
||||
}
|
||||
|
||||
sqlite3_stmt *stmtCmd = NULL;
|
||||
sqlite3_stmt *stmtInsert = NULL;
|
||||
sqlite3_stmt *stmtKnn = NULL;
|
||||
|
||||
/* Commands are dispatched via INSERT INTO t(t) VALUES ('cmd_string') */
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(v) VALUES (?)", -1, &stmtCmd, NULL);
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||
sqlite3_prepare_v2(db,
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?",
|
||||
-1, &stmtKnn, NULL);
|
||||
|
||||
if (!stmtCmd || !stmtInsert || !stmtKnn) goto cleanup;
|
||||
|
||||
/* Fuzz-driven command + operation interleaving */
|
||||
while (size >= 2) {
|
||||
uint8_t op = fuzz_byte(&data, &size, 0) % 5;
|
||||
|
||||
switch (op) {
|
||||
case 0: { /* Send fuzz command string */
|
||||
int cmd_len = fuzz_byte(&data, &size, 0) % 64;
|
||||
char cmd[65];
|
||||
for (int i = 0; i < cmd_len && size > 0; i++) {
|
||||
cmd[i] = (char)fuzz_byte(&data, &size, 0);
|
||||
}
|
||||
cmd[cmd_len] = '\0';
|
||||
sqlite3_reset(stmtCmd);
|
||||
sqlite3_bind_text(stmtCmd, 1, cmd, -1, SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmtCmd); /* May fail -- that's expected */
|
||||
break;
|
||||
}
|
||||
case 1: { /* Send valid-looking command with fuzz value */
|
||||
const char *prefixes[] = {
|
||||
"search_list_size=",
|
||||
"search_list_size_search=",
|
||||
"search_list_size_insert=",
|
||||
};
|
||||
int prefix_idx = fuzz_byte(&data, &size, 0) % 3;
|
||||
int val = (int)(int8_t)fuzz_byte(&data, &size, 0);
|
||||
|
||||
char cmd[128];
|
||||
snprintf(cmd, sizeof(cmd), "%s%d", prefixes[prefix_idx], val);
|
||||
sqlite3_reset(stmtCmd);
|
||||
sqlite3_bind_text(stmtCmd, 1, cmd, -1, SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmtCmd);
|
||||
break;
|
||||
}
|
||||
case 2: { /* KNN query (uses whatever search_list_size is set) */
|
||||
float qvec[8] = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
|
||||
qvec[0] = (float)((int8_t)fuzz_byte(&data, &size, 127)) / 10.0f;
|
||||
int k = fuzz_byte(&data, &size, 3) % 10 + 1;
|
||||
sqlite3_reset(stmtKnn);
|
||||
sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_TRANSIENT);
|
||||
sqlite3_bind_int(stmtKnn, 2, k);
|
||||
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||
break;
|
||||
}
|
||||
case 3: { /* Insert (uses whatever search_list_size_insert is set) */
|
||||
int64_t rowid = (int64_t)(fuzz_byte(&data, &size, 0) % 32) + 1;
|
||||
float vec[8];
|
||||
for (int j = 0; j < 8; j++) {
|
||||
vec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f;
|
||||
}
|
||||
sqlite3_reset(stmtInsert);
|
||||
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmtInsert);
|
||||
break;
|
||||
}
|
||||
case 4: { /* Set search_list_size to extreme values */
|
||||
const char *extreme_cmds[] = {
|
||||
"search_list_size=1",
|
||||
"search_list_size=2",
|
||||
"search_list_size=1000",
|
||||
"search_list_size_search=1",
|
||||
"search_list_size_insert=1",
|
||||
};
|
||||
int idx = fuzz_byte(&data, &size, 0) % 5;
|
||||
sqlite3_reset(stmtCmd);
|
||||
sqlite3_bind_text(stmtCmd, 1, extreme_cmds[idx], -1, SQLITE_STATIC);
|
||||
sqlite3_step(stmtCmd);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cleanup:
|
||||
sqlite3_finalize(stmtCmd);
|
||||
sqlite3_finalize(stmtInsert);
|
||||
sqlite3_finalize(stmtKnn);
|
||||
sqlite3_close(db);
|
||||
return 0;
|
||||
}
|
||||
44
tests/fuzz/diskann-create.c
Normal file
44
tests/fuzz/diskann-create.c
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Fuzz target for DiskANN CREATE TABLE config parsing.
|
||||
* Feeds fuzz data as the INDEXED BY diskann(...) option string.
|
||||
*/
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
if (size > 4096) return 0; /* Limit input size */
|
||||
|
||||
int rc;
|
||||
sqlite3 *db;
|
||||
sqlite3_stmt *stmt;
|
||||
|
||||
rc = sqlite3_open(":memory:", &db);
|
||||
assert(rc == SQLITE_OK);
|
||||
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
|
||||
sqlite3_str *s = sqlite3_str_new(NULL);
|
||||
assert(s);
|
||||
sqlite3_str_appendall(s,
|
||||
"CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[64] INDEXED BY diskann(");
|
||||
sqlite3_str_appendf(s, "%.*s", (int)size, data);
|
||||
sqlite3_str_appendall(s, "))");
|
||||
const char *zSql = sqlite3_str_finish(s);
|
||||
assert(zSql);
|
||||
|
||||
rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, NULL);
|
||||
sqlite3_free((char *)zSql);
|
||||
if (rc == SQLITE_OK) {
|
||||
sqlite3_step(stmt);
|
||||
}
|
||||
sqlite3_finalize(stmt);
|
||||
sqlite3_close(db);
|
||||
return 0;
|
||||
}
|
||||
187
tests/fuzz/diskann-deep-search.c
Normal file
187
tests/fuzz/diskann-deep-search.c
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
/**
|
||||
* Fuzz target for DiskANN greedy beam search deep paths.
|
||||
*
|
||||
* Builds a graph with enough nodes to force multi-hop traversal, then
|
||||
* uses fuzz data to control: query vector values, k, search_list_size
|
||||
* overrides, and interleaved insert/delete/query sequences that stress
|
||||
* the candidate list growth, visited set hash collisions, and the
|
||||
* re-ranking logic.
|
||||
*
|
||||
* Key code paths targeted:
|
||||
* - diskann_candidate_list_insert (sorted insert, dedup, eviction)
|
||||
* - diskann_visited_set (hash collisions, capacity)
|
||||
* - diskann_search (full beam search loop, re-ranking with exact dist)
|
||||
* - diskann_distance_quantized_precomputed (both binary and int8)
|
||||
* - Buffer merge in vec0Filter_knn_diskann
|
||||
*/
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <math.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
/* Consume one byte from fuzz input, or return default. */
|
||||
static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) {
|
||||
if (*size == 0) return def;
|
||||
uint8_t b = **data;
|
||||
(*data)++;
|
||||
(*size)--;
|
||||
return b;
|
||||
}
|
||||
|
||||
static uint16_t fuzz_u16(const uint8_t **data, size_t *size) {
|
||||
uint8_t lo = fuzz_byte(data, size, 0);
|
||||
uint8_t hi = fuzz_byte(data, size, 0);
|
||||
return (uint16_t)hi << 8 | lo;
|
||||
}
|
||||
|
||||
static float fuzz_float(const uint8_t **data, size_t *size) {
|
||||
return (float)((int8_t)fuzz_byte(data, size, 0)) / 10.0f;
|
||||
}
|
||||
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
if (size < 32) return 0;
|
||||
|
||||
/* Use first bytes to pick quantizer type and dimensions */
|
||||
uint8_t quantizer_choice = fuzz_byte(&data, &size, 0) % 2;
|
||||
const char *quantizer = quantizer_choice ? "int8" : "binary";
|
||||
|
||||
/* Dimensions must be divisible by 8. Pick from {8, 16, 32} */
|
||||
int dim_choices[] = {8, 16, 32};
|
||||
int dims = dim_choices[fuzz_byte(&data, &size, 0) % 3];
|
||||
|
||||
/* n_neighbors: 8 or 16 -- small to force full-neighbor scenarios quickly */
|
||||
int n_neighbors = (fuzz_byte(&data, &size, 0) % 2) ? 16 : 8;
|
||||
|
||||
/* search_list_size: small so beam search terminates quickly but still exercises loops */
|
||||
int search_list_size = 8 + (fuzz_byte(&data, &size, 0) % 24);
|
||||
|
||||
/* alpha: vary to test RobustPrune pruning logic */
|
||||
float alpha_choices[] = {1.0f, 1.2f, 1.5f, 2.0f};
|
||||
float alpha = alpha_choices[fuzz_byte(&data, &size, 0) % 4];
|
||||
|
||||
int rc;
|
||||
sqlite3 *db;
|
||||
rc = sqlite3_open(":memory:", &db);
|
||||
assert(rc == SQLITE_OK);
|
||||
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
|
||||
char sql[512];
|
||||
snprintf(sql, sizeof(sql),
|
||||
"CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[%d] INDEXED BY diskann("
|
||||
"neighbor_quantizer=%s, n_neighbors=%d, "
|
||||
"search_list_size=%d"
|
||||
"))", dims, quantizer, n_neighbors, search_list_size);
|
||||
|
||||
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||
|
||||
sqlite3_stmt *stmtInsert = NULL, *stmtDelete = NULL, *stmtKnn = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||
sqlite3_prepare_v2(db,
|
||||
"DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
|
||||
|
||||
char knn_sql[256];
|
||||
snprintf(knn_sql, sizeof(knn_sql),
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?");
|
||||
sqlite3_prepare_v2(db, knn_sql, -1, &stmtKnn, NULL);
|
||||
|
||||
if (!stmtInsert || !stmtDelete || !stmtKnn) goto cleanup;
|
||||
|
||||
/* Phase 1: Seed the graph with enough nodes to create multi-hop structure.
|
||||
* Insert 2*n_neighbors nodes so the graph is dense enough for search
|
||||
* to actually traverse multiple hops. */
|
||||
int seed_count = n_neighbors * 2;
|
||||
if (seed_count > 64) seed_count = 64; /* Bound for performance */
|
||||
{
|
||||
float *vec = malloc(dims * sizeof(float));
|
||||
if (!vec) goto cleanup;
|
||||
for (int i = 1; i <= seed_count; i++) {
|
||||
for (int j = 0; j < dims; j++) {
|
||||
vec[j] = fuzz_float(&data, &size);
|
||||
}
|
||||
sqlite3_reset(stmtInsert);
|
||||
sqlite3_bind_int64(stmtInsert, 1, i);
|
||||
sqlite3_bind_blob(stmtInsert, 2, vec, dims * sizeof(float), SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmtInsert);
|
||||
}
|
||||
free(vec);
|
||||
}
|
||||
|
||||
/* Phase 2: Fuzz-driven operations on the seeded graph */
|
||||
float *vec = malloc(dims * sizeof(float));
|
||||
if (!vec) goto cleanup;
|
||||
|
||||
while (size >= 2) {
|
||||
uint8_t op = fuzz_byte(&data, &size, 0) % 5;
|
||||
uint8_t param = fuzz_byte(&data, &size, 0);
|
||||
|
||||
switch (op) {
|
||||
case 0: { /* INSERT with fuzz-controlled vector and rowid */
|
||||
int64_t rowid = (int64_t)(param % 128) + 1;
|
||||
for (int j = 0; j < dims; j++) {
|
||||
vec[j] = fuzz_float(&data, &size);
|
||||
}
|
||||
sqlite3_reset(stmtInsert);
|
||||
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||
sqlite3_bind_blob(stmtInsert, 2, vec, dims * sizeof(float), SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmtInsert);
|
||||
break;
|
||||
}
|
||||
case 1: { /* DELETE */
|
||||
int64_t rowid = (int64_t)(param % 128) + 1;
|
||||
sqlite3_reset(stmtDelete);
|
||||
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||
sqlite3_step(stmtDelete);
|
||||
break;
|
||||
}
|
||||
case 2: { /* KNN with fuzz query vector and variable k */
|
||||
for (int j = 0; j < dims; j++) {
|
||||
vec[j] = fuzz_float(&data, &size);
|
||||
}
|
||||
int k = (param % 20) + 1;
|
||||
sqlite3_reset(stmtKnn);
|
||||
sqlite3_bind_blob(stmtKnn, 1, vec, dims * sizeof(float), SQLITE_TRANSIENT);
|
||||
sqlite3_bind_int(stmtKnn, 2, k);
|
||||
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||
break;
|
||||
}
|
||||
case 3: { /* KNN with k > number of nodes (boundary) */
|
||||
for (int j = 0; j < dims; j++) {
|
||||
vec[j] = fuzz_float(&data, &size);
|
||||
}
|
||||
sqlite3_reset(stmtKnn);
|
||||
sqlite3_bind_blob(stmtKnn, 1, vec, dims * sizeof(float), SQLITE_TRANSIENT);
|
||||
sqlite3_bind_int(stmtKnn, 2, 1000); /* k >> graph size */
|
||||
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||
break;
|
||||
}
|
||||
case 4: { /* INSERT duplicate rowid (triggers OR REPLACE path) */
|
||||
int64_t rowid = (int64_t)(param % 32) + 1;
|
||||
for (int j = 0; j < dims; j++) {
|
||||
vec[j] = (float)(param + j) / 50.0f;
|
||||
}
|
||||
sqlite3_reset(stmtInsert);
|
||||
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||
sqlite3_bind_blob(stmtInsert, 2, vec, dims * sizeof(float), SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmtInsert);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
free(vec);
|
||||
|
||||
cleanup:
|
||||
sqlite3_finalize(stmtInsert);
|
||||
sqlite3_finalize(stmtDelete);
|
||||
sqlite3_finalize(stmtKnn);
|
||||
sqlite3_close(db);
|
||||
return 0;
|
||||
}
|
||||
175
tests/fuzz/diskann-delete-stress.c
Normal file
175
tests/fuzz/diskann-delete-stress.c
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
/**
|
||||
* Fuzz target for DiskANN delete path and graph connectivity maintenance.
|
||||
*
|
||||
* The delete path is the most complex graph mutation:
|
||||
* 1. Read deleted node's neighbor list
|
||||
* 2. For each neighbor, remove deleted node from their list
|
||||
* 3. Try to fill the gap with one of deleted node's other neighbors
|
||||
* 4. Handle medoid deletion (pick new medoid)
|
||||
*
|
||||
* Edge cases this targets:
|
||||
* - Delete the medoid (entry point) -- forces medoid reassignment
|
||||
* - Delete all nodes except one -- graph degenerates
|
||||
* - Delete nodes in a chain -- cascading dangling edges
|
||||
* - Re-insert at deleted rowids -- stale graph edges to old data
|
||||
* - Delete nonexistent rowids -- should be no-op
|
||||
* - Insert-delete-insert same rowid rapidly
|
||||
* - Delete when graph has exactly n_neighbors entries (full nodes)
|
||||
*
|
||||
* Key code paths:
|
||||
* - diskann_delete -> diskann_repair_reverse_edges
|
||||
* - diskann_medoid_handle_delete
|
||||
* - diskann_node_clear_neighbor
|
||||
* - Interaction between delete and concurrent search
|
||||
*/
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) {
|
||||
if (*size == 0) return def;
|
||||
uint8_t b = **data;
|
||||
(*data)++;
|
||||
(*size)--;
|
||||
return b;
|
||||
}
|
||||
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
if (size < 20) return 0;
|
||||
|
||||
int rc;
|
||||
sqlite3 *db;
|
||||
rc = sqlite3_open(":memory:", &db);
|
||||
assert(rc == SQLITE_OK);
|
||||
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
|
||||
/* int8 quantizer to exercise that distance code path */
|
||||
uint8_t quant = fuzz_byte(&data, &size, 0) % 2;
|
||||
const char *qname = quant ? "int8" : "binary";
|
||||
|
||||
char sql[256];
|
||||
snprintf(sql, sizeof(sql),
|
||||
"CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[8] INDEXED BY diskann(neighbor_quantizer=%s, n_neighbors=8))",
|
||||
qname);
|
||||
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||
|
||||
sqlite3_stmt *stmtInsert = NULL, *stmtDelete = NULL, *stmtKnn = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||
sqlite3_prepare_v2(db,
|
||||
"DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
|
||||
sqlite3_prepare_v2(db,
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?",
|
||||
-1, &stmtKnn, NULL);
|
||||
|
||||
if (!stmtInsert || !stmtDelete || !stmtKnn) goto cleanup;
|
||||
|
||||
/* Phase 1: Build a graph of exactly n_neighbors+2 = 10 nodes.
|
||||
* This makes every node nearly full, maximizing the chance that
|
||||
* inserts trigger the "full node" path in add_reverse_edge. */
|
||||
for (int i = 1; i <= 10; i++) {
|
||||
float vec[8];
|
||||
for (int j = 0; j < 8; j++) {
|
||||
vec[j] = (float)((int8_t)fuzz_byte(&data, &size, (uint8_t)(i*13+j*7))) / 20.0f;
|
||||
}
|
||||
sqlite3_reset(stmtInsert);
|
||||
sqlite3_bind_int64(stmtInsert, 1, i);
|
||||
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmtInsert);
|
||||
}
|
||||
|
||||
/* Phase 2: Fuzz-driven delete-heavy workload */
|
||||
while (size >= 2) {
|
||||
uint8_t op = fuzz_byte(&data, &size, 0);
|
||||
uint8_t param = fuzz_byte(&data, &size, 0);
|
||||
|
||||
switch (op % 6) {
|
||||
case 0: /* Delete existing node */
|
||||
case 1: { /* (weighted toward deletes) */
|
||||
int64_t rowid = (int64_t)(param % 16) + 1;
|
||||
sqlite3_reset(stmtDelete);
|
||||
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||
sqlite3_step(stmtDelete);
|
||||
break;
|
||||
}
|
||||
case 2: { /* Delete then immediately re-insert same rowid */
|
||||
int64_t rowid = (int64_t)(param % 10) + 1;
|
||||
sqlite3_reset(stmtDelete);
|
||||
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||
sqlite3_step(stmtDelete);
|
||||
|
||||
float vec[8];
|
||||
for (int j = 0; j < 8; j++) {
|
||||
vec[j] = (float)((int8_t)fuzz_byte(&data, &size, (uint8_t)(rowid+j))) / 15.0f;
|
||||
}
|
||||
sqlite3_reset(stmtInsert);
|
||||
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmtInsert);
|
||||
break;
|
||||
}
|
||||
case 3: { /* KNN query on potentially sparse/empty graph */
|
||||
float qvec[8];
|
||||
for (int j = 0; j < 8; j++) {
|
||||
qvec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f;
|
||||
}
|
||||
int k = (param % 15) + 1;
|
||||
sqlite3_reset(stmtKnn);
|
||||
sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_TRANSIENT);
|
||||
sqlite3_bind_int(stmtKnn, 2, k);
|
||||
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||
break;
|
||||
}
|
||||
case 4: { /* Insert new node */
|
||||
int64_t rowid = (int64_t)(param % 32) + 1;
|
||||
float vec[8];
|
||||
for (int j = 0; j < 8; j++) {
|
||||
vec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f;
|
||||
}
|
||||
sqlite3_reset(stmtInsert);
|
||||
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmtInsert);
|
||||
break;
|
||||
}
|
||||
case 5: { /* Delete ALL remaining nodes, then insert fresh */
|
||||
for (int i = 1; i <= 32; i++) {
|
||||
sqlite3_reset(stmtDelete);
|
||||
sqlite3_bind_int64(stmtDelete, 1, i);
|
||||
sqlite3_step(stmtDelete);
|
||||
}
|
||||
/* Now insert one node into empty graph */
|
||||
float vec[8] = {1.0f, 0, 0, 0, 0, 0, 0, 0};
|
||||
sqlite3_reset(stmtInsert);
|
||||
sqlite3_bind_int64(stmtInsert, 1, 1);
|
||||
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmtInsert);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Final KNN on whatever state the graph is in */
|
||||
{
|
||||
float qvec[8] = {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f};
|
||||
sqlite3_reset(stmtKnn);
|
||||
sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_TRANSIENT);
|
||||
sqlite3_bind_int(stmtKnn, 2, 10);
|
||||
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||
}
|
||||
|
||||
cleanup:
|
||||
sqlite3_finalize(stmtInsert);
|
||||
sqlite3_finalize(stmtDelete);
|
||||
sqlite3_finalize(stmtKnn);
|
||||
sqlite3_close(db);
|
||||
return 0;
|
||||
}
|
||||
123
tests/fuzz/diskann-graph-corrupt.c
Normal file
123
tests/fuzz/diskann-graph-corrupt.c
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
/**
|
||||
* Fuzz target for DiskANN shadow table corruption resilience.
|
||||
* Creates and populates a DiskANN table, then corrupts shadow table blobs
|
||||
* using fuzz data and runs queries.
|
||||
*/
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
if (size < 16) return 0;
|
||||
|
||||
int rc;
|
||||
sqlite3 *db;
|
||||
|
||||
rc = sqlite3_open(":memory:", &db);
|
||||
assert(rc == SQLITE_OK);
|
||||
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
|
||||
rc = sqlite3_exec(db,
|
||||
"CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8))",
|
||||
NULL, NULL, NULL);
|
||||
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||
|
||||
/* Insert a few vectors to create graph structure */
|
||||
{
|
||||
sqlite3_stmt *stmt;
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmt, NULL);
|
||||
for (int i = 1; i <= 10; i++) {
|
||||
float vec[8];
|
||||
for (int j = 0; j < 8; j++) {
|
||||
vec[j] = (float)i * 0.1f + (float)j * 0.01f;
|
||||
}
|
||||
sqlite3_reset(stmt);
|
||||
sqlite3_bind_int64(stmt, 1, i);
|
||||
sqlite3_bind_blob(stmt, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmt);
|
||||
}
|
||||
sqlite3_finalize(stmt);
|
||||
}
|
||||
|
||||
/* Corrupt shadow table data using fuzz bytes */
|
||||
size_t offset = 0;
|
||||
|
||||
/* Determine which row and column to corrupt */
|
||||
int target_row = (data[offset++] % 10) + 1;
|
||||
int corrupt_type = data[offset++] % 3; /* 0=validity, 1=neighbor_ids, 2=qvecs */
|
||||
|
||||
const char *column_name;
|
||||
switch (corrupt_type) {
|
||||
case 0: column_name = "neighbors_validity"; break;
|
||||
case 1: column_name = "neighbor_ids"; break;
|
||||
default: column_name = "neighbor_quantized_vectors"; break;
|
||||
}
|
||||
|
||||
/* Read the blob, corrupt it, write it back */
|
||||
{
|
||||
sqlite3_stmt *readStmt;
|
||||
char sqlbuf[256];
|
||||
snprintf(sqlbuf, sizeof(sqlbuf),
|
||||
"SELECT %s FROM v_diskann_nodes00 WHERE rowid = ?", column_name);
|
||||
rc = sqlite3_prepare_v2(db, sqlbuf, -1, &readStmt, NULL);
|
||||
if (rc == SQLITE_OK) {
|
||||
sqlite3_bind_int64(readStmt, 1, target_row);
|
||||
if (sqlite3_step(readStmt) == SQLITE_ROW) {
|
||||
const void *blob = sqlite3_column_blob(readStmt, 0);
|
||||
int blobSize = sqlite3_column_bytes(readStmt, 0);
|
||||
if (blob && blobSize > 0) {
|
||||
unsigned char *corrupt = sqlite3_malloc(blobSize);
|
||||
if (corrupt) {
|
||||
memcpy(corrupt, blob, blobSize);
|
||||
/* Apply fuzz bytes as XOR corruption */
|
||||
size_t remaining = size - offset;
|
||||
for (size_t i = 0; i < remaining && i < (size_t)blobSize; i++) {
|
||||
corrupt[i % blobSize] ^= data[offset + i];
|
||||
}
|
||||
/* Write back */
|
||||
sqlite3_stmt *writeStmt;
|
||||
snprintf(sqlbuf, sizeof(sqlbuf),
|
||||
"UPDATE v_diskann_nodes00 SET %s = ? WHERE rowid = ?", column_name);
|
||||
rc = sqlite3_prepare_v2(db, sqlbuf, -1, &writeStmt, NULL);
|
||||
if (rc == SQLITE_OK) {
|
||||
sqlite3_bind_blob(writeStmt, 1, corrupt, blobSize, SQLITE_TRANSIENT);
|
||||
sqlite3_bind_int64(writeStmt, 2, target_row);
|
||||
sqlite3_step(writeStmt);
|
||||
sqlite3_finalize(writeStmt);
|
||||
}
|
||||
sqlite3_free(corrupt);
|
||||
}
|
||||
}
|
||||
}
|
||||
sqlite3_finalize(readStmt);
|
||||
}
|
||||
}
|
||||
|
||||
/* Run queries on corrupted graph -- should not crash */
|
||||
{
|
||||
float qvec[8] = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
|
||||
sqlite3_stmt *knnStmt;
|
||||
rc = sqlite3_prepare_v2(db,
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = 5",
|
||||
-1, &knnStmt, NULL);
|
||||
if (rc == SQLITE_OK) {
|
||||
sqlite3_bind_blob(knnStmt, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||
while (sqlite3_step(knnStmt) == SQLITE_ROW) {}
|
||||
sqlite3_finalize(knnStmt);
|
||||
}
|
||||
}
|
||||
|
||||
/* Full scan */
|
||||
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||
|
||||
sqlite3_close(db);
|
||||
return 0;
|
||||
}
|
||||
164
tests/fuzz/diskann-int8-quant.c
Normal file
164
tests/fuzz/diskann-int8-quant.c
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
/**
|
||||
* Fuzz target for DiskANN int8 quantizer edge cases.
|
||||
*
|
||||
* The binary quantizer is simple (sign bit), but the int8 quantizer has
|
||||
* interesting arithmetic:
|
||||
* i8_val = (i8)(((src - (-1.0f)) / step) - 128.0f)
|
||||
* where step = 2.0f / 255.0f
|
||||
*
|
||||
* Edge cases in this formula:
|
||||
* - src values outside [-1, 1] cause clamping issues (no explicit clamp!)
|
||||
* - src = NaN, +Inf, -Inf (from corrupted vectors or div-by-zero)
|
||||
* - src very close to boundaries (-1.0, 1.0) -- rounding
|
||||
* - The cast to i8 can overflow for extreme src values
|
||||
*
|
||||
* Also exercises int8 distance functions:
|
||||
* - distance_l2_sqr_int8: accumulates squared differences, possible overflow
|
||||
* - distance_cosine_int8: dot product with normalization
|
||||
* - distance_l1_int8: absolute differences
|
||||
*
|
||||
* This fuzzer also tests the cosine distance metric path which the
|
||||
* other fuzzers (using L2 default) don't cover.
|
||||
*/
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <math.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) {
|
||||
if (*size == 0) return def;
|
||||
uint8_t b = **data;
|
||||
(*data)++;
|
||||
(*size)--;
|
||||
return b;
|
||||
}
|
||||
|
||||
static float fuzz_extreme_float(const uint8_t **data, size_t *size) {
|
||||
uint8_t mode = fuzz_byte(data, size, 0) % 8;
|
||||
uint8_t raw = fuzz_byte(data, size, 0);
|
||||
switch (mode) {
|
||||
case 0: return (float)((int8_t)raw) / 10.0f; /* Normal range */
|
||||
case 1: return (float)((int8_t)raw) * 100.0f; /* Large values */
|
||||
case 2: return (float)((int8_t)raw) / 1000.0f; /* Tiny values near 0 */
|
||||
case 3: return -1.0f; /* Exact boundary */
|
||||
case 4: return 1.0f; /* Exact boundary */
|
||||
case 5: return 0.0f; /* Zero */
|
||||
case 6: return (float)raw / 255.0f; /* [0, 1] range */
|
||||
case 7: return -(float)raw / 255.0f; /* [-1, 0] range */
|
||||
}
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
if (size < 40) return 0;
|
||||
|
||||
int rc;
|
||||
sqlite3 *db;
|
||||
rc = sqlite3_open(":memory:", &db);
|
||||
assert(rc == SQLITE_OK);
|
||||
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
|
||||
/* Test both distance metrics with int8 quantizer */
|
||||
uint8_t metric_choice = fuzz_byte(&data, &size, 0) % 2;
|
||||
const char *metric = metric_choice ? "cosine" : "L2";
|
||||
|
||||
int dims = 8 + (fuzz_byte(&data, &size, 0) % 3) * 8; /* 8, 16, or 24 */
|
||||
|
||||
char sql[512];
|
||||
snprintf(sql, sizeof(sql),
|
||||
"CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[%d] distance_metric=%s "
|
||||
"INDEXED BY diskann(neighbor_quantizer=int8, n_neighbors=8, search_list_size=16))",
|
||||
dims, metric);
|
||||
|
||||
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||
|
||||
sqlite3_stmt *stmtInsert = NULL, *stmtKnn = NULL, *stmtDelete = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||
sqlite3_prepare_v2(db,
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?",
|
||||
-1, &stmtKnn, NULL);
|
||||
sqlite3_prepare_v2(db,
|
||||
"DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
|
||||
|
||||
if (!stmtInsert || !stmtKnn || !stmtDelete) goto cleanup;
|
||||
|
||||
/* Insert vectors with extreme float values to stress quantization */
|
||||
float *vec = malloc(dims * sizeof(float));
|
||||
if (!vec) goto cleanup;
|
||||
|
||||
for (int i = 1; i <= 16; i++) {
|
||||
for (int j = 0; j < dims; j++) {
|
||||
vec[j] = fuzz_extreme_float(&data, &size);
|
||||
}
|
||||
sqlite3_reset(stmtInsert);
|
||||
sqlite3_bind_int64(stmtInsert, 1, i);
|
||||
sqlite3_bind_blob(stmtInsert, 2, vec, dims * sizeof(float), SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmtInsert);
|
||||
}
|
||||
|
||||
/* Fuzz-driven operations */
|
||||
while (size >= 2) {
|
||||
uint8_t op = fuzz_byte(&data, &size, 0) % 4;
|
||||
uint8_t param = fuzz_byte(&data, &size, 0);
|
||||
|
||||
switch (op) {
|
||||
case 0: { /* KNN with extreme query values */
|
||||
for (int j = 0; j < dims; j++) {
|
||||
vec[j] = fuzz_extreme_float(&data, &size);
|
||||
}
|
||||
int k = (param % 10) + 1;
|
||||
sqlite3_reset(stmtKnn);
|
||||
sqlite3_bind_blob(stmtKnn, 1, vec, dims * sizeof(float), SQLITE_TRANSIENT);
|
||||
sqlite3_bind_int(stmtKnn, 2, k);
|
||||
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||
break;
|
||||
}
|
||||
case 1: { /* Insert with extreme values */
|
||||
int64_t rowid = (int64_t)(param % 32) + 1;
|
||||
for (int j = 0; j < dims; j++) {
|
||||
vec[j] = fuzz_extreme_float(&data, &size);
|
||||
}
|
||||
sqlite3_reset(stmtInsert);
|
||||
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||
sqlite3_bind_blob(stmtInsert, 2, vec, dims * sizeof(float), SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmtInsert);
|
||||
break;
|
||||
}
|
||||
case 2: { /* Delete */
|
||||
int64_t rowid = (int64_t)(param % 32) + 1;
|
||||
sqlite3_reset(stmtDelete);
|
||||
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||
sqlite3_step(stmtDelete);
|
||||
break;
|
||||
}
|
||||
case 3: { /* KNN with all-zero or all-same-value query */
|
||||
float val = (param % 3 == 0) ? 0.0f :
|
||||
(param % 3 == 1) ? 1.0f : -1.0f;
|
||||
for (int j = 0; j < dims; j++) vec[j] = val;
|
||||
sqlite3_reset(stmtKnn);
|
||||
sqlite3_bind_blob(stmtKnn, 1, vec, dims * sizeof(float), SQLITE_TRANSIENT);
|
||||
sqlite3_bind_int(stmtKnn, 2, 5);
|
||||
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
free(vec);
|
||||
|
||||
cleanup:
|
||||
sqlite3_finalize(stmtInsert);
|
||||
sqlite3_finalize(stmtKnn);
|
||||
sqlite3_finalize(stmtDelete);
|
||||
sqlite3_close(db);
|
||||
return 0;
|
||||
}
|
||||
100
tests/fuzz/diskann-operations.c
Normal file
100
tests/fuzz/diskann-operations.c
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
/**
|
||||
* Fuzz target for DiskANN insert/delete/query operation sequences.
|
||||
* Uses fuzz bytes to drive random operations on a DiskANN-indexed table.
|
||||
*/
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
if (size < 6) return 0;
|
||||
|
||||
int rc;
|
||||
sqlite3 *db;
|
||||
sqlite3_stmt *stmtInsert = NULL;
|
||||
sqlite3_stmt *stmtDelete = NULL;
|
||||
sqlite3_stmt *stmtKnn = NULL;
|
||||
sqlite3_stmt *stmtScan = NULL;
|
||||
|
||||
rc = sqlite3_open(":memory:", &db);
|
||||
assert(rc == SQLITE_OK);
|
||||
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
|
||||
rc = sqlite3_exec(db,
|
||||
"CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8))",
|
||||
NULL, NULL, NULL);
|
||||
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||
sqlite3_prepare_v2(db,
|
||||
"DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
|
||||
sqlite3_prepare_v2(db,
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = 3",
|
||||
-1, &stmtKnn, NULL);
|
||||
sqlite3_prepare_v2(db,
|
||||
"SELECT rowid FROM v", -1, &stmtScan, NULL);
|
||||
|
||||
if (!stmtInsert || !stmtDelete || !stmtKnn || !stmtScan) goto cleanup;
|
||||
|
||||
size_t i = 0;
|
||||
while (i + 2 <= size) {
|
||||
uint8_t op = data[i++] % 4;
|
||||
uint8_t rowid_byte = data[i++];
|
||||
int64_t rowid = (int64_t)(rowid_byte % 32) + 1;
|
||||
|
||||
switch (op) {
|
||||
case 0: {
|
||||
/* INSERT: consume 32 bytes for 8 floats, or use what's left */
|
||||
float vec[8] = {0};
|
||||
for (int j = 0; j < 8 && i < size; j++, i++) {
|
||||
vec[j] = (float)((int8_t)data[i]) / 10.0f;
|
||||
}
|
||||
sqlite3_reset(stmtInsert);
|
||||
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmtInsert);
|
||||
break;
|
||||
}
|
||||
case 1: {
|
||||
/* DELETE */
|
||||
sqlite3_reset(stmtDelete);
|
||||
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||
sqlite3_step(stmtDelete);
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
/* KNN query */
|
||||
float qvec[8] = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
|
||||
sqlite3_reset(stmtKnn);
|
||||
sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||
break;
|
||||
}
|
||||
case 3: {
|
||||
/* Full scan */
|
||||
sqlite3_reset(stmtScan);
|
||||
while (sqlite3_step(stmtScan) == SQLITE_ROW) {}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Final operations -- must not crash regardless of prior state */
|
||||
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||
|
||||
cleanup:
|
||||
sqlite3_finalize(stmtInsert);
|
||||
sqlite3_finalize(stmtDelete);
|
||||
sqlite3_finalize(stmtKnn);
|
||||
sqlite3_finalize(stmtScan);
|
||||
sqlite3_close(db);
|
||||
return 0;
|
||||
}
|
||||
131
tests/fuzz/diskann-prune-direct.c
Normal file
131
tests/fuzz/diskann-prune-direct.c
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
/**
|
||||
* Fuzz target for DiskANN RobustPrune algorithm (diskann_prune_select).
|
||||
*
|
||||
* diskann_prune_select is exposed for testing and takes:
|
||||
* - inter_distances: flattened NxN matrix of inter-candidate distances
|
||||
* - p_distances: N distances from node p to each candidate
|
||||
* - num_candidates, alpha, max_neighbors
|
||||
*
|
||||
* This is a pure function that doesn't need a database, so we can
|
||||
* call it directly with fuzz-controlled inputs. This gives the fuzzer
|
||||
* maximum speed (no SQLite overhead) to explore:
|
||||
*
|
||||
* - alpha boundary: alpha=0 (prunes nothing), alpha=very large (prunes all)
|
||||
* - max_neighbors = 0, 1, num_candidates, > num_candidates
|
||||
* - num_candidates = 0, 1, large
|
||||
* - Distance matrices with: all zeros, all same, negative values, NaN, Inf
|
||||
* - Non-symmetric distance matrices (should still work)
|
||||
* - Memory: large num_candidates to stress malloc
|
||||
*
|
||||
* Key code paths:
|
||||
* - diskann_prune_select alpha-pruning loop
|
||||
* - Boundary: selectedCount reaches max_neighbors exactly
|
||||
* - All candidates pruned before max_neighbors reached
|
||||
*/
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <math.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
/* Declare the test-exposed function.
|
||||
* diskann_prune_select is not static -- it's a public symbol. */
|
||||
extern int diskann_prune_select(
|
||||
const float *inter_distances, const float *p_distances,
|
||||
int num_candidates, float alpha, int max_neighbors,
|
||||
int *outSelected, int *outCount);
|
||||
|
||||
static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) {
|
||||
if (*size == 0) return def;
|
||||
uint8_t b = **data;
|
||||
(*data)++;
|
||||
(*size)--;
|
||||
return b;
|
||||
}
|
||||
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
if (size < 8) return 0;
|
||||
|
||||
/* Consume parameters from fuzz data */
|
||||
int num_candidates = fuzz_byte(&data, &size, 0) % 33; /* 0..32 */
|
||||
int max_neighbors = fuzz_byte(&data, &size, 0) % 17; /* 0..16 */
|
||||
|
||||
/* Alpha: pick from interesting values */
|
||||
uint8_t alpha_idx = fuzz_byte(&data, &size, 0) % 8;
|
||||
float alpha_values[] = {0.0f, 0.5f, 1.0f, 1.2f, 1.5f, 2.0f, 10.0f, 100.0f};
|
||||
float alpha = alpha_values[alpha_idx];
|
||||
|
||||
if (num_candidates == 0) {
|
||||
/* Test empty case */
|
||||
int outCount = -1;
|
||||
int rc = diskann_prune_select(NULL, NULL, 0, alpha, max_neighbors,
|
||||
NULL, &outCount);
|
||||
assert(rc == 0 /* SQLITE_OK */);
|
||||
assert(outCount == 0);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* Allocate arrays */
|
||||
int n = num_candidates;
|
||||
float *inter_distances = malloc(n * n * sizeof(float));
|
||||
float *p_distances = malloc(n * sizeof(float));
|
||||
int *outSelected = malloc(n * sizeof(int));
|
||||
if (!inter_distances || !p_distances || !outSelected) {
|
||||
free(inter_distances);
|
||||
free(p_distances);
|
||||
free(outSelected);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* Fill p_distances from fuzz data (sorted ascending for correct input) */
|
||||
for (int i = 0; i < n; i++) {
|
||||
uint8_t raw = fuzz_byte(&data, &size, (uint8_t)(i * 10));
|
||||
p_distances[i] = (float)raw / 10.0f;
|
||||
}
|
||||
/* Sort p_distances ascending (prune_select expects sorted input) */
|
||||
for (int i = 1; i < n; i++) {
|
||||
float tmp = p_distances[i];
|
||||
int j = i - 1;
|
||||
while (j >= 0 && p_distances[j] > tmp) {
|
||||
p_distances[j + 1] = p_distances[j];
|
||||
j--;
|
||||
}
|
||||
p_distances[j + 1] = tmp;
|
||||
}
|
||||
|
||||
/* Fill inter-distance matrix from fuzz data */
|
||||
for (int i = 0; i < n * n; i++) {
|
||||
uint8_t raw = fuzz_byte(&data, &size, (uint8_t)(i % 256));
|
||||
inter_distances[i] = (float)raw / 10.0f;
|
||||
}
|
||||
/* Make diagonal zero */
|
||||
for (int i = 0; i < n; i++) {
|
||||
inter_distances[i * n + i] = 0.0f;
|
||||
}
|
||||
|
||||
int outCount = -1;
|
||||
int rc = diskann_prune_select(inter_distances, p_distances,
|
||||
n, alpha, max_neighbors,
|
||||
outSelected, &outCount);
|
||||
/* Basic sanity: should not crash, count should be valid */
|
||||
assert(rc == 0);
|
||||
assert(outCount >= 0);
|
||||
assert(outCount <= max_neighbors || max_neighbors == 0);
|
||||
assert(outCount <= n);
|
||||
|
||||
/* Verify outSelected flags are consistent with outCount */
|
||||
int flagCount = 0;
|
||||
for (int i = 0; i < n; i++) {
|
||||
if (outSelected[i]) flagCount++;
|
||||
}
|
||||
assert(flagCount == outCount);
|
||||
|
||||
free(inter_distances);
|
||||
free(p_distances);
|
||||
free(outSelected);
|
||||
return 0;
|
||||
}
|
||||
10
tests/fuzz/diskann.dict
Normal file
10
tests/fuzz/diskann.dict
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
"neighbor_quantizer"
|
||||
"binary"
|
||||
"int8"
|
||||
"n_neighbors"
|
||||
"search_list_size"
|
||||
"search_list_size_search"
|
||||
"search_list_size_insert"
|
||||
"alpha"
|
||||
"="
|
||||
","
|
||||
192
tests/fuzz/ivf-cell-overflow.c
Normal file
192
tests/fuzz/ivf-cell-overflow.c
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
/**
|
||||
* Fuzz target: IVF cell overflow and boundary conditions.
|
||||
*
|
||||
* Pushes cells past VEC0_IVF_CELL_MAX_VECTORS (64) to trigger cell
|
||||
* splitting, then exercises blob I/O at slot boundaries.
|
||||
*
|
||||
* Targets:
|
||||
* - Cell splitting when n_vectors reaches cap (64)
|
||||
* - Blob offset arithmetic: slot * vecSize, slot / 8, slot % 8
|
||||
* - Validity bitmap at byte boundaries (slot 7->8, 15->16, etc.)
|
||||
* - Insert into full cell -> create new cell path
|
||||
* - Delete from various slot positions (first, last, middle)
|
||||
* - Multiple cells per centroid
|
||||
* - assign-vectors command with multi-cell centroids
|
||||
*/
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
if (size < 8) return 0;
|
||||
|
||||
int rc;
|
||||
sqlite3 *db;
|
||||
|
||||
rc = sqlite3_open(":memory:", &db);
|
||||
assert(rc == SQLITE_OK);
|
||||
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
|
||||
// Use small dimensions for speed but enough vectors to overflow cells
|
||||
int dim = (data[0] % 8) + 2; // 2..9
|
||||
int nlist = (data[1] % 4) + 1; // 1..4
|
||||
// We need >64 vectors to overflow a cell
|
||||
int num_vecs = (data[2] % 64) + 65; // 65..128
|
||||
int delete_pattern = data[3]; // Controls which vectors to delete
|
||||
|
||||
const uint8_t *payload = data + 4;
|
||||
size_t payload_size = size - 4;
|
||||
|
||||
char sql[256];
|
||||
snprintf(sql, sizeof(sql),
|
||||
"CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[%d] indexed by ivf(nlist=%d, nprobe=%d))",
|
||||
dim, nlist, nlist);
|
||||
|
||||
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||
|
||||
// Insert enough vectors to overflow at least one cell
|
||||
sqlite3_stmt *stmtInsert = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
||||
|
||||
size_t offset = 0;
|
||||
for (int i = 0; i < num_vecs; i++) {
|
||||
float *vec = sqlite3_malloc(dim * sizeof(float));
|
||||
if (!vec) break;
|
||||
for (int d = 0; d < dim; d++) {
|
||||
if (offset < payload_size) {
|
||||
vec[d] = ((float)(int8_t)payload[offset++]) / 50.0f;
|
||||
} else {
|
||||
// Cluster vectors near specific centroids to ensure some cells overflow
|
||||
int cluster = i % nlist;
|
||||
vec[d] = (float)cluster + (float)(i % 10) * 0.01f + d * 0.001f;
|
||||
}
|
||||
}
|
||||
sqlite3_reset(stmtInsert);
|
||||
sqlite3_bind_int64(stmtInsert, 1, (int64_t)(i + 1));
|
||||
sqlite3_bind_blob(stmtInsert, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmtInsert);
|
||||
sqlite3_free(vec);
|
||||
}
|
||||
sqlite3_finalize(stmtInsert);
|
||||
|
||||
// Train to assign vectors to centroids (triggers cell building)
|
||||
sqlite3_exec(db,
|
||||
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||
NULL, NULL, NULL);
|
||||
|
||||
// Delete vectors at boundary positions based on fuzz data
|
||||
// This tests validity bitmap manipulation at different slot positions
|
||||
for (int i = 0; i < num_vecs; i++) {
|
||||
int byte_idx = i / 8;
|
||||
if (byte_idx < (int)payload_size && (payload[byte_idx] & (1 << (i % 8)))) {
|
||||
// Use delete_pattern to thin deletions
|
||||
if ((delete_pattern + i) % 3 == 0) {
|
||||
char delsql[64];
|
||||
snprintf(delsql, sizeof(delsql), "DELETE FROM v WHERE rowid = %d", i + 1);
|
||||
sqlite3_exec(db, delsql, NULL, NULL, NULL);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Insert more vectors after deletions (into cells with holes)
|
||||
{
|
||||
sqlite3_stmt *si = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &si, NULL);
|
||||
if (si) {
|
||||
for (int i = 0; i < 10; i++) {
|
||||
float *vec = sqlite3_malloc(dim * sizeof(float));
|
||||
if (!vec) break;
|
||||
for (int d = 0; d < dim; d++)
|
||||
vec[d] = (float)(i + 200) * 0.01f;
|
||||
sqlite3_reset(si);
|
||||
sqlite3_bind_int64(si, 1, (int64_t)(num_vecs + i + 1));
|
||||
sqlite3_bind_blob(si, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||
sqlite3_step(si);
|
||||
sqlite3_free(vec);
|
||||
}
|
||||
sqlite3_finalize(si);
|
||||
}
|
||||
}
|
||||
|
||||
// KNN query that must scan multiple cells per centroid
|
||||
{
|
||||
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||
if (qvec) {
|
||||
for (int d = 0; d < dim; d++) qvec[d] = 0.0f;
|
||||
sqlite3_stmt *sk = NULL;
|
||||
snprintf(sql, sizeof(sql),
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 20");
|
||||
sqlite3_prepare_v2(db, sql, -1, &sk, NULL);
|
||||
if (sk) {
|
||||
sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||
while (sqlite3_step(sk) == SQLITE_ROW) {}
|
||||
sqlite3_finalize(sk);
|
||||
}
|
||||
sqlite3_free(qvec);
|
||||
}
|
||||
}
|
||||
|
||||
// Test assign-vectors with multi-cell state
|
||||
// First clear centroids
|
||||
sqlite3_exec(db,
|
||||
"INSERT INTO v(v) VALUES ('clear-centroids')",
|
||||
NULL, NULL, NULL);
|
||||
|
||||
// Set centroids manually, then assign
|
||||
for (int c = 0; c < nlist; c++) {
|
||||
float *cvec = sqlite3_malloc(dim * sizeof(float));
|
||||
if (!cvec) break;
|
||||
for (int d = 0; d < dim; d++) cvec[d] = (float)c + d * 0.1f;
|
||||
|
||||
char cmd[128];
|
||||
snprintf(cmd, sizeof(cmd),
|
||||
"INSERT INTO v(v, emb) VALUES ('set-centroid:%d', ?)", c);
|
||||
sqlite3_stmt *sc = NULL;
|
||||
sqlite3_prepare_v2(db, cmd, -1, &sc, NULL);
|
||||
if (sc) {
|
||||
sqlite3_bind_blob(sc, 1, cvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||
sqlite3_step(sc);
|
||||
sqlite3_finalize(sc);
|
||||
}
|
||||
sqlite3_free(cvec);
|
||||
}
|
||||
|
||||
sqlite3_exec(db,
|
||||
"INSERT INTO v(v) VALUES ('assign-vectors')",
|
||||
NULL, NULL, NULL);
|
||||
|
||||
// Final query after assign-vectors
|
||||
{
|
||||
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||
if (qvec) {
|
||||
for (int d = 0; d < dim; d++) qvec[d] = 1.0f;
|
||||
sqlite3_stmt *sk = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 5",
|
||||
-1, &sk, NULL);
|
||||
if (sk) {
|
||||
sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||
while (sqlite3_step(sk) == SQLITE_ROW) {}
|
||||
sqlite3_finalize(sk);
|
||||
}
|
||||
sqlite3_free(qvec);
|
||||
}
|
||||
}
|
||||
|
||||
// Full scan
|
||||
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||
|
||||
sqlite3_close(db);
|
||||
return 0;
|
||||
}
|
||||
36
tests/fuzz/ivf-create.c
Normal file
36
tests/fuzz/ivf-create.c
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
int rc;
|
||||
sqlite3 *db;
|
||||
sqlite3_stmt *stmt;
|
||||
|
||||
rc = sqlite3_open(":memory:", &db);
|
||||
assert(rc == SQLITE_OK);
|
||||
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
|
||||
sqlite3_str *s = sqlite3_str_new(NULL);
|
||||
assert(s);
|
||||
sqlite3_str_appendall(s, "CREATE VIRTUAL TABLE v USING vec0(emb float[4] indexed by ivf(");
|
||||
sqlite3_str_appendf(s, "%.*s", (int)size, data);
|
||||
sqlite3_str_appendall(s, "))");
|
||||
const char *zSql = sqlite3_str_finish(s);
|
||||
assert(zSql);
|
||||
|
||||
rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, NULL);
|
||||
sqlite3_free((void *)zSql);
|
||||
if (rc == SQLITE_OK) {
|
||||
sqlite3_step(stmt);
|
||||
}
|
||||
sqlite3_finalize(stmt);
|
||||
sqlite3_close(db);
|
||||
return 0;
|
||||
}
|
||||
16
tests/fuzz/ivf-create.dict
Normal file
16
tests/fuzz/ivf-create.dict
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
"nlist"
|
||||
"nprobe"
|
||||
"quantizer"
|
||||
"oversample"
|
||||
"binary"
|
||||
"int8"
|
||||
"none"
|
||||
"="
|
||||
","
|
||||
"("
|
||||
")"
|
||||
"0"
|
||||
"1"
|
||||
"128"
|
||||
"65536"
|
||||
"65537"
|
||||
180
tests/fuzz/ivf-kmeans.c
Normal file
180
tests/fuzz/ivf-kmeans.c
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
/**
|
||||
* Fuzz target: IVF k-means clustering.
|
||||
*
|
||||
* Builds a table, inserts fuzz-controlled vectors, then runs
|
||||
* compute-centroids with fuzz-controlled parameters (nlist, max_iter, seed).
|
||||
* Targets:
|
||||
* - kmeans with N < k (clamping), N == 1, k == 1
|
||||
* - kmeans with duplicate/identical vectors (all distances zero)
|
||||
* - kmeans with NaN/Inf vectors
|
||||
* - Empty cluster reassignment path (farthest-point heuristic)
|
||||
* - Large nlist relative to N
|
||||
* - The compute-centroids:{json} command parsing
|
||||
* - clear-centroids followed by compute-centroids (round-trip)
|
||||
*/
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
if (size < 10) return 0;
|
||||
|
||||
int rc;
|
||||
sqlite3 *db;
|
||||
|
||||
rc = sqlite3_open(":memory:", &db);
|
||||
assert(rc == SQLITE_OK);
|
||||
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
|
||||
// Parse fuzz header
|
||||
// Byte 0-1: dimension (1..128)
|
||||
// Byte 2: nlist for CREATE (1..64)
|
||||
// Byte 3: nlist override for compute-centroids (0 = use default)
|
||||
// Byte 4: max_iter (1..50)
|
||||
// Byte 5-8: seed
|
||||
// Byte 9: num_vectors (1..64)
|
||||
// Remaining: vector float data
|
||||
|
||||
int dim = (data[0] | (data[1] << 8)) % 128 + 1;
|
||||
int nlist_create = (data[2] % 64) + 1;
|
||||
int nlist_override = data[3] % 65; // 0 means use table default
|
||||
int max_iter = (data[4] % 50) + 1;
|
||||
uint32_t seed = (uint32_t)data[5] | ((uint32_t)data[6] << 8) |
|
||||
((uint32_t)data[7] << 16) | ((uint32_t)data[8] << 24);
|
||||
int num_vecs = (data[9] % 64) + 1;
|
||||
|
||||
const uint8_t *payload = data + 10;
|
||||
size_t payload_size = size - 10;
|
||||
|
||||
char sql[256];
|
||||
snprintf(sql, sizeof(sql),
|
||||
"CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[%d] indexed by ivf(nlist=%d, nprobe=%d))",
|
||||
dim, nlist_create, nlist_create);
|
||||
|
||||
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||
|
||||
// Insert vectors
|
||||
sqlite3_stmt *stmtInsert = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
||||
|
||||
size_t offset = 0;
|
||||
for (int i = 0; i < num_vecs; i++) {
|
||||
float *vec = sqlite3_malloc(dim * sizeof(float));
|
||||
if (!vec) break;
|
||||
|
||||
for (int d = 0; d < dim; d++) {
|
||||
if (offset + 4 <= payload_size) {
|
||||
memcpy(&vec[d], payload + offset, sizeof(float));
|
||||
offset += 4;
|
||||
} else if (offset < payload_size) {
|
||||
// Scale to interesting range including values > 1, < -1
|
||||
vec[d] = ((float)(int8_t)payload[offset++]) / 5.0f;
|
||||
} else {
|
||||
// Reuse earlier bytes to fill remaining dimensions
|
||||
vec[d] = (float)(i * dim + d) * 0.01f;
|
||||
}
|
||||
}
|
||||
|
||||
sqlite3_reset(stmtInsert);
|
||||
sqlite3_bind_int64(stmtInsert, 1, (int64_t)(i + 1));
|
||||
sqlite3_bind_blob(stmtInsert, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmtInsert);
|
||||
sqlite3_free(vec);
|
||||
}
|
||||
sqlite3_finalize(stmtInsert);
|
||||
|
||||
// Exercise compute-centroids with JSON options
|
||||
{
|
||||
char cmd[256];
|
||||
snprintf(cmd, sizeof(cmd),
|
||||
"INSERT INTO v(rowid) VALUES "
|
||||
"('compute-centroids:{\"nlist\":%d,\"max_iterations\":%d,\"seed\":%u}')",
|
||||
nlist_override, max_iter, seed);
|
||||
sqlite3_exec(db, cmd, NULL, NULL, NULL);
|
||||
}
|
||||
|
||||
// KNN query after training
|
||||
{
|
||||
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||
if (qvec) {
|
||||
for (int d = 0; d < dim; d++) {
|
||||
qvec[d] = (d < 3) ? 1.0f : 0.0f;
|
||||
}
|
||||
sqlite3_stmt *stmtKnn = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 5",
|
||||
-1, &stmtKnn, NULL);
|
||||
if (stmtKnn) {
|
||||
sqlite3_bind_blob(stmtKnn, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||
sqlite3_finalize(stmtKnn);
|
||||
}
|
||||
sqlite3_free(qvec);
|
||||
}
|
||||
}
|
||||
|
||||
// Clear centroids and re-compute to test round-trip
|
||||
sqlite3_exec(db,
|
||||
"INSERT INTO v(v) VALUES ('clear-centroids')",
|
||||
NULL, NULL, NULL);
|
||||
|
||||
// Insert a few more vectors in untrained state
|
||||
{
|
||||
sqlite3_stmt *si = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &si, NULL);
|
||||
if (si) {
|
||||
for (int i = 0; i < 3; i++) {
|
||||
float *vec = sqlite3_malloc(dim * sizeof(float));
|
||||
if (!vec) break;
|
||||
for (int d = 0; d < dim; d++) vec[d] = (float)(i + 100) * 0.1f;
|
||||
sqlite3_reset(si);
|
||||
sqlite3_bind_int64(si, 1, (int64_t)(num_vecs + i + 1));
|
||||
sqlite3_bind_blob(si, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||
sqlite3_step(si);
|
||||
sqlite3_free(vec);
|
||||
}
|
||||
sqlite3_finalize(si);
|
||||
}
|
||||
}
|
||||
|
||||
// Re-train
|
||||
sqlite3_exec(db,
|
||||
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||
NULL, NULL, NULL);
|
||||
|
||||
// Delete some rows after training, then query
|
||||
sqlite3_exec(db, "DELETE FROM v WHERE rowid = 1", NULL, NULL, NULL);
|
||||
sqlite3_exec(db, "DELETE FROM v WHERE rowid = 2", NULL, NULL, NULL);
|
||||
|
||||
// Query after deletes
|
||||
{
|
||||
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||
if (qvec) {
|
||||
for (int d = 0; d < dim; d++) qvec[d] = 0.5f;
|
||||
sqlite3_stmt *stmtKnn = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 10",
|
||||
-1, &stmtKnn, NULL);
|
||||
if (stmtKnn) {
|
||||
sqlite3_bind_blob(stmtKnn, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||
sqlite3_finalize(stmtKnn);
|
||||
}
|
||||
sqlite3_free(qvec);
|
||||
}
|
||||
}
|
||||
|
||||
sqlite3_close(db);
|
||||
return 0;
|
||||
}
|
||||
199
tests/fuzz/ivf-knn-deep.c
Normal file
199
tests/fuzz/ivf-knn-deep.c
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
/**
|
||||
* Fuzz target: IVF KNN search deep paths.
|
||||
*
|
||||
* Exercises the full KNN pipeline with fuzz-controlled:
|
||||
* - nprobe values (including > nlist, =1, =nlist)
|
||||
* - Query vectors (including adversarial floats)
|
||||
* - Mix of trained/untrained state
|
||||
* - Oversample + rescore path (quantizer=int8 with oversample>1)
|
||||
* - Multiple interleaved KNN queries
|
||||
* - Candidate array realloc path (many vectors in probed cells)
|
||||
*
|
||||
* Targets:
|
||||
* - ivf_scan_cells_from_stmt: candidate realloc, distance computation
|
||||
* - ivf_query_knn: centroid sorting, nprobe selection
|
||||
* - Oversample rescore: re-ranking with full-precision vectors
|
||||
* - qsort with NaN distances
|
||||
*/
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
static uint16_t read_u16(const uint8_t *p) {
|
||||
return (uint16_t)(p[0] | (p[1] << 8));
|
||||
}
|
||||
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
if (size < 16) return 0;
|
||||
|
||||
int rc;
|
||||
sqlite3 *db;
|
||||
|
||||
rc = sqlite3_open(":memory:", &db);
|
||||
assert(rc == SQLITE_OK);
|
||||
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
|
||||
// Header
|
||||
int dim = (data[0] % 32) + 2; // 2..33
|
||||
int nlist = (data[1] % 16) + 1; // 1..16
|
||||
int nprobe_initial = (data[2] % 20) + 1; // 1..20 (can be > nlist)
|
||||
int quantizer_type = data[3] % 3; // 0=none, 1=int8, 2=binary
|
||||
int oversample = (data[4] % 4) + 1; // 1..4
|
||||
int num_vecs = (data[5] % 80) + 4; // 4..83
|
||||
int num_queries = (data[6] % 8) + 1; // 1..8
|
||||
int k_limit = (data[7] % 20) + 1; // 1..20
|
||||
|
||||
const uint8_t *payload = data + 8;
|
||||
size_t payload_size = size - 8;
|
||||
|
||||
// For binary quantizer, dimension must be multiple of 8
|
||||
if (quantizer_type == 2) {
|
||||
dim = ((dim + 7) / 8) * 8;
|
||||
if (dim == 0) dim = 8;
|
||||
}
|
||||
|
||||
const char *qname;
|
||||
switch (quantizer_type) {
|
||||
case 1: qname = "int8"; break;
|
||||
case 2: qname = "binary"; break;
|
||||
default: qname = "none"; break;
|
||||
}
|
||||
|
||||
// Oversample only valid with quantization
|
||||
if (quantizer_type == 0) oversample = 1;
|
||||
|
||||
// Cap nprobe to nlist for CREATE (parser rejects nprobe > nlist)
|
||||
int nprobe_create = nprobe_initial <= nlist ? nprobe_initial : nlist;
|
||||
|
||||
char sql[512];
|
||||
snprintf(sql, sizeof(sql),
|
||||
"CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[%d] indexed by ivf(nlist=%d, nprobe=%d, quantizer=%s%s))",
|
||||
dim, nlist, nprobe_create, qname,
|
||||
oversample > 1 ? ", oversample=2" : "");
|
||||
|
||||
// If that fails (e.g. oversample with none), try without oversample
|
||||
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||
if (rc != SQLITE_OK) {
|
||||
snprintf(sql, sizeof(sql),
|
||||
"CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[%d] indexed by ivf(nlist=%d, nprobe=%d, quantizer=%s))",
|
||||
dim, nlist, nprobe_create, qname);
|
||||
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||
}
|
||||
|
||||
// Insert vectors
|
||||
sqlite3_stmt *stmtInsert = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
||||
|
||||
size_t offset = 0;
|
||||
for (int i = 0; i < num_vecs; i++) {
|
||||
float *vec = sqlite3_malloc(dim * sizeof(float));
|
||||
if (!vec) break;
|
||||
for (int d = 0; d < dim; d++) {
|
||||
if (offset < payload_size) {
|
||||
vec[d] = ((float)(int8_t)payload[offset++]) / 20.0f;
|
||||
} else {
|
||||
vec[d] = (float)((i * dim + d) % 256 - 128) / 128.0f;
|
||||
}
|
||||
}
|
||||
sqlite3_reset(stmtInsert);
|
||||
sqlite3_bind_int64(stmtInsert, 1, (int64_t)(i + 1));
|
||||
sqlite3_bind_blob(stmtInsert, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmtInsert);
|
||||
sqlite3_free(vec);
|
||||
}
|
||||
sqlite3_finalize(stmtInsert);
|
||||
|
||||
// Query BEFORE training (flat scan path)
|
||||
{
|
||||
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||
if (qvec) {
|
||||
for (int d = 0; d < dim; d++) qvec[d] = 0.5f;
|
||||
sqlite3_stmt *sk = NULL;
|
||||
snprintf(sql, sizeof(sql),
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit);
|
||||
sqlite3_prepare_v2(db, sql, -1, &sk, NULL);
|
||||
if (sk) {
|
||||
sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||
while (sqlite3_step(sk) == SQLITE_ROW) {}
|
||||
sqlite3_finalize(sk);
|
||||
}
|
||||
sqlite3_free(qvec);
|
||||
}
|
||||
}
|
||||
|
||||
// Train
|
||||
sqlite3_exec(db,
|
||||
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||
NULL, NULL, NULL);
|
||||
|
||||
// Change nprobe at runtime (can exceed nlist -- tests clamping in query)
|
||||
{
|
||||
char cmd[64];
|
||||
snprintf(cmd, sizeof(cmd),
|
||||
"INSERT INTO v(v) VALUES ('nprobe=%d')", nprobe_initial);
|
||||
sqlite3_exec(db, cmd, NULL, NULL, NULL);
|
||||
}
|
||||
|
||||
// Multiple KNN queries with different fuzz-derived query vectors
|
||||
for (int q = 0; q < num_queries; q++) {
|
||||
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||
if (!qvec) break;
|
||||
for (int d = 0; d < dim; d++) {
|
||||
if (offset < payload_size) {
|
||||
qvec[d] = ((float)(int8_t)payload[offset++]) / 10.0f;
|
||||
} else {
|
||||
qvec[d] = (q == 0) ? 1.0f : 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
sqlite3_stmt *sk = NULL;
|
||||
snprintf(sql, sizeof(sql),
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit);
|
||||
sqlite3_prepare_v2(db, sql, -1, &sk, NULL);
|
||||
if (sk) {
|
||||
sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||
while (sqlite3_step(sk) == SQLITE_ROW) {}
|
||||
sqlite3_finalize(sk);
|
||||
}
|
||||
sqlite3_free(qvec);
|
||||
}
|
||||
|
||||
// Delete half the vectors then query again
|
||||
for (int i = 1; i <= num_vecs / 2; i++) {
|
||||
char delsql[64];
|
||||
snprintf(delsql, sizeof(delsql), "DELETE FROM v WHERE rowid = %d", i);
|
||||
sqlite3_exec(db, delsql, NULL, NULL, NULL);
|
||||
}
|
||||
|
||||
// Query after mass deletion
|
||||
{
|
||||
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||
if (qvec) {
|
||||
for (int d = 0; d < dim; d++) qvec[d] = -0.5f;
|
||||
sqlite3_stmt *sk = NULL;
|
||||
snprintf(sql, sizeof(sql),
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit);
|
||||
sqlite3_prepare_v2(db, sql, -1, &sk, NULL);
|
||||
if (sk) {
|
||||
sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||
while (sqlite3_step(sk) == SQLITE_ROW) {}
|
||||
sqlite3_finalize(sk);
|
||||
}
|
||||
sqlite3_free(qvec);
|
||||
}
|
||||
}
|
||||
|
||||
sqlite3_close(db);
|
||||
return 0;
|
||||
}
|
||||
121
tests/fuzz/ivf-operations.c
Normal file
121
tests/fuzz/ivf-operations.c
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
if (size < 6) return 0;
|
||||
|
||||
int rc;
|
||||
sqlite3 *db;
|
||||
sqlite3_stmt *stmtInsert = NULL;
|
||||
sqlite3_stmt *stmtDelete = NULL;
|
||||
sqlite3_stmt *stmtKnn = NULL;
|
||||
sqlite3_stmt *stmtScan = NULL;
|
||||
|
||||
rc = sqlite3_open(":memory:", &db);
|
||||
assert(rc == SQLITE_OK);
|
||||
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
|
||||
rc = sqlite3_exec(db,
|
||||
"CREATE VIRTUAL TABLE v USING vec0(emb float[4] indexed by ivf(nlist=4, nprobe=4))",
|
||||
NULL, NULL, NULL);
|
||||
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||
sqlite3_prepare_v2(db,
|
||||
"DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
|
||||
sqlite3_prepare_v2(db,
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 3",
|
||||
-1, &stmtKnn, NULL);
|
||||
sqlite3_prepare_v2(db,
|
||||
"SELECT rowid FROM v", -1, &stmtScan, NULL);
|
||||
|
||||
if (!stmtInsert || !stmtDelete || !stmtKnn || !stmtScan) goto cleanup;
|
||||
|
||||
size_t i = 0;
|
||||
while (i + 2 <= size) {
|
||||
uint8_t op = data[i++] % 7;
|
||||
uint8_t rowid_byte = data[i++];
|
||||
int64_t rowid = (int64_t)(rowid_byte % 32) + 1;
|
||||
|
||||
switch (op) {
|
||||
case 0: {
|
||||
// INSERT: consume 16 bytes for 4 floats, or use what's left
|
||||
float vec[4] = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
for (int j = 0; j < 4 && i < size; j++, i++) {
|
||||
vec[j] = (float)((int8_t)data[i]) / 10.0f;
|
||||
}
|
||||
sqlite3_reset(stmtInsert);
|
||||
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmtInsert);
|
||||
break;
|
||||
}
|
||||
case 1: {
|
||||
// DELETE
|
||||
sqlite3_reset(stmtDelete);
|
||||
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||
sqlite3_step(stmtDelete);
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
// KNN query with a fixed query vector
|
||||
float qvec[4] = {1.0f, 0.0f, 0.0f, 0.0f};
|
||||
sqlite3_reset(stmtKnn);
|
||||
sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||
break;
|
||||
}
|
||||
case 3: {
|
||||
// Full scan
|
||||
sqlite3_reset(stmtScan);
|
||||
while (sqlite3_step(stmtScan) == SQLITE_ROW) {}
|
||||
break;
|
||||
}
|
||||
case 4: {
|
||||
// compute-centroids command
|
||||
sqlite3_exec(db,
|
||||
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||
NULL, NULL, NULL);
|
||||
break;
|
||||
}
|
||||
case 5: {
|
||||
// clear-centroids command
|
||||
sqlite3_exec(db,
|
||||
"INSERT INTO v(v) VALUES ('clear-centroids')",
|
||||
NULL, NULL, NULL);
|
||||
break;
|
||||
}
|
||||
case 6: {
|
||||
// nprobe=N command
|
||||
if (i < size) {
|
||||
uint8_t n = data[i++];
|
||||
int nprobe = (n % 4) + 1;
|
||||
char buf[64];
|
||||
snprintf(buf, sizeof(buf),
|
||||
"INSERT INTO v(v) VALUES ('nprobe=%d')", nprobe);
|
||||
sqlite3_exec(db, buf, NULL, NULL, NULL);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Final operations — must not crash regardless of prior state
|
||||
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||
|
||||
cleanup:
|
||||
sqlite3_finalize(stmtInsert);
|
||||
sqlite3_finalize(stmtDelete);
|
||||
sqlite3_finalize(stmtKnn);
|
||||
sqlite3_finalize(stmtScan);
|
||||
sqlite3_close(db);
|
||||
return 0;
|
||||
}
|
||||
129
tests/fuzz/ivf-quantize.c
Normal file
129
tests/fuzz/ivf-quantize.c
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
/**
|
||||
* Fuzz target: IVF quantization functions.
|
||||
*
|
||||
* Directly exercises ivf_quantize_int8 and ivf_quantize_binary with
|
||||
* fuzz-controlled dimensions and float data. Targets:
|
||||
* - ivf_quantize_int8: clamping, int8 overflow boundary
|
||||
* - ivf_quantize_binary: D not divisible by 8, memset(D/8) undercount
|
||||
* - Round-trip through CREATE TABLE + INSERT with quantized IVF
|
||||
*/
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <math.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
if (size < 8) return 0;
|
||||
|
||||
int rc;
|
||||
sqlite3 *db;
|
||||
|
||||
rc = sqlite3_open(":memory:", &db);
|
||||
assert(rc == SQLITE_OK);
|
||||
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
|
||||
// Byte 0: quantizer type (0=int8, 1=binary)
|
||||
// Byte 1: dimension (1..64, but we test edge cases)
|
||||
// Byte 2: nlist (1..8)
|
||||
// Byte 3: num_vectors to insert (1..32)
|
||||
// Remaining: float data
|
||||
int qtype = data[0] % 2;
|
||||
int dim = (data[1] % 64) + 1;
|
||||
int nlist = (data[2] % 8) + 1;
|
||||
int num_vecs = (data[3] % 32) + 1;
|
||||
const uint8_t *payload = data + 4;
|
||||
size_t payload_size = size - 4;
|
||||
|
||||
// For binary quantizer, D must be multiple of 8 to avoid the D/8 bug
|
||||
// in production. But we explicitly want to test non-multiples too to
|
||||
// find the bug. Use dim as-is.
|
||||
const char *quantizer = qtype ? "binary" : "int8";
|
||||
|
||||
// Binary quantizer needs D multiple of 8 in current code, but let's
|
||||
// test both valid and invalid dimensions to see what happens.
|
||||
// For binary with non-multiple-of-8, the code does memset(dst, 0, D/8)
|
||||
// which underallocates when D%8 != 0.
|
||||
char sql[256];
|
||||
snprintf(sql, sizeof(sql),
|
||||
"CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[%d] indexed by ivf(nlist=%d, nprobe=%d, quantizer=%s))",
|
||||
dim, nlist, nlist, quantizer);
|
||||
|
||||
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||
|
||||
// Insert vectors with fuzz-controlled float values
|
||||
sqlite3_stmt *stmtInsert = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
||||
|
||||
size_t offset = 0;
|
||||
for (int i = 0; i < num_vecs && offset < payload_size; i++) {
|
||||
// Build float vector from fuzz data
|
||||
float *vec = sqlite3_malloc(dim * sizeof(float));
|
||||
if (!vec) break;
|
||||
|
||||
for (int d = 0; d < dim; d++) {
|
||||
if (offset + 4 <= payload_size) {
|
||||
// Use raw bytes as float -- can produce NaN, Inf, denormals
|
||||
memcpy(&vec[d], payload + offset, sizeof(float));
|
||||
offset += 4;
|
||||
} else if (offset < payload_size) {
|
||||
// Partial: use byte as scaled value
|
||||
vec[d] = ((float)(int8_t)payload[offset++]) / 50.0f;
|
||||
} else {
|
||||
vec[d] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
sqlite3_reset(stmtInsert);
|
||||
sqlite3_bind_int64(stmtInsert, 1, (int64_t)(i + 1));
|
||||
sqlite3_bind_blob(stmtInsert, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmtInsert);
|
||||
sqlite3_free(vec);
|
||||
}
|
||||
sqlite3_finalize(stmtInsert);
|
||||
|
||||
// Trigger compute-centroids to exercise kmeans + quantization together
|
||||
sqlite3_exec(db,
|
||||
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||
NULL, NULL, NULL);
|
||||
|
||||
// KNN query with fuzz-derived query vector
|
||||
{
|
||||
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||
if (qvec) {
|
||||
for (int d = 0; d < dim; d++) {
|
||||
if (offset < payload_size) {
|
||||
qvec[d] = ((float)(int8_t)payload[offset++]) / 10.0f;
|
||||
} else {
|
||||
qvec[d] = 1.0f;
|
||||
}
|
||||
}
|
||||
|
||||
sqlite3_stmt *stmtKnn = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 5",
|
||||
-1, &stmtKnn, NULL);
|
||||
if (stmtKnn) {
|
||||
sqlite3_bind_blob(stmtKnn, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||
sqlite3_finalize(stmtKnn);
|
||||
}
|
||||
sqlite3_free(qvec);
|
||||
}
|
||||
}
|
||||
|
||||
// Full scan
|
||||
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||
|
||||
sqlite3_close(db);
|
||||
return 0;
|
||||
}
|
||||
182
tests/fuzz/ivf-rescore.c
Normal file
182
tests/fuzz/ivf-rescore.c
Normal file
|
|
@ -0,0 +1,182 @@
|
|||
/**
|
||||
* Fuzz target: IVF oversample + rescore path.
|
||||
*
|
||||
* Specifically targets the code path where quantizer != none AND
|
||||
* oversample > 1, which triggers:
|
||||
* 1. Quantized KNN scan to collect oversample*k candidates
|
||||
* 2. Full-precision vector lookup from _ivf_vectors table
|
||||
* 3. Re-scoring with float32 distances
|
||||
* 4. Re-sort and truncation
|
||||
*
|
||||
* This path has the most complex memory management in the KNN query:
|
||||
* - Two separate distance computations (quantized + float)
|
||||
* - Cross-table lookups (cells + vectors KV store)
|
||||
* - Candidate array resizing
|
||||
* - qsort over partially re-scored arrays
|
||||
*
|
||||
* Also tests the int8 + binary quantization round-trip fidelity
|
||||
* under adversarial float inputs.
|
||||
*/
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <math.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
if (size < 12) return 0;
|
||||
|
||||
int rc;
|
||||
sqlite3 *db;
|
||||
|
||||
rc = sqlite3_open(":memory:", &db);
|
||||
assert(rc == SQLITE_OK);
|
||||
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
|
||||
// Header
|
||||
int quantizer_type = (data[0] % 2) + 1; // 1=int8, 2=binary (never none)
|
||||
int dim = (data[1] % 32) + 8; // 8..39
|
||||
int nlist = (data[2] % 8) + 1; // 1..8
|
||||
int oversample = (data[3] % 4) + 2; // 2..5 (always > 1)
|
||||
int num_vecs = (data[4] % 60) + 8; // 8..67
|
||||
int k_limit = (data[5] % 15) + 1; // 1..15
|
||||
|
||||
const uint8_t *payload = data + 6;
|
||||
size_t payload_size = size - 6;
|
||||
|
||||
// Binary quantizer needs D multiple of 8
|
||||
if (quantizer_type == 2) {
|
||||
dim = ((dim + 7) / 8) * 8;
|
||||
}
|
||||
|
||||
const char *qname = (quantizer_type == 1) ? "int8" : "binary";
|
||||
|
||||
char sql[512];
|
||||
snprintf(sql, sizeof(sql),
|
||||
"CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[%d] indexed by ivf(nlist=%d, nprobe=%d, quantizer=%s, oversample=%d))",
|
||||
dim, nlist, nlist, qname, oversample);
|
||||
|
||||
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||
|
||||
// Insert vectors with diverse values
|
||||
sqlite3_stmt *stmtInsert = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
||||
|
||||
size_t offset = 0;
|
||||
for (int i = 0; i < num_vecs; i++) {
|
||||
float *vec = sqlite3_malloc(dim * sizeof(float));
|
||||
if (!vec) break;
|
||||
for (int d = 0; d < dim; d++) {
|
||||
if (offset + 4 <= payload_size) {
|
||||
// Use raw bytes as float for adversarial values
|
||||
memcpy(&vec[d], payload + offset, sizeof(float));
|
||||
offset += 4;
|
||||
// Sanitize: replace NaN/Inf with bounded values to avoid
|
||||
// poisoning the entire computation. We want edge values,
|
||||
// not complete nonsense.
|
||||
if (isnan(vec[d]) || isinf(vec[d])) {
|
||||
vec[d] = (vec[d] > 0) ? 1e6f : -1e6f;
|
||||
if (isnan(vec[d])) vec[d] = 0.0f;
|
||||
}
|
||||
} else if (offset < payload_size) {
|
||||
vec[d] = ((float)(int8_t)payload[offset++]) / 30.0f;
|
||||
} else {
|
||||
vec[d] = (float)(i * dim + d) * 0.001f;
|
||||
}
|
||||
}
|
||||
sqlite3_reset(stmtInsert);
|
||||
sqlite3_bind_int64(stmtInsert, 1, (int64_t)(i + 1));
|
||||
sqlite3_bind_blob(stmtInsert, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmtInsert);
|
||||
sqlite3_free(vec);
|
||||
}
|
||||
sqlite3_finalize(stmtInsert);
|
||||
|
||||
// Train
|
||||
sqlite3_exec(db,
|
||||
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||
NULL, NULL, NULL);
|
||||
|
||||
// Multiple KNN queries to exercise rescore path
|
||||
for (int q = 0; q < 4; q++) {
|
||||
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||
if (!qvec) break;
|
||||
for (int d = 0; d < dim; d++) {
|
||||
if (offset < payload_size) {
|
||||
qvec[d] = ((float)(int8_t)payload[offset++]) / 10.0f;
|
||||
} else {
|
||||
qvec[d] = (q == 0) ? 1.0f : (q == 1) ? -1.0f : 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
sqlite3_stmt *sk = NULL;
|
||||
snprintf(sql, sizeof(sql),
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit);
|
||||
sqlite3_prepare_v2(db, sql, -1, &sk, NULL);
|
||||
if (sk) {
|
||||
sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||
while (sqlite3_step(sk) == SQLITE_ROW) {}
|
||||
sqlite3_finalize(sk);
|
||||
}
|
||||
sqlite3_free(qvec);
|
||||
}
|
||||
|
||||
// Delete some vectors, then query again (rescore with missing _ivf_vectors rows)
|
||||
for (int i = 1; i <= num_vecs / 3; i++) {
|
||||
char delsql[64];
|
||||
snprintf(delsql, sizeof(delsql), "DELETE FROM v WHERE rowid = %d", i);
|
||||
sqlite3_exec(db, delsql, NULL, NULL, NULL);
|
||||
}
|
||||
|
||||
{
|
||||
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||
if (qvec) {
|
||||
for (int d = 0; d < dim; d++) qvec[d] = 0.5f;
|
||||
sqlite3_stmt *sk = NULL;
|
||||
snprintf(sql, sizeof(sql),
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit);
|
||||
sqlite3_prepare_v2(db, sql, -1, &sk, NULL);
|
||||
if (sk) {
|
||||
sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||
while (sqlite3_step(sk) == SQLITE_ROW) {}
|
||||
sqlite3_finalize(sk);
|
||||
}
|
||||
sqlite3_free(qvec);
|
||||
}
|
||||
}
|
||||
|
||||
// Retrain after deletions
|
||||
sqlite3_exec(db,
|
||||
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||
NULL, NULL, NULL);
|
||||
|
||||
// Query after retrain
|
||||
{
|
||||
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||
if (qvec) {
|
||||
for (int d = 0; d < dim; d++) qvec[d] = -0.3f;
|
||||
sqlite3_stmt *sk = NULL;
|
||||
snprintf(sql, sizeof(sql),
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit);
|
||||
sqlite3_prepare_v2(db, sql, -1, &sk, NULL);
|
||||
if (sk) {
|
||||
sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||
while (sqlite3_step(sk) == SQLITE_ROW) {}
|
||||
sqlite3_finalize(sk);
|
||||
}
|
||||
sqlite3_free(qvec);
|
||||
}
|
||||
}
|
||||
|
||||
sqlite3_close(db);
|
||||
return 0;
|
||||
}
|
||||
228
tests/fuzz/ivf-shadow-corrupt.c
Normal file
228
tests/fuzz/ivf-shadow-corrupt.c
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
/**
|
||||
* Fuzz target: IVF shadow table corruption.
|
||||
*
|
||||
* Creates a trained IVF table, then corrupts IVF shadow table blobs
|
||||
* (centroids, cells validity/rowids/vectors, rowid_map) with fuzz data.
|
||||
* Then exercises all read/write paths. Must not crash.
|
||||
*
|
||||
* Targets:
|
||||
* - Cell validity bitmap with wrong size
|
||||
* - Cell rowids blob with wrong size/alignment
|
||||
* - Cell vectors blob with wrong size
|
||||
* - Centroid blob with wrong size
|
||||
* - n_vectors inconsistent with validity bitmap
|
||||
* - Missing rowid_map entries
|
||||
* - KNN scan over corrupted cells
|
||||
* - Insert/delete with corrupted rowid_map
|
||||
*/
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
if (size < 4) return 0;
|
||||
|
||||
int rc;
|
||||
sqlite3 *db;
|
||||
|
||||
rc = sqlite3_open(":memory:", &db);
|
||||
assert(rc == SQLITE_OK);
|
||||
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
|
||||
// Create IVF table and insert enough vectors to train
|
||||
rc = sqlite3_exec(db,
|
||||
"CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[8] indexed by ivf(nlist=2, nprobe=2))",
|
||||
NULL, NULL, NULL);
|
||||
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||
|
||||
// Insert 10 vectors
|
||||
{
|
||||
sqlite3_stmt *si = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &si, NULL);
|
||||
if (!si) { sqlite3_close(db); return 0; }
|
||||
for (int i = 0; i < 10; i++) {
|
||||
float vec[8];
|
||||
for (int d = 0; d < 8; d++) {
|
||||
vec[d] = (float)(i * 8 + d) * 0.1f;
|
||||
}
|
||||
sqlite3_reset(si);
|
||||
sqlite3_bind_int64(si, 1, i + 1);
|
||||
sqlite3_bind_blob(si, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||
sqlite3_step(si);
|
||||
}
|
||||
sqlite3_finalize(si);
|
||||
}
|
||||
|
||||
// Train
|
||||
sqlite3_exec(db,
|
||||
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||
NULL, NULL, NULL);
|
||||
|
||||
// Now corrupt shadow tables based on fuzz input
|
||||
uint8_t target = data[0] % 10;
|
||||
const uint8_t *payload = data + 1;
|
||||
int payload_size = (int)(size - 1);
|
||||
|
||||
// Limit payload to avoid huge allocations
|
||||
if (payload_size > 4096) payload_size = 4096;
|
||||
|
||||
sqlite3_stmt *stmt = NULL;
|
||||
|
||||
switch (target) {
|
||||
case 0: {
|
||||
// Corrupt cell validity blob
|
||||
rc = sqlite3_prepare_v2(db,
|
||||
"UPDATE v_ivf_cells00 SET validity = ? WHERE rowid = 1",
|
||||
-1, &stmt, NULL);
|
||||
if (rc == SQLITE_OK) {
|
||||
sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC);
|
||||
sqlite3_step(stmt); sqlite3_finalize(stmt);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 1: {
|
||||
// Corrupt cell rowids blob
|
||||
rc = sqlite3_prepare_v2(db,
|
||||
"UPDATE v_ivf_cells00 SET rowids = ? WHERE rowid = 1",
|
||||
-1, &stmt, NULL);
|
||||
if (rc == SQLITE_OK) {
|
||||
sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC);
|
||||
sqlite3_step(stmt); sqlite3_finalize(stmt);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
// Corrupt cell vectors blob
|
||||
rc = sqlite3_prepare_v2(db,
|
||||
"UPDATE v_ivf_cells00 SET vectors = ? WHERE rowid = 1",
|
||||
-1, &stmt, NULL);
|
||||
if (rc == SQLITE_OK) {
|
||||
sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC);
|
||||
sqlite3_step(stmt); sqlite3_finalize(stmt);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 3: {
|
||||
// Corrupt centroid blob
|
||||
rc = sqlite3_prepare_v2(db,
|
||||
"UPDATE v_ivf_centroids00 SET centroid = ? WHERE centroid_id = 0",
|
||||
-1, &stmt, NULL);
|
||||
if (rc == SQLITE_OK) {
|
||||
sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC);
|
||||
sqlite3_step(stmt); sqlite3_finalize(stmt);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 4: {
|
||||
// Set n_vectors to a bogus value (larger than cell capacity)
|
||||
int bogus_n = 99999;
|
||||
if (payload_size >= 4) {
|
||||
memcpy(&bogus_n, payload, 4);
|
||||
bogus_n = abs(bogus_n) % 100000;
|
||||
}
|
||||
char sql[128];
|
||||
snprintf(sql, sizeof(sql),
|
||||
"UPDATE v_ivf_cells00 SET n_vectors = %d WHERE rowid = 1", bogus_n);
|
||||
sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||
break;
|
||||
}
|
||||
case 5: {
|
||||
// Delete rowid_map entries (orphan vectors)
|
||||
sqlite3_exec(db,
|
||||
"DELETE FROM v_ivf_rowid_map00 WHERE rowid IN (1, 2, 3)",
|
||||
NULL, NULL, NULL);
|
||||
break;
|
||||
}
|
||||
case 6: {
|
||||
// Corrupt rowid_map slot values
|
||||
char sql[128];
|
||||
int bogus_slot = payload_size > 0 ? (int)payload[0] * 1000 : 99999;
|
||||
snprintf(sql, sizeof(sql),
|
||||
"UPDATE v_ivf_rowid_map00 SET slot = %d WHERE rowid = 1", bogus_slot);
|
||||
sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||
break;
|
||||
}
|
||||
case 7: {
|
||||
// Corrupt rowid_map cell_id values
|
||||
sqlite3_exec(db,
|
||||
"UPDATE v_ivf_rowid_map00 SET cell_id = 99999 WHERE rowid = 1",
|
||||
NULL, NULL, NULL);
|
||||
break;
|
||||
}
|
||||
case 8: {
|
||||
// Delete all centroids (make trained but no centroids)
|
||||
sqlite3_exec(db,
|
||||
"DELETE FROM v_ivf_centroids00",
|
||||
NULL, NULL, NULL);
|
||||
break;
|
||||
}
|
||||
case 9: {
|
||||
// Set validity to NULL
|
||||
sqlite3_exec(db,
|
||||
"UPDATE v_ivf_cells00 SET validity = NULL WHERE rowid = 1",
|
||||
NULL, NULL, NULL);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Exercise all read paths over corrupted state — must not crash
|
||||
float qvec[8] = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
|
||||
|
||||
// KNN query
|
||||
{
|
||||
sqlite3_stmt *sk = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 5",
|
||||
-1, &sk, NULL);
|
||||
if (sk) {
|
||||
sqlite3_bind_blob(sk, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||
while (sqlite3_step(sk) == SQLITE_ROW) {}
|
||||
sqlite3_finalize(sk);
|
||||
}
|
||||
}
|
||||
|
||||
// Full scan
|
||||
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||
|
||||
// Point query
|
||||
sqlite3_exec(db, "SELECT * FROM v WHERE rowid = 1", NULL, NULL, NULL);
|
||||
sqlite3_exec(db, "SELECT * FROM v WHERE rowid = 5", NULL, NULL, NULL);
|
||||
|
||||
// Delete
|
||||
sqlite3_exec(db, "DELETE FROM v WHERE rowid = 3", NULL, NULL, NULL);
|
||||
|
||||
// Insert after corruption
|
||||
{
|
||||
float newvec[8] = {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f};
|
||||
sqlite3_stmt *si = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &si, NULL);
|
||||
if (si) {
|
||||
sqlite3_bind_int64(si, 1, 100);
|
||||
sqlite3_bind_blob(si, 2, newvec, sizeof(newvec), SQLITE_STATIC);
|
||||
sqlite3_step(si);
|
||||
sqlite3_finalize(si);
|
||||
}
|
||||
}
|
||||
|
||||
// compute-centroids over corrupted state
|
||||
sqlite3_exec(db,
|
||||
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||
NULL, NULL, NULL);
|
||||
|
||||
// clear-centroids
|
||||
sqlite3_exec(db,
|
||||
"INSERT INTO v(v) VALUES ('clear-centroids')",
|
||||
NULL, NULL, NULL);
|
||||
|
||||
sqlite3_close(db);
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
|
@ -8,9 +7,6 @@
|
|||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
extern int sqlite3_vec_numpy_init(sqlite3 *db, char **pzErrMsg,
|
||||
const sqlite3_api_routines *pApi);
|
||||
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
int rc = SQLITE_OK;
|
||||
sqlite3 *db;
|
||||
|
|
@ -20,17 +16,20 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
|||
assert(rc == SQLITE_OK);
|
||||
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
rc = sqlite3_vec_numpy_init(db, NULL, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
|
||||
rc = sqlite3_prepare_v2(db, "select * from vec_npy_each(?)", -1, &stmt, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
sqlite3_bind_blob(stmt, 1, data, size, SQLITE_STATIC);
|
||||
rc = sqlite3_step(stmt);
|
||||
while (rc == SQLITE_ROW) {
|
||||
rc = sqlite3_step(stmt);
|
||||
sqlite3_str *s = sqlite3_str_new(NULL);
|
||||
assert(s);
|
||||
sqlite3_str_appendall(s, "CREATE VIRTUAL TABLE v USING vec0(emb float[128] indexed by rescore(");
|
||||
sqlite3_str_appendf(s, "%.*s", (int)size, data);
|
||||
sqlite3_str_appendall(s, "))");
|
||||
const char *zSql = sqlite3_str_finish(s);
|
||||
assert(zSql);
|
||||
|
||||
rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, NULL);
|
||||
sqlite3_free((void *)zSql);
|
||||
if (rc == SQLITE_OK) {
|
||||
sqlite3_step(stmt);
|
||||
}
|
||||
|
||||
sqlite3_finalize(stmt);
|
||||
sqlite3_close(db);
|
||||
return 0;
|
||||
20
tests/fuzz/rescore-create.dict
Normal file
20
tests/fuzz/rescore-create.dict
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
"rescore"
|
||||
"quantizer"
|
||||
"bit"
|
||||
"int8"
|
||||
"oversample"
|
||||
"indexed"
|
||||
"by"
|
||||
"float"
|
||||
"("
|
||||
")"
|
||||
","
|
||||
"="
|
||||
"["
|
||||
"]"
|
||||
"1"
|
||||
"8"
|
||||
"16"
|
||||
"128"
|
||||
"256"
|
||||
"1024"
|
||||
151
tests/fuzz/rescore-interleave.c
Normal file
151
tests/fuzz/rescore-interleave.c
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
/**
|
||||
* Fuzz target: interleaved insert/update/delete/KNN operations on rescore
|
||||
* tables with BOTH quantizer types, exercising the int8 quantizer path
|
||||
* and the update code path that the existing rescore-operations.c misses.
|
||||
*
|
||||
* Key differences from rescore-operations.c:
|
||||
* - Tests BOTH bit and int8 quantizers (the existing target only tests bit)
|
||||
* - Fuzz-controlled query vectors (not fixed [1,0,0,...])
|
||||
* - Exercises the UPDATE path (line 9080+ in sqlite-vec.c)
|
||||
* - Tests with 16 dimensions (more realistic, exercises more of the
|
||||
* quantization loop)
|
||||
* - Interleaves KNN between mutations to stress the blob_reopen path
|
||||
* when _rescore_vectors rows have been deleted/modified
|
||||
*/
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
if (size < 8) return 0;
|
||||
|
||||
int rc;
|
||||
sqlite3 *db;
|
||||
sqlite3_stmt *stmtInsert = NULL;
|
||||
sqlite3_stmt *stmtUpdate = NULL;
|
||||
sqlite3_stmt *stmtDelete = NULL;
|
||||
sqlite3_stmt *stmtKnn = NULL;
|
||||
|
||||
rc = sqlite3_open(":memory:", &db);
|
||||
assert(rc == SQLITE_OK);
|
||||
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
|
||||
/* Use first byte to pick quantizer */
|
||||
int use_int8 = data[0] & 1;
|
||||
data++; size--;
|
||||
|
||||
const char *create_sql = use_int8
|
||||
? "CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[16] indexed by rescore(quantizer=int8))"
|
||||
: "CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[16] indexed by rescore(quantizer=bit))";
|
||||
|
||||
rc = sqlite3_exec(db, create_sql, NULL, NULL, NULL);
|
||||
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||
sqlite3_prepare_v2(db,
|
||||
"UPDATE v SET emb = ? WHERE rowid = ?", -1, &stmtUpdate, NULL);
|
||||
sqlite3_prepare_v2(db,
|
||||
"DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
|
||||
sqlite3_prepare_v2(db,
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? "
|
||||
"ORDER BY distance LIMIT 5", -1, &stmtKnn, NULL);
|
||||
|
||||
if (!stmtInsert || !stmtUpdate || !stmtDelete || !stmtKnn)
|
||||
goto cleanup;
|
||||
|
||||
size_t i = 0;
|
||||
while (i + 2 <= size) {
|
||||
uint8_t op = data[i++] % 5; /* 5 operations now */
|
||||
uint8_t rowid_byte = data[i++];
|
||||
int64_t rowid = (int64_t)(rowid_byte % 24) + 1;
|
||||
|
||||
switch (op) {
|
||||
case 0: {
|
||||
/* INSERT: consume bytes for 16 floats */
|
||||
float vec[16] = {0};
|
||||
for (int j = 0; j < 16 && i < size; j++, i++) {
|
||||
vec[j] = (float)((int8_t)data[i]) / 8.0f;
|
||||
}
|
||||
sqlite3_reset(stmtInsert);
|
||||
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmtInsert);
|
||||
break;
|
||||
}
|
||||
case 1: {
|
||||
/* DELETE */
|
||||
sqlite3_reset(stmtDelete);
|
||||
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||
sqlite3_step(stmtDelete);
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
/* KNN with fuzz-controlled query vector */
|
||||
float qvec[16] = {0};
|
||||
for (int j = 0; j < 16 && i < size; j++, i++) {
|
||||
qvec[j] = (float)((int8_t)data[i]) / 4.0f;
|
||||
}
|
||||
sqlite3_reset(stmtKnn);
|
||||
sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {
|
||||
(void)sqlite3_column_int64(stmtKnn, 0);
|
||||
(void)sqlite3_column_double(stmtKnn, 1);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 3: {
|
||||
/* UPDATE: modify an existing vector (exercises rescore update path) */
|
||||
float vec[16] = {0};
|
||||
for (int j = 0; j < 16 && i < size; j++, i++) {
|
||||
vec[j] = (float)((int8_t)data[i]) / 6.0f;
|
||||
}
|
||||
sqlite3_reset(stmtUpdate);
|
||||
sqlite3_bind_blob(stmtUpdate, 1, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||
sqlite3_bind_int64(stmtUpdate, 2, rowid);
|
||||
sqlite3_step(stmtUpdate);
|
||||
break;
|
||||
}
|
||||
case 4: {
|
||||
/* INSERT then immediately UPDATE same row (stresses blob lifecycle) */
|
||||
float vec1[16] = {0};
|
||||
float vec2[16] = {0};
|
||||
for (int j = 0; j < 16 && i < size; j++, i++) {
|
||||
vec1[j] = (float)((int8_t)data[i]) / 10.0f;
|
||||
vec2[j] = -vec1[j]; /* opposite direction */
|
||||
}
|
||||
/* Insert */
|
||||
sqlite3_reset(stmtInsert);
|
||||
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||
sqlite3_bind_blob(stmtInsert, 2, vec1, sizeof(vec1), SQLITE_TRANSIENT);
|
||||
if (sqlite3_step(stmtInsert) == SQLITE_DONE) {
|
||||
/* Only update if insert succeeded (rowid might already exist) */
|
||||
sqlite3_reset(stmtUpdate);
|
||||
sqlite3_bind_blob(stmtUpdate, 1, vec2, sizeof(vec2), SQLITE_TRANSIENT);
|
||||
sqlite3_bind_int64(stmtUpdate, 2, rowid);
|
||||
sqlite3_step(stmtUpdate);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Final consistency check: full scan must not crash */
|
||||
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||
|
||||
cleanup:
|
||||
sqlite3_finalize(stmtInsert);
|
||||
sqlite3_finalize(stmtUpdate);
|
||||
sqlite3_finalize(stmtDelete);
|
||||
sqlite3_finalize(stmtKnn);
|
||||
sqlite3_close(db);
|
||||
return 0;
|
||||
}
|
||||
178
tests/fuzz/rescore-knn-deep.c
Normal file
178
tests/fuzz/rescore-knn-deep.c
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
/**
|
||||
* Fuzz target: deep exercise of rescore KNN with fuzz-controlled query vectors
|
||||
* and both quantizer types (bit + int8), multiple distance metrics.
|
||||
*
|
||||
* The existing rescore-operations.c only tests bit quantizer with a fixed
|
||||
* query vector. This target:
|
||||
* - Tests both bit and int8 quantizers
|
||||
* - Uses fuzz-controlled query vectors (hits NaN/Inf/denormal paths)
|
||||
* - Tests all distance metrics with int8 (L2, cosine, L1)
|
||||
* - Exercises large LIMIT values (oversample multiplication)
|
||||
* - Tests KNN with rowid IN constraints
|
||||
* - Exercises the insert->query->update->query->delete->query cycle
|
||||
*/
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
if (size < 20) return 0;
|
||||
|
||||
int rc;
|
||||
sqlite3 *db;
|
||||
|
||||
rc = sqlite3_open(":memory:", &db);
|
||||
assert(rc == SQLITE_OK);
|
||||
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
|
||||
/* Use first 4 bytes for configuration */
|
||||
uint8_t config = data[0];
|
||||
uint8_t num_inserts = (data[1] % 20) + 3; /* 3..22 inserts */
|
||||
uint8_t limit_val = (data[2] % 50) + 1; /* 1..50 for LIMIT */
|
||||
uint8_t metric_choice = data[3] % 3;
|
||||
data += 4;
|
||||
size -= 4;
|
||||
|
||||
int use_int8 = config & 1;
|
||||
const char *metric_str;
|
||||
switch (metric_choice) {
|
||||
case 0: metric_str = ""; break; /* default L2 */
|
||||
case 1: metric_str = " distance_metric=cosine"; break;
|
||||
case 2: metric_str = " distance_metric=l1"; break;
|
||||
default: metric_str = ""; break;
|
||||
}
|
||||
|
||||
/* Build CREATE TABLE statement */
|
||||
char create_sql[256];
|
||||
if (use_int8) {
|
||||
snprintf(create_sql, sizeof(create_sql),
|
||||
"CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[16] indexed by rescore(quantizer=int8)%s)", metric_str);
|
||||
} else {
|
||||
/* bit quantizer ignores distance_metric for the coarse pass (always hamming),
|
||||
but the float rescore phase uses the specified metric */
|
||||
snprintf(create_sql, sizeof(create_sql),
|
||||
"CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[16] indexed by rescore(quantizer=bit)%s)", metric_str);
|
||||
}
|
||||
|
||||
rc = sqlite3_exec(db, create_sql, NULL, NULL, NULL);
|
||||
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||
|
||||
/* Insert vectors using fuzz data */
|
||||
{
|
||||
sqlite3_stmt *ins = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &ins, NULL);
|
||||
if (!ins) { sqlite3_close(db); return 0; }
|
||||
|
||||
size_t cursor = 0;
|
||||
for (int i = 0; i < num_inserts && cursor + 1 < size; i++) {
|
||||
float vec[16];
|
||||
for (int j = 0; j < 16; j++) {
|
||||
if (cursor < size) {
|
||||
/* Map fuzz byte to float -- includes potential for
|
||||
interesting float values via reinterpretation */
|
||||
int8_t sb = (int8_t)data[cursor++];
|
||||
vec[j] = (float)sb / 5.0f;
|
||||
} else {
|
||||
vec[j] = 0.0f;
|
||||
}
|
||||
}
|
||||
sqlite3_reset(ins);
|
||||
sqlite3_bind_int64(ins, 1, (sqlite3_int64)(i + 1));
|
||||
sqlite3_bind_blob(ins, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||
sqlite3_step(ins);
|
||||
}
|
||||
sqlite3_finalize(ins);
|
||||
}
|
||||
|
||||
/* Build a fuzz-controlled query vector from remaining data */
|
||||
float qvec[16] = {0};
|
||||
{
|
||||
size_t cursor = 0;
|
||||
for (int j = 0; j < 16 && cursor < size; j++) {
|
||||
int8_t sb = (int8_t)data[cursor++];
|
||||
qvec[j] = (float)sb / 3.0f;
|
||||
}
|
||||
}
|
||||
|
||||
/* KNN query with fuzz-controlled vector and LIMIT */
|
||||
{
|
||||
char knn_sql[256];
|
||||
snprintf(knn_sql, sizeof(knn_sql),
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? "
|
||||
"ORDER BY distance LIMIT %d", (int)limit_val);
|
||||
|
||||
sqlite3_stmt *knn = NULL;
|
||||
sqlite3_prepare_v2(db, knn_sql, -1, &knn, NULL);
|
||||
if (knn) {
|
||||
sqlite3_bind_blob(knn, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||
while (sqlite3_step(knn) == SQLITE_ROW) {
|
||||
/* Read results to ensure distance computation didn't produce garbage
|
||||
that crashes the cursor iteration */
|
||||
(void)sqlite3_column_int64(knn, 0);
|
||||
(void)sqlite3_column_double(knn, 1);
|
||||
}
|
||||
sqlite3_finalize(knn);
|
||||
}
|
||||
}
|
||||
|
||||
/* Update some vectors, then query again */
|
||||
{
|
||||
float uvec[16];
|
||||
for (int j = 0; j < 16; j++) uvec[j] = qvec[15 - j]; /* reverse of query */
|
||||
sqlite3_stmt *upd = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"UPDATE v SET emb = ? WHERE rowid = 1", -1, &upd, NULL);
|
||||
if (upd) {
|
||||
sqlite3_bind_blob(upd, 1, uvec, sizeof(uvec), SQLITE_STATIC);
|
||||
sqlite3_step(upd);
|
||||
sqlite3_finalize(upd);
|
||||
}
|
||||
}
|
||||
|
||||
/* Second KNN after update */
|
||||
{
|
||||
sqlite3_stmt *knn = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? "
|
||||
"ORDER BY distance LIMIT 10", -1, &knn, NULL);
|
||||
if (knn) {
|
||||
sqlite3_bind_blob(knn, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||
while (sqlite3_step(knn) == SQLITE_ROW) {}
|
||||
sqlite3_finalize(knn);
|
||||
}
|
||||
}
|
||||
|
||||
/* Delete half the rows, then KNN again */
|
||||
for (int i = 1; i <= num_inserts; i += 2) {
|
||||
char del_sql[64];
|
||||
snprintf(del_sql, sizeof(del_sql),
|
||||
"DELETE FROM v WHERE rowid = %d", i);
|
||||
sqlite3_exec(db, del_sql, NULL, NULL, NULL);
|
||||
}
|
||||
|
||||
/* Third KNN after deletes -- exercises distance computation over
|
||||
zeroed-out slots in the quantized chunk */
|
||||
{
|
||||
sqlite3_stmt *knn = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? "
|
||||
"ORDER BY distance LIMIT 5", -1, &knn, NULL);
|
||||
if (knn) {
|
||||
sqlite3_bind_blob(knn, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||
while (sqlite3_step(knn) == SQLITE_ROW) {}
|
||||
sqlite3_finalize(knn);
|
||||
}
|
||||
}
|
||||
|
||||
sqlite3_close(db);
|
||||
return 0;
|
||||
}
|
||||
96
tests/fuzz/rescore-operations.c
Normal file
96
tests/fuzz/rescore-operations.c
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
if (size < 6) return 0;
|
||||
|
||||
int rc;
|
||||
sqlite3 *db;
|
||||
sqlite3_stmt *stmtInsert = NULL;
|
||||
sqlite3_stmt *stmtDelete = NULL;
|
||||
sqlite3_stmt *stmtKnn = NULL;
|
||||
sqlite3_stmt *stmtScan = NULL;
|
||||
|
||||
rc = sqlite3_open(":memory:", &db);
|
||||
assert(rc == SQLITE_OK);
|
||||
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
|
||||
rc = sqlite3_exec(db,
|
||||
"CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[8] indexed by rescore(quantizer=bit))",
|
||||
NULL, NULL, NULL);
|
||||
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||
sqlite3_prepare_v2(db,
|
||||
"DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
|
||||
sqlite3_prepare_v2(db,
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? ORDER BY distance LIMIT 3",
|
||||
-1, &stmtKnn, NULL);
|
||||
sqlite3_prepare_v2(db,
|
||||
"SELECT rowid FROM v", -1, &stmtScan, NULL);
|
||||
|
||||
if (!stmtInsert || !stmtDelete || !stmtKnn || !stmtScan) goto cleanup;
|
||||
|
||||
size_t i = 0;
|
||||
while (i + 2 <= size) {
|
||||
uint8_t op = data[i++] % 4;
|
||||
uint8_t rowid_byte = data[i++];
|
||||
int64_t rowid = (int64_t)(rowid_byte % 32) + 1;
|
||||
|
||||
switch (op) {
|
||||
case 0: {
|
||||
// INSERT: consume 32 bytes for 8 floats, or use what's left
|
||||
float vec[8] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
|
||||
for (int j = 0; j < 8 && i < size; j++, i++) {
|
||||
vec[j] = (float)((int8_t)data[i]) / 10.0f;
|
||||
}
|
||||
sqlite3_reset(stmtInsert);
|
||||
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||
sqlite3_step(stmtInsert);
|
||||
break;
|
||||
}
|
||||
case 1: {
|
||||
// DELETE
|
||||
sqlite3_reset(stmtDelete);
|
||||
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||
sqlite3_step(stmtDelete);
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
// KNN query with a fixed query vector
|
||||
float qvec[8] = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
|
||||
sqlite3_reset(stmtKnn);
|
||||
sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||
break;
|
||||
}
|
||||
case 3: {
|
||||
// Full scan
|
||||
sqlite3_reset(stmtScan);
|
||||
while (sqlite3_step(stmtScan) == SQLITE_ROW) {}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Final operations -- must not crash regardless of prior state
|
||||
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||
|
||||
cleanup:
|
||||
sqlite3_finalize(stmtInsert);
|
||||
sqlite3_finalize(stmtDelete);
|
||||
sqlite3_finalize(stmtKnn);
|
||||
sqlite3_finalize(stmtScan);
|
||||
sqlite3_close(db);
|
||||
return 0;
|
||||
}
|
||||
177
tests/fuzz/rescore-quantize-edge.c
Normal file
177
tests/fuzz/rescore-quantize-edge.c
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <math.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
/* Test wrappers from sqlite-vec-rescore.c (SQLITE_VEC_TEST build) */
|
||||
extern void _test_rescore_quantize_float_to_bit(const float *src, uint8_t *dst, size_t dim);
|
||||
extern void _test_rescore_quantize_float_to_int8(const float *src, int8_t *dst, size_t dim);
|
||||
extern size_t _test_rescore_quantized_byte_size_bit(size_t dimensions);
|
||||
extern size_t _test_rescore_quantized_byte_size_int8(size_t dimensions);
|
||||
|
||||
/**
|
||||
* Fuzz target: edge cases in rescore quantization functions.
|
||||
*
|
||||
* The existing rescore-quantize.c only tests dimensions that are multiples of 8
|
||||
* and never passes special float values. This target:
|
||||
*
|
||||
* - Tests rescore_quantized_byte_size with arbitrary dimension values
|
||||
* (including 0, 1, 7, MAX values -- looking for integer division issues)
|
||||
* - Passes raw float reinterpretation of fuzz bytes (NaN, Inf, denormals,
|
||||
* negative zero -- these are the values that break min/max/range logic)
|
||||
* - Tests the int8 quantizer with all-identical values (range=0 branch)
|
||||
* - Tests the int8 quantizer with extreme ranges (overflow in scale calc)
|
||||
* - Tests bit quantizer with exact float threshold (0.0f boundary)
|
||||
*/
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
if (size < 8) return 0;
|
||||
|
||||
uint8_t mode = data[0] % 5;
|
||||
data++; size--;
|
||||
|
||||
switch (mode) {
|
||||
case 0: {
|
||||
/* Test rescore_quantized_byte_size with fuzz-controlled dimensions.
|
||||
This function does dimensions / CHAR_BIT for bit, dimensions for int8.
|
||||
We're checking it doesn't do anything weird with edge values. */
|
||||
if (size < sizeof(size_t)) return 0;
|
||||
size_t dim;
|
||||
memcpy(&dim, data, sizeof(dim));
|
||||
|
||||
/* These should never crash, just return values */
|
||||
size_t bit_size = _test_rescore_quantized_byte_size_bit(dim);
|
||||
size_t int8_size = _test_rescore_quantized_byte_size_int8(dim);
|
||||
|
||||
/* Verify basic invariants */
|
||||
(void)bit_size;
|
||||
(void)int8_size;
|
||||
break;
|
||||
}
|
||||
|
||||
case 1: {
|
||||
/* Bit quantize with raw reinterpreted floats (NaN, Inf, denormal).
|
||||
The key check: src[i] >= 0.0f -- NaN comparison is always false,
|
||||
so NaN should produce 0-bits. But denormals and -0.0f are tricky. */
|
||||
size_t num_floats = size / sizeof(float);
|
||||
if (num_floats == 0) return 0;
|
||||
/* Round to multiple of 8 for bit quantizer */
|
||||
size_t dim = (num_floats / 8) * 8;
|
||||
if (dim == 0) return 0;
|
||||
|
||||
const float *src = (const float *)data;
|
||||
size_t bit_bytes = dim / 8;
|
||||
uint8_t *dst = (uint8_t *)malloc(bit_bytes);
|
||||
if (!dst) return 0;
|
||||
|
||||
_test_rescore_quantize_float_to_bit(src, dst, dim);
|
||||
|
||||
/* Verify: for each bit, if src >= 0 then bit should be set */
|
||||
for (size_t i = 0; i < dim; i++) {
|
||||
int bit_set = (dst[i / 8] >> (i % 8)) & 1;
|
||||
if (src[i] >= 0.0f) {
|
||||
assert(bit_set == 1);
|
||||
} else if (src[i] < 0.0f) {
|
||||
/* Definitely negative -- bit must be 0 */
|
||||
assert(bit_set == 0);
|
||||
}
|
||||
/* NaN: comparison is false, so bit_set should be 0 */
|
||||
}
|
||||
|
||||
free(dst);
|
||||
break;
|
||||
}
|
||||
|
||||
case 2: {
|
||||
/* Int8 quantize with raw reinterpreted floats.
|
||||
The dangerous paths:
|
||||
- All values identical (range == 0) -> memset path
|
||||
- vmin/vmax with NaN (NaN < anything is false, NaN > anything is false)
|
||||
- Extreme range causing scale = 255/range to be Inf or 0
|
||||
- denormals near the clamping boundaries */
|
||||
size_t num_floats = size / sizeof(float);
|
||||
if (num_floats == 0) return 0;
|
||||
|
||||
const float *src = (const float *)data;
|
||||
int8_t *dst = (int8_t *)malloc(num_floats);
|
||||
if (!dst) return 0;
|
||||
|
||||
_test_rescore_quantize_float_to_int8(src, dst, num_floats);
|
||||
|
||||
/* Output must always be in [-128, 127] (trivially true for int8_t,
|
||||
but check the actual clamping logic worked) */
|
||||
for (size_t i = 0; i < num_floats; i++) {
|
||||
assert(dst[i] >= -128 && dst[i] <= 127);
|
||||
}
|
||||
|
||||
free(dst);
|
||||
break;
|
||||
}
|
||||
|
||||
case 3: {
|
||||
/* Int8 quantize stress: all-same values (range=0 branch) */
|
||||
size_t dim = (size < 64) ? size : 64;
|
||||
if (dim == 0) return 0;
|
||||
|
||||
float *src = (float *)malloc(dim * sizeof(float));
|
||||
int8_t *dst = (int8_t *)malloc(dim);
|
||||
if (!src || !dst) { free(src); free(dst); return 0; }
|
||||
|
||||
/* Fill with a single value derived from fuzz data */
|
||||
float val;
|
||||
memcpy(&val, data, sizeof(float) < size ? sizeof(float) : size);
|
||||
for (size_t i = 0; i < dim; i++) src[i] = val;
|
||||
|
||||
_test_rescore_quantize_float_to_int8(src, dst, dim);
|
||||
|
||||
/* All outputs should be 0 when range == 0 */
|
||||
for (size_t i = 0; i < dim; i++) {
|
||||
assert(dst[i] == 0);
|
||||
}
|
||||
|
||||
free(src);
|
||||
free(dst);
|
||||
break;
|
||||
}
|
||||
|
||||
case 4: {
|
||||
/* Int8 quantize with extreme range: one huge positive, one huge negative.
|
||||
Tests scale = 255.0f / range overflow to Inf, then v * Inf = Inf,
|
||||
then clamping to [-128, 127]. */
|
||||
if (size < 2 * sizeof(float)) return 0;
|
||||
|
||||
float extreme[2];
|
||||
memcpy(extreme, data, 2 * sizeof(float));
|
||||
|
||||
/* Only test if both are finite (NaN/Inf tested in case 2) */
|
||||
if (!isfinite(extreme[0]) || !isfinite(extreme[1])) return 0;
|
||||
|
||||
/* Build a vector with these two extreme values plus some fuzz */
|
||||
size_t dim = 16;
|
||||
float src[16];
|
||||
src[0] = extreme[0];
|
||||
src[1] = extreme[1];
|
||||
for (size_t i = 2; i < dim; i++) {
|
||||
if (2 * sizeof(float) + (i - 2) < size) {
|
||||
src[i] = (float)((int8_t)data[2 * sizeof(float) + (i - 2)]) * 1000.0f;
|
||||
} else {
|
||||
src[i] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
int8_t dst[16];
|
||||
_test_rescore_quantize_float_to_int8(src, dst, dim);
|
||||
|
||||
for (size_t i = 0; i < dim; i++) {
|
||||
assert(dst[i] >= -128 && dst[i] <= 127);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
54
tests/fuzz/rescore-quantize.c
Normal file
54
tests/fuzz/rescore-quantize.c
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
/* These are SQLITE_VEC_TEST wrappers defined in sqlite-vec-rescore.c */
|
||||
extern void _test_rescore_quantize_float_to_bit(const float *src, uint8_t *dst, size_t dim);
|
||||
extern void _test_rescore_quantize_float_to_int8(const float *src, int8_t *dst, size_t dim);
|
||||
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
/* Need at least 4 bytes for one float */
|
||||
if (size < 4) return 0;
|
||||
|
||||
/* Use the input as an array of floats. Dimensions must be a multiple of 8
|
||||
* for the bit quantizer. */
|
||||
size_t num_floats = size / sizeof(float);
|
||||
if (num_floats == 0) return 0;
|
||||
|
||||
/* Round down to multiple of 8 for bit quantizer compatibility */
|
||||
size_t dim = (num_floats / 8) * 8;
|
||||
if (dim == 0) dim = 8;
|
||||
if (dim > num_floats) return 0;
|
||||
|
||||
const float *src = (const float *)data;
|
||||
|
||||
/* Allocate output buffers */
|
||||
size_t bit_bytes = dim / 8;
|
||||
uint8_t *bit_dst = (uint8_t *)malloc(bit_bytes);
|
||||
int8_t *int8_dst = (int8_t *)malloc(dim);
|
||||
if (!bit_dst || !int8_dst) {
|
||||
free(bit_dst);
|
||||
free(int8_dst);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* Test bit quantization */
|
||||
_test_rescore_quantize_float_to_bit(src, bit_dst, dim);
|
||||
|
||||
/* Test int8 quantization */
|
||||
_test_rescore_quantize_float_to_int8(src, int8_dst, dim);
|
||||
|
||||
/* Verify int8 output is in range */
|
||||
for (size_t i = 0; i < dim; i++) {
|
||||
assert(int8_dst[i] >= -128 && int8_dst[i] <= 127);
|
||||
}
|
||||
|
||||
free(bit_dst);
|
||||
free(int8_dst);
|
||||
return 0;
|
||||
}
|
||||
230
tests/fuzz/rescore-shadow-corrupt.c
Normal file
230
tests/fuzz/rescore-shadow-corrupt.c
Normal file
|
|
@ -0,0 +1,230 @@
|
|||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include "sqlite-vec.h"
|
||||
#include "sqlite3.h"
|
||||
#include <assert.h>
|
||||
|
||||
/**
|
||||
* Fuzz target: corrupt rescore shadow tables then exercise KNN/read/write.
|
||||
*
|
||||
* This targets the dangerous code paths in rescore_knn (Phase 1 + 2):
|
||||
* - sqlite3_blob_read into baseVectors with potentially wrong-sized blobs
|
||||
* - distance computation on corrupted/partial quantized data
|
||||
* - blob_reopen on _rescore_vectors with missing/corrupted rows
|
||||
* - insert/delete after corruption (blob_write to wrong offsets)
|
||||
*
|
||||
* The existing shadow-corrupt.c only tests vec0 without rescore.
|
||||
*/
|
||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||
if (size < 4) return 0;
|
||||
|
||||
int rc;
|
||||
sqlite3 *db;
|
||||
|
||||
rc = sqlite3_open(":memory:", &db);
|
||||
assert(rc == SQLITE_OK);
|
||||
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||
assert(rc == SQLITE_OK);
|
||||
|
||||
/* Pick quantizer type from first byte */
|
||||
int use_int8 = data[0] & 1;
|
||||
int target = (data[1] % 8);
|
||||
const uint8_t *payload = data + 2;
|
||||
int payload_size = (int)(size - 2);
|
||||
|
||||
const char *create_sql = use_int8
|
||||
? "CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[16] indexed by rescore(quantizer=int8))"
|
||||
: "CREATE VIRTUAL TABLE v USING vec0("
|
||||
"emb float[16] indexed by rescore(quantizer=bit))";
|
||||
|
||||
rc = sqlite3_exec(db, create_sql, NULL, NULL, NULL);
|
||||
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||
|
||||
/* Insert several vectors so there's a full chunk to corrupt */
|
||||
{
|
||||
sqlite3_stmt *ins = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &ins, NULL);
|
||||
if (!ins) { sqlite3_close(db); return 0; }
|
||||
|
||||
for (int i = 1; i <= 8; i++) {
|
||||
float vec[16];
|
||||
for (int j = 0; j < 16; j++) vec[j] = (float)(i * 10 + j) / 100.0f;
|
||||
sqlite3_reset(ins);
|
||||
sqlite3_bind_int64(ins, 1, i);
|
||||
sqlite3_bind_blob(ins, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||
sqlite3_step(ins);
|
||||
}
|
||||
sqlite3_finalize(ins);
|
||||
}
|
||||
|
||||
/* Now corrupt rescore shadow tables based on fuzz input */
|
||||
sqlite3_stmt *stmt = NULL;
|
||||
|
||||
switch (target) {
|
||||
case 0: {
|
||||
/* Corrupt _rescore_chunks00 vectors blob with fuzz data */
|
||||
rc = sqlite3_prepare_v2(db,
|
||||
"UPDATE v_rescore_chunks00 SET vectors = ? WHERE rowid = 1",
|
||||
-1, &stmt, NULL);
|
||||
if (rc == SQLITE_OK) {
|
||||
sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC);
|
||||
sqlite3_step(stmt);
|
||||
sqlite3_finalize(stmt);
|
||||
stmt = NULL;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 1: {
|
||||
/* Corrupt _rescore_vectors00 vector blob for a specific row */
|
||||
rc = sqlite3_prepare_v2(db,
|
||||
"UPDATE v_rescore_vectors00 SET vector = ? WHERE rowid = 3",
|
||||
-1, &stmt, NULL);
|
||||
if (rc == SQLITE_OK) {
|
||||
sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC);
|
||||
sqlite3_step(stmt);
|
||||
sqlite3_finalize(stmt);
|
||||
stmt = NULL;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
/* Truncate the quantized chunk blob to wrong size */
|
||||
rc = sqlite3_prepare_v2(db,
|
||||
"UPDATE v_rescore_chunks00 SET vectors = X'DEADBEEF' WHERE rowid = 1",
|
||||
-1, &stmt, NULL);
|
||||
if (rc == SQLITE_OK) {
|
||||
sqlite3_step(stmt);
|
||||
sqlite3_finalize(stmt);
|
||||
stmt = NULL;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 3: {
|
||||
/* Delete rows from _rescore_vectors (orphan the float vectors) */
|
||||
sqlite3_exec(db,
|
||||
"DELETE FROM v_rescore_vectors00 WHERE rowid IN (2, 4, 6)",
|
||||
NULL, NULL, NULL);
|
||||
break;
|
||||
}
|
||||
case 4: {
|
||||
/* Delete the chunk row entirely from _rescore_chunks */
|
||||
sqlite3_exec(db,
|
||||
"DELETE FROM v_rescore_chunks00 WHERE rowid = 1",
|
||||
NULL, NULL, NULL);
|
||||
break;
|
||||
}
|
||||
case 5: {
|
||||
/* Set vectors to NULL in _rescore_chunks */
|
||||
sqlite3_exec(db,
|
||||
"UPDATE v_rescore_chunks00 SET vectors = NULL WHERE rowid = 1",
|
||||
NULL, NULL, NULL);
|
||||
break;
|
||||
}
|
||||
case 6: {
|
||||
/* Set vector to NULL in _rescore_vectors */
|
||||
sqlite3_exec(db,
|
||||
"UPDATE v_rescore_vectors00 SET vector = NULL WHERE rowid = 3",
|
||||
NULL, NULL, NULL);
|
||||
break;
|
||||
}
|
||||
case 7: {
|
||||
/* Corrupt BOTH tables with fuzz data */
|
||||
int half = payload_size / 2;
|
||||
rc = sqlite3_prepare_v2(db,
|
||||
"UPDATE v_rescore_chunks00 SET vectors = ? WHERE rowid = 1",
|
||||
-1, &stmt, NULL);
|
||||
if (rc == SQLITE_OK) {
|
||||
sqlite3_bind_blob(stmt, 1, payload, half, SQLITE_STATIC);
|
||||
sqlite3_step(stmt);
|
||||
sqlite3_finalize(stmt);
|
||||
stmt = NULL;
|
||||
}
|
||||
rc = sqlite3_prepare_v2(db,
|
||||
"UPDATE v_rescore_vectors00 SET vector = ? WHERE rowid = 1",
|
||||
-1, &stmt, NULL);
|
||||
if (rc == SQLITE_OK) {
|
||||
sqlite3_bind_blob(stmt, 1, payload + half,
|
||||
payload_size - half, SQLITE_STATIC);
|
||||
sqlite3_step(stmt);
|
||||
sqlite3_finalize(stmt);
|
||||
stmt = NULL;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/* Exercise ALL read/write paths -- NONE should crash */
|
||||
|
||||
/* KNN query (triggers rescore_knn Phase 1 + Phase 2) */
|
||||
{
|
||||
float qvec[16] = {1,0,0,0, 0,1,0,0, 0,0,1,0, 0,0,0,1};
|
||||
sqlite3_stmt *knn = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? "
|
||||
"ORDER BY distance LIMIT 5", -1, &knn, NULL);
|
||||
if (knn) {
|
||||
sqlite3_bind_blob(knn, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||
while (sqlite3_step(knn) == SQLITE_ROW) {}
|
||||
sqlite3_finalize(knn);
|
||||
}
|
||||
}
|
||||
|
||||
/* Full scan (triggers reading from _rescore_vectors) */
|
||||
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||
|
||||
/* Point lookups */
|
||||
sqlite3_exec(db, "SELECT * FROM v WHERE rowid = 1", NULL, NULL, NULL);
|
||||
sqlite3_exec(db, "SELECT * FROM v WHERE rowid = 3", NULL, NULL, NULL);
|
||||
|
||||
/* Insert after corruption */
|
||||
{
|
||||
float vec[16] = {0};
|
||||
sqlite3_stmt *ins = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"INSERT INTO v(rowid, emb) VALUES (99, ?)", -1, &ins, NULL);
|
||||
if (ins) {
|
||||
sqlite3_bind_blob(ins, 1, vec, sizeof(vec), SQLITE_STATIC);
|
||||
sqlite3_step(ins);
|
||||
sqlite3_finalize(ins);
|
||||
}
|
||||
}
|
||||
|
||||
/* Delete after corruption */
|
||||
sqlite3_exec(db, "DELETE FROM v WHERE rowid = 5", NULL, NULL, NULL);
|
||||
|
||||
/* Update after corruption */
|
||||
{
|
||||
float vec[16] = {1,1,1,1, 1,1,1,1, 1,1,1,1, 1,1,1,1};
|
||||
sqlite3_stmt *upd = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"UPDATE v SET emb = ? WHERE rowid = 1", -1, &upd, NULL);
|
||||
if (upd) {
|
||||
sqlite3_bind_blob(upd, 1, vec, sizeof(vec), SQLITE_STATIC);
|
||||
sqlite3_step(upd);
|
||||
sqlite3_finalize(upd);
|
||||
}
|
||||
}
|
||||
|
||||
/* KNN again after modifications to corrupted state */
|
||||
{
|
||||
float qvec[16] = {0,0,0,0, 0,0,0,0, 1,1,1,1, 1,1,1,1};
|
||||
sqlite3_stmt *knn = NULL;
|
||||
sqlite3_prepare_v2(db,
|
||||
"SELECT rowid, distance FROM v WHERE emb MATCH ? "
|
||||
"ORDER BY distance LIMIT 3", -1, &knn, NULL);
|
||||
if (knn) {
|
||||
sqlite3_bind_blob(knn, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||
while (sqlite3_step(knn) == SQLITE_ROW) {}
|
||||
sqlite3_finalize(knn);
|
||||
}
|
||||
}
|
||||
|
||||
sqlite3_exec(db, "DROP TABLE v", NULL, NULL, NULL);
|
||||
sqlite3_close(db);
|
||||
return 0;
|
||||
}
|
||||
81
tests/generate_legacy_db.py
Normal file
81
tests/generate_legacy_db.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = ["sqlite-vec==0.1.6"]
|
||||
# ///
|
||||
"""Generate a legacy sqlite-vec database for backwards-compat testing.
|
||||
|
||||
Usage:
|
||||
uv run --script generate_legacy_db.py
|
||||
|
||||
Creates tests/fixtures/legacy-v0.1.6.db with a vec0 table containing
|
||||
test data that can be read by the current version of sqlite-vec.
|
||||
"""
|
||||
import sqlite3
|
||||
import sqlite_vec
|
||||
import struct
|
||||
import os
|
||||
|
||||
FIXTURE_DIR = os.path.join(os.path.dirname(__file__), "fixtures")
|
||||
DB_PATH = os.path.join(FIXTURE_DIR, "legacy-v0.1.6.db")
|
||||
|
||||
DIMS = 4
|
||||
N_ROWS = 50
|
||||
|
||||
|
||||
def _f32(vals):
|
||||
return struct.pack(f"{len(vals)}f", *vals)
|
||||
|
||||
|
||||
def main():
|
||||
os.makedirs(FIXTURE_DIR, exist_ok=True)
|
||||
if os.path.exists(DB_PATH):
|
||||
os.remove(DB_PATH)
|
||||
|
||||
db = sqlite3.connect(DB_PATH)
|
||||
db.enable_load_extension(True)
|
||||
sqlite_vec.load(db)
|
||||
|
||||
# Print version for verification
|
||||
version = db.execute("SELECT vec_version()").fetchone()[0]
|
||||
print(f"sqlite-vec version: {version}")
|
||||
|
||||
# Create a basic vec0 table — flat index, no fancy features
|
||||
db.execute(f"CREATE VIRTUAL TABLE legacy_vectors USING vec0(emb float[{DIMS}])")
|
||||
|
||||
# Insert test data: vectors where element[0] == rowid for easy verification
|
||||
for i in range(1, N_ROWS + 1):
|
||||
vec = [float(i), 0.0, 0.0, 0.0]
|
||||
db.execute("INSERT INTO legacy_vectors(rowid, emb) VALUES (?, ?)", [i, _f32(vec)])
|
||||
|
||||
db.commit()
|
||||
|
||||
# Verify
|
||||
count = db.execute("SELECT count(*) FROM legacy_vectors").fetchone()[0]
|
||||
print(f"Inserted {count} rows")
|
||||
|
||||
# Test KNN works
|
||||
query = _f32([1.0, 0.0, 0.0, 0.0])
|
||||
rows = db.execute(
|
||||
"SELECT rowid, distance FROM legacy_vectors WHERE emb MATCH ? AND k = 5",
|
||||
[query],
|
||||
).fetchall()
|
||||
print(f"KNN top 5: {[(r[0], round(r[1], 4)) for r in rows]}")
|
||||
assert rows[0][0] == 1 # closest to [1,0,0,0]
|
||||
assert len(rows) == 5
|
||||
|
||||
# Also create a table with name == column name (the conflict case)
|
||||
# This was allowed in old versions — new code must not break on reconnect
|
||||
db.execute("CREATE VIRTUAL TABLE emb USING vec0(emb float[4])")
|
||||
for i in range(1, 11):
|
||||
db.execute("INSERT INTO emb(rowid, emb) VALUES (?, ?)", [i, _f32([float(i), 0, 0, 0])])
|
||||
db.commit()
|
||||
|
||||
count2 = db.execute("SELECT count(*) FROM emb").fetchone()[0]
|
||||
print(f"Table 'emb' with column 'emb': {count2} rows (name conflict case)")
|
||||
|
||||
db.close()
|
||||
print(f"\nGenerated: {DB_PATH}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -3,6 +3,11 @@
|
|||
|
||||
#include <stdlib.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#ifndef SQLITE_VEC_ENABLE_IVF
|
||||
#define SQLITE_VEC_ENABLE_IVF 1
|
||||
#endif
|
||||
|
||||
int min_idx(
|
||||
const float *distances,
|
||||
|
|
@ -62,12 +67,81 @@ enum Vec0DistanceMetrics {
|
|||
VEC0_DISTANCE_METRIC_L1 = 3,
|
||||
};
|
||||
|
||||
enum Vec0IndexType {
|
||||
VEC0_INDEX_TYPE_FLAT = 1,
|
||||
#ifdef SQLITE_VEC_ENABLE_RESCORE
|
||||
VEC0_INDEX_TYPE_RESCORE = 2,
|
||||
#endif
|
||||
VEC0_INDEX_TYPE_IVF = 3,
|
||||
VEC0_INDEX_TYPE_DISKANN = 4,
|
||||
};
|
||||
|
||||
enum Vec0RescoreQuantizerType {
|
||||
VEC0_RESCORE_QUANTIZER_BIT = 1,
|
||||
VEC0_RESCORE_QUANTIZER_INT8 = 2,
|
||||
};
|
||||
|
||||
struct Vec0RescoreConfig {
|
||||
enum Vec0RescoreQuantizerType quantizer_type;
|
||||
int oversample;
|
||||
};
|
||||
|
||||
#if SQLITE_VEC_ENABLE_IVF
|
||||
enum Vec0IvfQuantizer {
|
||||
VEC0_IVF_QUANTIZER_NONE = 0,
|
||||
VEC0_IVF_QUANTIZER_INT8 = 1,
|
||||
VEC0_IVF_QUANTIZER_BINARY = 2,
|
||||
};
|
||||
|
||||
struct Vec0IvfConfig {
|
||||
int nlist;
|
||||
int nprobe;
|
||||
int quantizer;
|
||||
int oversample;
|
||||
};
|
||||
#else
|
||||
struct Vec0IvfConfig { char _unused; };
|
||||
#endif
|
||||
|
||||
#ifdef SQLITE_VEC_ENABLE_RESCORE
|
||||
enum Vec0RescoreQuantizerType {
|
||||
VEC0_RESCORE_QUANTIZER_BIT = 1,
|
||||
VEC0_RESCORE_QUANTIZER_INT8 = 2,
|
||||
};
|
||||
|
||||
struct Vec0RescoreConfig {
|
||||
enum Vec0RescoreQuantizerType quantizer_type;
|
||||
int oversample;
|
||||
};
|
||||
#endif
|
||||
|
||||
enum Vec0DiskannQuantizerType {
|
||||
VEC0_DISKANN_QUANTIZER_BINARY = 1,
|
||||
VEC0_DISKANN_QUANTIZER_INT8 = 2,
|
||||
};
|
||||
|
||||
struct Vec0DiskannConfig {
|
||||
enum Vec0DiskannQuantizerType quantizer_type;
|
||||
int n_neighbors;
|
||||
int search_list_size;
|
||||
int search_list_size_search;
|
||||
int search_list_size_insert;
|
||||
float alpha;
|
||||
int buffer_threshold;
|
||||
};
|
||||
|
||||
struct VectorColumnDefinition {
|
||||
char *name;
|
||||
int name_length;
|
||||
size_t dimensions;
|
||||
enum VectorElementType element_type;
|
||||
enum Vec0DistanceMetrics distance_metric;
|
||||
enum Vec0IndexType index_type;
|
||||
#ifdef SQLITE_VEC_ENABLE_RESCORE
|
||||
struct Vec0RescoreConfig rescore;
|
||||
#endif
|
||||
struct Vec0IvfConfig ivf;
|
||||
struct Vec0DiskannConfig diskann;
|
||||
};
|
||||
|
||||
int vec0_parse_vector_column(const char *source, int source_length,
|
||||
|
|
@ -78,10 +152,90 @@ int vec0_parse_partition_key_definition(const char *source, int source_length,
|
|||
int *out_column_name_length,
|
||||
int *out_column_type);
|
||||
|
||||
size_t diskann_quantized_vector_byte_size(
|
||||
enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions);
|
||||
|
||||
int diskann_validity_byte_size(int n_neighbors);
|
||||
size_t diskann_neighbor_ids_byte_size(int n_neighbors);
|
||||
size_t diskann_neighbor_qvecs_byte_size(
|
||||
int n_neighbors, enum Vec0DiskannQuantizerType quantizer_type,
|
||||
size_t dimensions);
|
||||
int diskann_node_init(
|
||||
int n_neighbors, enum Vec0DiskannQuantizerType quantizer_type,
|
||||
size_t dimensions,
|
||||
unsigned char **outValidity, int *outValiditySize,
|
||||
unsigned char **outNeighborIds, int *outNeighborIdsSize,
|
||||
unsigned char **outNeighborQvecs, int *outNeighborQvecsSize);
|
||||
int diskann_validity_get(const unsigned char *validity, int i);
|
||||
void diskann_validity_set(unsigned char *validity, int i, int value);
|
||||
int diskann_validity_count(const unsigned char *validity, int n_neighbors);
|
||||
long long diskann_neighbor_id_get(const unsigned char *neighbor_ids, int i);
|
||||
void diskann_neighbor_id_set(unsigned char *neighbor_ids, int i, long long rowid);
|
||||
const unsigned char *diskann_neighbor_qvec_get(
|
||||
const unsigned char *qvecs, int i,
|
||||
enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions);
|
||||
void diskann_neighbor_qvec_set(
|
||||
unsigned char *qvecs, int i, const unsigned char *src_qvec,
|
||||
enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions);
|
||||
void diskann_node_set_neighbor(
|
||||
unsigned char *validity, unsigned char *neighbor_ids, unsigned char *qvecs, int i,
|
||||
long long neighbor_rowid, const unsigned char *neighbor_qvec,
|
||||
enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions);
|
||||
void diskann_node_clear_neighbor(
|
||||
unsigned char *validity, unsigned char *neighbor_ids, unsigned char *qvecs, int i,
|
||||
enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions);
|
||||
int diskann_quantize_vector(
|
||||
const float *src, size_t dimensions,
|
||||
enum Vec0DiskannQuantizerType quantizer_type,
|
||||
unsigned char *out);
|
||||
|
||||
int diskann_prune_select(
|
||||
const float *inter_distances, const float *p_distances,
|
||||
int num_candidates, float alpha, int max_neighbors,
|
||||
int *outSelected, int *outCount);
|
||||
|
||||
#ifdef SQLITE_VEC_TEST
|
||||
float _test_distance_l2_sqr_float(const float *a, const float *b, size_t dims);
|
||||
float _test_distance_cosine_float(const float *a, const float *b, size_t dims);
|
||||
float _test_distance_hamming(const unsigned char *a, const unsigned char *b, size_t dims);
|
||||
|
||||
#ifdef SQLITE_VEC_ENABLE_RESCORE
|
||||
void _test_rescore_quantize_float_to_bit(const float *src, uint8_t *dst, size_t dim);
|
||||
void _test_rescore_quantize_float_to_int8(const float *src, int8_t *dst, size_t dim);
|
||||
size_t _test_rescore_quantized_byte_size_bit(size_t dimensions);
|
||||
size_t _test_rescore_quantized_byte_size_int8(size_t dimensions);
|
||||
#endif
|
||||
#if SQLITE_VEC_ENABLE_IVF
|
||||
void ivf_quantize_int8(const float *src, int8_t *dst, int D);
|
||||
void ivf_quantize_binary(const float *src, uint8_t *dst, int D);
|
||||
#endif
|
||||
// DiskANN candidate list (opaque struct, use accessors)
|
||||
struct DiskannCandidateList {
|
||||
void *items; // opaque
|
||||
int count;
|
||||
int capacity;
|
||||
};
|
||||
|
||||
int _test_diskann_candidate_list_init(struct DiskannCandidateList *list, int capacity);
|
||||
void _test_diskann_candidate_list_free(struct DiskannCandidateList *list);
|
||||
int _test_diskann_candidate_list_insert(struct DiskannCandidateList *list, long long rowid, float distance);
|
||||
int _test_diskann_candidate_list_next_unvisited(const struct DiskannCandidateList *list);
|
||||
int _test_diskann_candidate_list_count(const struct DiskannCandidateList *list);
|
||||
long long _test_diskann_candidate_list_rowid(const struct DiskannCandidateList *list, int i);
|
||||
float _test_diskann_candidate_list_distance(const struct DiskannCandidateList *list, int i);
|
||||
void _test_diskann_candidate_list_set_visited(struct DiskannCandidateList *list, int i);
|
||||
|
||||
// DiskANN visited set (opaque struct, use accessors)
|
||||
struct DiskannVisitedSet {
|
||||
void *slots; // opaque
|
||||
int capacity;
|
||||
int count;
|
||||
};
|
||||
|
||||
int _test_diskann_visited_set_init(struct DiskannVisitedSet *set, int capacity);
|
||||
void _test_diskann_visited_set_free(struct DiskannVisitedSet *set);
|
||||
int _test_diskann_visited_set_contains(const struct DiskannVisitedSet *set, long long rowid);
|
||||
int _test_diskann_visited_set_insert(struct DiskannVisitedSet *set, long long rowid);
|
||||
#endif
|
||||
|
||||
#endif /* SQLITE_VEC_INTERNAL_H */
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import sqlite3
|
||||
from helpers import exec, vec0_shadow_table_contents
|
||||
import struct
|
||||
import pytest
|
||||
from helpers import exec, vec0_shadow_table_contents, _f32
|
||||
|
||||
|
||||
def test_constructor_limit(db, snapshot):
|
||||
|
|
@ -126,3 +128,198 @@ def test_knn(db, snapshot):
|
|||
) == snapshot(name="illegal KNN w/ aux")
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# Auxiliary columns with non-flat indexes
|
||||
# ======================================================================
|
||||
|
||||
|
||||
def test_rescore_aux_shadow_tables(db, snapshot):
|
||||
"""Rescore + aux column: verify shadow tables are created correctly."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" emb float[128] indexed by rescore(quantizer=bit),"
|
||||
" +label text,"
|
||||
" +score float"
|
||||
")"
|
||||
)
|
||||
assert exec(db, "SELECT name, sql FROM sqlite_master WHERE type='table' AND name LIKE 't_%' ORDER BY name") == snapshot(
|
||||
name="rescore aux shadow tables"
|
||||
)
|
||||
|
||||
|
||||
def test_rescore_aux_insert_knn(db, snapshot):
|
||||
"""Insert with aux data, KNN should return aux column values."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" emb float[128] indexed by rescore(quantizer=bit),"
|
||||
" +label text"
|
||||
")"
|
||||
)
|
||||
import random
|
||||
random.seed(77)
|
||||
data = [
|
||||
("alpha", [random.gauss(0, 1) for _ in range(128)]),
|
||||
("beta", [random.gauss(0, 1) for _ in range(128)]),
|
||||
("gamma", [random.gauss(0, 1) for _ in range(128)]),
|
||||
]
|
||||
for label, vec in data:
|
||||
db.execute(
|
||||
"INSERT INTO t(emb, label) VALUES (?, ?)",
|
||||
[_f32(vec), label],
|
||||
)
|
||||
|
||||
assert exec(db, "SELECT rowid, label FROM t ORDER BY rowid") == snapshot(
|
||||
name="rescore aux select all"
|
||||
)
|
||||
assert vec0_shadow_table_contents(db, "t", skip_info=True) == snapshot(
|
||||
name="rescore aux shadow contents"
|
||||
)
|
||||
|
||||
# KNN should include aux column, "alpha" closest to its own vector
|
||||
rows = db.execute(
|
||||
"SELECT label, distance FROM t WHERE emb MATCH ? ORDER BY distance LIMIT 3",
|
||||
[_f32(data[0][1])],
|
||||
).fetchall()
|
||||
assert len(rows) == 3
|
||||
assert rows[0][0] == "alpha"
|
||||
|
||||
|
||||
def test_rescore_aux_update(db):
|
||||
"""UPDATE aux column on rescore table should work without affecting vectors."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" emb float[128] indexed by rescore(quantizer=bit),"
|
||||
" +label text"
|
||||
")"
|
||||
)
|
||||
import random
|
||||
random.seed(88)
|
||||
vec = [random.gauss(0, 1) for _ in range(128)]
|
||||
db.execute("INSERT INTO t(rowid, emb, label) VALUES (1, ?, 'original')", [_f32(vec)])
|
||||
db.execute("UPDATE t SET label = 'updated' WHERE rowid = 1")
|
||||
|
||||
assert db.execute("SELECT label FROM t WHERE rowid = 1").fetchone()[0] == "updated"
|
||||
|
||||
# KNN still works with updated aux
|
||||
rows = db.execute(
|
||||
"SELECT rowid, label FROM t WHERE emb MATCH ? ORDER BY distance LIMIT 1",
|
||||
[_f32(vec)],
|
||||
).fetchall()
|
||||
assert rows[0][0] == 1
|
||||
assert rows[0][1] == "updated"
|
||||
|
||||
|
||||
def test_rescore_aux_delete(db, snapshot):
|
||||
"""DELETE should remove aux data from shadow table."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" emb float[128] indexed by rescore(quantizer=bit),"
|
||||
" +label text"
|
||||
")"
|
||||
)
|
||||
import random
|
||||
random.seed(99)
|
||||
for i in range(5):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, emb, label) VALUES (?, ?, ?)",
|
||||
[i + 1, _f32([random.gauss(0, 1) for _ in range(128)]), f"item-{i+1}"],
|
||||
)
|
||||
|
||||
db.execute("DELETE FROM t WHERE rowid = 3")
|
||||
|
||||
assert exec(db, "SELECT rowid, label FROM t ORDER BY rowid") == snapshot(
|
||||
name="rescore aux after delete"
|
||||
)
|
||||
assert exec(db, "SELECT rowid, value00 FROM t_auxiliary ORDER BY rowid") == snapshot(
|
||||
name="rescore aux shadow after delete"
|
||||
)
|
||||
|
||||
|
||||
def test_diskann_aux_shadow_tables(db, snapshot):
|
||||
"""DiskANN + aux column: verify shadow tables are created correctly."""
|
||||
db.execute("""
|
||||
CREATE VIRTUAL TABLE t USING vec0(
|
||||
emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8),
|
||||
+label text,
|
||||
+score float
|
||||
)
|
||||
""")
|
||||
assert exec(db, "SELECT name, sql FROM sqlite_master WHERE type='table' AND name LIKE 't_%' ORDER BY name") == snapshot(
|
||||
name="diskann aux shadow tables"
|
||||
)
|
||||
|
||||
|
||||
def test_diskann_aux_insert_knn(db, snapshot):
|
||||
"""DiskANN + aux: insert, KNN, verify aux values returned."""
|
||||
db.execute("""
|
||||
CREATE VIRTUAL TABLE t USING vec0(
|
||||
emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8),
|
||||
+label text
|
||||
)
|
||||
""")
|
||||
data = [
|
||||
("red", [1, 0, 0, 0, 0, 0, 0, 0]),
|
||||
("green", [0, 1, 0, 0, 0, 0, 0, 0]),
|
||||
("blue", [0, 0, 1, 0, 0, 0, 0, 0]),
|
||||
]
|
||||
for label, vec in data:
|
||||
db.execute("INSERT INTO t(emb, label) VALUES (?, ?)", [_f32(vec), label])
|
||||
|
||||
assert exec(db, "SELECT rowid, label FROM t ORDER BY rowid") == snapshot(
|
||||
name="diskann aux select all"
|
||||
)
|
||||
assert vec0_shadow_table_contents(db, "t", skip_info=True) == snapshot(
|
||||
name="diskann aux shadow contents"
|
||||
)
|
||||
|
||||
rows = db.execute(
|
||||
"SELECT label, distance FROM t WHERE emb MATCH ? AND k = 3",
|
||||
[_f32([1, 0, 0, 0, 0, 0, 0, 0])],
|
||||
).fetchall()
|
||||
assert len(rows) >= 1
|
||||
assert rows[0][0] == "red"
|
||||
|
||||
|
||||
def test_diskann_aux_update_and_delete(db, snapshot):
|
||||
"""DiskANN + aux: update aux column, delete row, verify cleanup."""
|
||||
db.execute("""
|
||||
CREATE VIRTUAL TABLE t USING vec0(
|
||||
emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8),
|
||||
+label text
|
||||
)
|
||||
""")
|
||||
for i in range(5):
|
||||
vec = [0.0] * 8
|
||||
vec[i % 8] = 1.0
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, emb, label) VALUES (?, ?, ?)",
|
||||
[i + 1, _f32(vec), f"item-{i+1}"],
|
||||
)
|
||||
|
||||
db.execute("UPDATE t SET label = 'UPDATED' WHERE rowid = 2")
|
||||
db.execute("DELETE FROM t WHERE rowid = 3")
|
||||
|
||||
assert exec(db, "SELECT rowid, label FROM t ORDER BY rowid") == snapshot(
|
||||
name="diskann aux after update+delete"
|
||||
)
|
||||
assert exec(db, "SELECT rowid, value00 FROM t_auxiliary ORDER BY rowid") == snapshot(
|
||||
name="diskann aux shadow after update+delete"
|
||||
)
|
||||
|
||||
|
||||
def test_diskann_aux_drop_cleans_all(db):
|
||||
"""DROP TABLE should remove aux shadow table too."""
|
||||
db.execute("""
|
||||
CREATE VIRTUAL TABLE t USING vec0(
|
||||
emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8),
|
||||
+label text
|
||||
)
|
||||
""")
|
||||
db.execute("INSERT INTO t(emb, label) VALUES (?, 'test')", [_f32([1]*8)])
|
||||
db.execute("DROP TABLE t")
|
||||
|
||||
tables = [r[0] for r in db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE name LIKE 't_%'"
|
||||
).fetchall()]
|
||||
assert "t_auxiliary" not in tables
|
||||
|
||||
|
|
|
|||
1331
tests/test-diskann.py
Normal file
1331
tests/test-diskann.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -27,3 +27,15 @@ def test_info(db, snapshot):
|
|||
assert exec(db, "select key, typeof(value) from v_info order by 1") == snapshot()
|
||||
|
||||
|
||||
def test_command_column_name_conflict(db):
|
||||
"""Table name matching a column name should error (command column conflict)."""
|
||||
# This would conflict: hidden command column 'embeddings' vs vector column 'embeddings'
|
||||
with pytest.raises(sqlite3.OperationalError, match="conflicts with table name"):
|
||||
db.execute(
|
||||
"create virtual table embeddings using vec0(embeddings float[4])"
|
||||
)
|
||||
|
||||
# Different names should work fine
|
||||
db.execute("create virtual table t using vec0(embeddings float[4])")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -483,3 +483,171 @@ def test_delete_one_chunk_of_two_shrinks_pages(tmp_path):
|
|||
row = db.execute("select emb from v where rowid = ?", [i]).fetchone()
|
||||
assert row[0] == _f32([float(i)] * dims)
|
||||
db.close()
|
||||
|
||||
|
||||
def test_wal_concurrent_reader_during_write(tmp_path):
|
||||
"""In WAL mode, a reader should see a consistent snapshot while a writer inserts."""
|
||||
dims = 4
|
||||
db_path = str(tmp_path / "test.db")
|
||||
|
||||
# Writer: create table, insert initial rows, enable WAL
|
||||
writer = sqlite3.connect(db_path)
|
||||
writer.enable_load_extension(True)
|
||||
writer.load_extension("dist/vec0")
|
||||
writer.execute("PRAGMA journal_mode=WAL")
|
||||
writer.execute(
|
||||
f"CREATE VIRTUAL TABLE v USING vec0(emb float[{dims}])"
|
||||
)
|
||||
for i in range(1, 11):
|
||||
writer.execute("INSERT INTO v(rowid, emb) VALUES (?, ?)", [i, _f32([float(i)] * dims)])
|
||||
writer.commit()
|
||||
|
||||
# Reader: open separate connection, start read
|
||||
reader = sqlite3.connect(db_path)
|
||||
reader.enable_load_extension(True)
|
||||
reader.load_extension("dist/vec0")
|
||||
|
||||
# Reader sees 10 rows
|
||||
count_before = reader.execute("SELECT count(*) FROM v").fetchone()[0]
|
||||
assert count_before == 10
|
||||
|
||||
# Writer inserts more rows (not yet committed)
|
||||
writer.execute("BEGIN")
|
||||
for i in range(11, 21):
|
||||
writer.execute("INSERT INTO v(rowid, emb) VALUES (?, ?)", [i, _f32([float(i)] * dims)])
|
||||
|
||||
# Reader still sees 10 (WAL snapshot isolation)
|
||||
count_during = reader.execute("SELECT count(*) FROM v").fetchone()[0]
|
||||
assert count_during == 10
|
||||
|
||||
# KNN during writer's transaction should work on reader's snapshot
|
||||
rows = reader.execute(
|
||||
"SELECT rowid FROM v WHERE emb MATCH ? AND k = 5",
|
||||
[_f32([1.0] * dims)],
|
||||
).fetchall()
|
||||
assert len(rows) == 5
|
||||
assert all(r[0] <= 10 for r in rows) # only original rows
|
||||
|
||||
# Writer commits
|
||||
writer.commit()
|
||||
|
||||
# Reader sees new rows after re-query (new snapshot)
|
||||
count_after = reader.execute("SELECT count(*) FROM v").fetchone()[0]
|
||||
assert count_after == 20
|
||||
|
||||
writer.close()
|
||||
reader.close()
|
||||
|
||||
|
||||
def test_insert_or_replace_integer_pk(db):
|
||||
"""INSERT OR REPLACE should update vector when rowid already exists."""
|
||||
db.execute("create virtual table v using vec0(emb float[4], chunk_size=8)")
|
||||
|
||||
db.execute(
|
||||
"insert into v(rowid, emb) values (1, ?)", [_f32([1.0, 2.0, 3.0, 4.0])]
|
||||
)
|
||||
# Replace with new vector
|
||||
db.execute(
|
||||
"insert or replace into v(rowid, emb) values (1, ?)",
|
||||
[_f32([10.0, 20.0, 30.0, 40.0])],
|
||||
)
|
||||
|
||||
# Should still have exactly 1 row
|
||||
count = db.execute("select count(*) from v").fetchone()[0]
|
||||
assert count == 1
|
||||
|
||||
# Vector should be the replaced value
|
||||
row = db.execute("select emb from v where rowid = 1").fetchone()
|
||||
assert row[0] == _f32([10.0, 20.0, 30.0, 40.0])
|
||||
|
||||
|
||||
def test_insert_or_replace_new_row(db):
|
||||
"""INSERT OR REPLACE with a new rowid should just insert normally."""
|
||||
db.execute("create virtual table v using vec0(emb float[4], chunk_size=8)")
|
||||
|
||||
db.execute(
|
||||
"insert or replace into v(rowid, emb) values (1, ?)",
|
||||
[_f32([1.0, 2.0, 3.0, 4.0])],
|
||||
)
|
||||
|
||||
count = db.execute("select count(*) from v").fetchone()[0]
|
||||
assert count == 1
|
||||
|
||||
row = db.execute("select emb from v where rowid = 1").fetchone()
|
||||
assert row[0] == _f32([1.0, 2.0, 3.0, 4.0])
|
||||
|
||||
|
||||
def test_insert_or_replace_text_pk(db):
|
||||
"""INSERT OR REPLACE should work with text primary keys."""
|
||||
db.execute(
|
||||
"create virtual table v using vec0("
|
||||
"id text primary key, emb float[4], chunk_size=8"
|
||||
")"
|
||||
)
|
||||
|
||||
db.execute(
|
||||
"insert into v(id, emb) values ('doc_a', ?)",
|
||||
[_f32([1.0, 2.0, 3.0, 4.0])],
|
||||
)
|
||||
db.execute(
|
||||
"insert or replace into v(id, emb) values ('doc_a', ?)",
|
||||
[_f32([10.0, 20.0, 30.0, 40.0])],
|
||||
)
|
||||
|
||||
count = db.execute("select count(*) from v").fetchone()[0]
|
||||
assert count == 1
|
||||
|
||||
row = db.execute("select emb from v where id = 'doc_a'").fetchone()
|
||||
assert row[0] == _f32([10.0, 20.0, 30.0, 40.0])
|
||||
|
||||
|
||||
def test_insert_or_replace_with_auxiliary(db):
|
||||
"""INSERT OR REPLACE should also replace auxiliary column values."""
|
||||
db.execute(
|
||||
"create virtual table v using vec0("
|
||||
"emb float[4], +label text, chunk_size=8"
|
||||
")"
|
||||
)
|
||||
|
||||
db.execute(
|
||||
"insert into v(rowid, emb, label) values (1, ?, 'old')",
|
||||
[_f32([1.0, 2.0, 3.0, 4.0])],
|
||||
)
|
||||
db.execute(
|
||||
"insert or replace into v(rowid, emb, label) values (1, ?, 'new')",
|
||||
[_f32([10.0, 20.0, 30.0, 40.0])],
|
||||
)
|
||||
|
||||
count = db.execute("select count(*) from v").fetchone()[0]
|
||||
assert count == 1
|
||||
|
||||
row = db.execute("select emb, label from v where rowid = 1").fetchone()
|
||||
assert row[0] == _f32([10.0, 20.0, 30.0, 40.0])
|
||||
assert row[1] == "new"
|
||||
|
||||
|
||||
def test_insert_or_replace_knn_uses_new_vector(db):
|
||||
"""After INSERT OR REPLACE, KNN should find the new vector, not the old one."""
|
||||
db.execute("create virtual table v using vec0(emb float[4], chunk_size=8)")
|
||||
|
||||
db.execute(
|
||||
"insert into v(rowid, emb) values (1, ?)", [_f32([1.0, 0.0, 0.0, 0.0])]
|
||||
)
|
||||
db.execute(
|
||||
"insert into v(rowid, emb) values (2, ?)", [_f32([0.0, 1.0, 0.0, 0.0])]
|
||||
)
|
||||
|
||||
# Replace row 1's vector to be very close to row 2
|
||||
db.execute(
|
||||
"insert or replace into v(rowid, emb) values (1, ?)",
|
||||
[_f32([0.0, 0.9, 0.0, 0.0])],
|
||||
)
|
||||
|
||||
# KNN for [0, 1, 0, 0] should return row 2 first (exact), then row 1 (close)
|
||||
rows = db.execute(
|
||||
"select rowid, distance from v where emb match ? and k = 2",
|
||||
[_f32([0.0, 1.0, 0.0, 0.0])],
|
||||
).fetchall()
|
||||
assert rows[0][0] == 2
|
||||
assert rows[1][0] == 1
|
||||
assert rows[1][1] < 0.11 # should be close (L2 distance ≈ 0.1)
|
||||
|
|
|
|||
589
tests/test-ivf-mutations.py
Normal file
589
tests/test-ivf-mutations.py
Normal file
|
|
@ -0,0 +1,589 @@
|
|||
"""
|
||||
Thorough IVF mutation tests: insert, delete, update, KNN correctness,
|
||||
error cases, edge cases, and cell overflow scenarios.
|
||||
"""
|
||||
import pytest
|
||||
import sqlite3
|
||||
import struct
|
||||
import math
|
||||
from helpers import _f32, exec
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def db():
|
||||
db = sqlite3.connect(":memory:")
|
||||
db.row_factory = sqlite3.Row
|
||||
db.enable_load_extension(True)
|
||||
db.load_extension("dist/vec0")
|
||||
db.enable_load_extension(False)
|
||||
return db
|
||||
|
||||
|
||||
def ivf_total_vectors(db, table="t", col=0):
|
||||
"""Count total vectors across all IVF cells."""
|
||||
return db.execute(
|
||||
f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d}"
|
||||
).fetchone()[0]
|
||||
|
||||
|
||||
def ivf_unassigned_count(db, table="t", col=0):
|
||||
return db.execute(
|
||||
f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d} WHERE centroid_id = -1"
|
||||
).fetchone()[0]
|
||||
|
||||
|
||||
def ivf_assigned_count(db, table="t", col=0):
|
||||
return db.execute(
|
||||
f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d} WHERE centroid_id >= 0"
|
||||
).fetchone()[0]
|
||||
|
||||
|
||||
def knn(db, query, k, table="t", col="v"):
|
||||
"""Run a KNN query and return list of (rowid, distance) tuples."""
|
||||
rows = db.execute(
|
||||
f"SELECT rowid, distance FROM {table} WHERE {col} MATCH ? AND k = ?",
|
||||
[_f32(query), k],
|
||||
).fetchall()
|
||||
return [(r[0], r[1]) for r in rows]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Single row insert + KNN
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_insert_single_row_knn(db):
|
||||
db.execute("CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf())")
|
||||
db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 0, 0, 0])])
|
||||
results = knn(db, [1, 0, 0, 0], 5)
|
||||
assert len(results) == 1
|
||||
assert results[0][0] == 1
|
||||
assert results[0][1] < 0.001
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Batch insert + KNN recall
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_batch_insert_knn_recall(db):
|
||||
"""Insert 200 vectors, train, verify KNN recall with nprobe=nlist."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=8, nprobe=8))"
|
||||
)
|
||||
for i in range(200):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||
[i, _f32([float(i), 0, 0, 0])],
|
||||
)
|
||||
assert ivf_total_vectors(db) == 200
|
||||
|
||||
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||
assert ivf_assigned_count(db) == 200
|
||||
|
||||
# Query near 100 -- closest should be rowid 100
|
||||
results = knn(db, [100.0, 0, 0, 0], 10)
|
||||
assert len(results) == 10
|
||||
assert results[0][0] == 100
|
||||
assert results[0][1] < 0.01
|
||||
|
||||
# All results should be near 100
|
||||
rowids = {r[0] for r in results}
|
||||
assert all(95 <= r <= 105 for r in rowids)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Delete rows, verify they're gone from KNN
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_delete_rows_gone_from_knn(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
|
||||
)
|
||||
for i in range(20):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||
[i, _f32([float(i), 0, 0, 0])],
|
||||
)
|
||||
|
||||
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||
|
||||
# Delete rowid 10
|
||||
db.execute("DELETE FROM t WHERE rowid = 10")
|
||||
|
||||
results = knn(db, [10.0, 0, 0, 0], 20)
|
||||
rowids = {r[0] for r in results}
|
||||
assert 10 not in rowids
|
||||
|
||||
|
||||
def test_delete_all_rows_empty_results(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
|
||||
)
|
||||
for i in range(10):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||
[i, _f32([float(i), 0, 0, 0])],
|
||||
)
|
||||
|
||||
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||
|
||||
for i in range(10):
|
||||
db.execute("DELETE FROM t WHERE rowid = ?", [i])
|
||||
|
||||
assert ivf_total_vectors(db) == 0
|
||||
results = knn(db, [5.0, 0, 0, 0], 10)
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Insert after delete (reuse rowids)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_insert_after_delete_reuse_rowid(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
|
||||
)
|
||||
for i in range(10):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||
[i, _f32([float(i), 0, 0, 0])],
|
||||
)
|
||||
|
||||
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||
|
||||
# Delete rowid 5
|
||||
db.execute("DELETE FROM t WHERE rowid = 5")
|
||||
|
||||
# Re-insert rowid 5 with a very different vector
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (5, ?)", [_f32([999.0, 0, 0, 0])]
|
||||
)
|
||||
|
||||
# KNN near 999 should find rowid 5
|
||||
results = knn(db, [999.0, 0, 0, 0], 1)
|
||||
assert len(results) >= 1
|
||||
assert results[0][0] == 5
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Update vectors (INSERT OR REPLACE), verify KNN reflects new values
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_update_vector_via_delete_insert(db):
|
||||
"""vec0 IVF update: delete then re-insert with new vector."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
|
||||
)
|
||||
for i in range(10):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||
[i, _f32([float(i), 0, 0, 0])],
|
||||
)
|
||||
|
||||
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||
|
||||
# "Update" rowid 3: delete and re-insert with new vector
|
||||
db.execute("DELETE FROM t WHERE rowid = 3")
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (3, ?)",
|
||||
[_f32([100.0, 0, 0, 0])],
|
||||
)
|
||||
|
||||
# KNN near 100 should find rowid 3
|
||||
results = knn(db, [100.0, 0, 0, 0], 1)
|
||||
assert results[0][0] == 3
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Error cases: IVF + auxiliary/metadata/partition key columns
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_ivf_with_auxiliary_column(db):
|
||||
"""IVF should support auxiliary columns."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(), +extra text)"
|
||||
)
|
||||
tables = [r[0] for r in db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE name LIKE 't_%' ORDER BY 1"
|
||||
).fetchall()]
|
||||
assert "t_auxiliary" in tables
|
||||
|
||||
|
||||
def test_error_ivf_with_metadata_column(db):
|
||||
result = exec(
|
||||
db,
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(), genre text)",
|
||||
)
|
||||
assert "error" in result
|
||||
assert "metadata" in result.get("message", "").lower()
|
||||
|
||||
|
||||
def test_error_ivf_with_partition_key(db):
|
||||
result = exec(
|
||||
db,
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(), user_id integer partition key)",
|
||||
)
|
||||
assert "error" in result
|
||||
assert "partition" in result.get("message", "").lower()
|
||||
|
||||
|
||||
def test_flat_with_auxiliary_still_works(db):
|
||||
"""Regression guard: flat-indexed tables with aux columns should still work."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4], +extra text)"
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v, extra) VALUES (1, ?, 'hello')",
|
||||
[_f32([1, 0, 0, 0])],
|
||||
)
|
||||
row = db.execute("SELECT extra FROM t WHERE rowid = 1").fetchone()
|
||||
assert row[0] == "hello"
|
||||
|
||||
|
||||
def test_flat_with_metadata_still_works(db):
|
||||
"""Regression guard: flat-indexed tables with metadata columns should still work."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4], genre text)"
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v, genre) VALUES (1, ?, 'rock')",
|
||||
[_f32([1, 0, 0, 0])],
|
||||
)
|
||||
row = db.execute("SELECT genre FROM t WHERE rowid = 1").fetchone()
|
||||
assert row[0] == "rock"
|
||||
|
||||
|
||||
def test_flat_with_partition_key_still_works(db):
|
||||
"""Regression guard: flat-indexed tables with partition key should still work."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4], user_id integer partition key)"
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v, user_id) VALUES (1, ?, 42)",
|
||||
[_f32([1, 0, 0, 0])],
|
||||
)
|
||||
row = db.execute("SELECT user_id FROM t WHERE rowid = 1").fetchone()
|
||||
assert row[0] == 42
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Edge cases
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_zero_vectors(db):
|
||||
"""Insert zero vectors, verify KNN still works."""
|
||||
db.execute("CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf())")
|
||||
for i in range(5):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||
[i, _f32([0, 0, 0, 0])],
|
||||
)
|
||||
results = knn(db, [0, 0, 0, 0], 5)
|
||||
assert len(results) == 5
|
||||
# All distances should be 0
|
||||
for _, dist in results:
|
||||
assert dist < 0.001
|
||||
|
||||
|
||||
def test_large_values(db):
|
||||
"""Insert vectors with very large and small values."""
|
||||
db.execute("CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf())")
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1e6, 1e6, 1e6, 1e6])]
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([1e-6, 1e-6, 1e-6, 1e-6])]
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (3, ?)", [_f32([-1e6, -1e6, -1e6, -1e6])]
|
||||
)
|
||||
|
||||
results = knn(db, [1e6, 1e6, 1e6, 1e6], 3)
|
||||
assert results[0][0] == 1
|
||||
|
||||
|
||||
def test_single_row_compute_centroids(db):
|
||||
"""Single row table, compute-centroids should still work."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=1))"
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 2, 3, 4])]
|
||||
)
|
||||
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||
assert ivf_assigned_count(db) == 1
|
||||
|
||||
results = knn(db, [1, 2, 3, 4], 1)
|
||||
assert len(results) == 1
|
||||
assert results[0][0] == 1
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Cell overflow (many vectors in one cell)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_cell_overflow_many_vectors(db):
|
||||
"""Insert >64 vectors that all go to same centroid. Should create multiple cells."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=0))"
|
||||
)
|
||||
# Insert 100 very similar vectors
|
||||
for i in range(100):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||
[i, _f32([1.0 + i * 0.001, 0, 0, 0])],
|
||||
)
|
||||
|
||||
# Set a single centroid so all vectors go there
|
||||
db.execute(
|
||||
"INSERT INTO t(t, v) VALUES ('set-centroid:0', ?)",
|
||||
[_f32([1.0, 0, 0, 0])],
|
||||
)
|
||||
db.execute("INSERT INTO t(t) VALUES ('assign-vectors')")
|
||||
|
||||
assert ivf_assigned_count(db) == 100
|
||||
|
||||
# Should have more than 1 cell (64 max per cell)
|
||||
cell_count = db.execute(
|
||||
"SELECT count(*) FROM t_ivf_cells00 WHERE centroid_id = 0"
|
||||
).fetchone()[0]
|
||||
assert cell_count >= 2 # 100 / 64 = 2 cells needed
|
||||
|
||||
# All vectors should be queryable
|
||||
results = knn(db, [1.0, 0, 0, 0], 100)
|
||||
assert len(results) == 100
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Large batch with training
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_large_batch_with_training(db):
|
||||
"""Insert 500, train, insert 500 more, verify total is 1000."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=16, nprobe=16))"
|
||||
)
|
||||
for i in range(500):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||
[i, _f32([float(i), 0, 0, 0])],
|
||||
)
|
||||
|
||||
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||
|
||||
for i in range(500, 1000):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||
[i, _f32([float(i), 0, 0, 0])],
|
||||
)
|
||||
|
||||
assert ivf_total_vectors(db) == 1000
|
||||
|
||||
# KNN should still work
|
||||
results = knn(db, [750.0, 0, 0, 0], 5)
|
||||
assert len(results) == 5
|
||||
assert results[0][0] == 750
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# KNN after interleaved insert/delete
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_knn_after_interleaved_insert_delete(db):
|
||||
"""Insert 20, train, delete 10 closest to query, verify remaining."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4, nprobe=4))"
|
||||
)
|
||||
for i in range(20):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||
[i, _f32([float(i), 0, 0, 0])],
|
||||
)
|
||||
|
||||
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||
|
||||
# Delete rowids 0-9 (closest to query at 5.0)
|
||||
for i in range(10):
|
||||
db.execute("DELETE FROM t WHERE rowid = ?", [i])
|
||||
|
||||
results = knn(db, [5.0, 0, 0, 0], 10)
|
||||
rowids = {r[0] for r in results}
|
||||
# None of the deleted rowids should appear
|
||||
assert all(r >= 10 for r in rowids)
|
||||
assert len(results) == 10
|
||||
|
||||
|
||||
def test_knn_empty_centroids_after_deletes(db):
|
||||
"""Some centroids may become empty after deletes. Should not crash."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4, nprobe=2))"
|
||||
)
|
||||
# Insert vectors clustered near 4 points
|
||||
for i in range(40):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||
[i, _f32([float(i % 10) * 10, 0, 0, 0])],
|
||||
)
|
||||
|
||||
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||
|
||||
# Delete a bunch, potentially emptying some centroids
|
||||
for i in range(30):
|
||||
db.execute("DELETE FROM t WHERE rowid = ?", [i])
|
||||
|
||||
# Should not crash even with empty centroids
|
||||
results = knn(db, [50.0, 0, 0, 0], 20)
|
||||
assert len(results) <= 10 # only 10 left
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# KNN returns correct distances
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_knn_correct_distances(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
|
||||
)
|
||||
db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([0, 0, 0, 0])])
|
||||
db.execute("INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([3, 0, 0, 0])])
|
||||
db.execute("INSERT INTO t(rowid, v) VALUES (3, ?)", [_f32([0, 4, 0, 0])])
|
||||
|
||||
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||
|
||||
results = knn(db, [0, 0, 0, 0], 3)
|
||||
result_map = {r[0]: r[1] for r in results}
|
||||
|
||||
# L2 distances (sqrt of sum of squared differences)
|
||||
assert abs(result_map[1] - 0.0) < 0.01
|
||||
assert abs(result_map[2] - 3.0) < 0.01 # sqrt(3^2) = 3
|
||||
assert abs(result_map[3] - 4.0) < 0.01 # sqrt(4^2) = 4
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Delete in flat mode leaves no orphan rowid_map entries
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_delete_flat_mode_rowid_map_count(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
|
||||
)
|
||||
for i in range(5):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||
[i, _f32([float(i), 0, 0, 0])],
|
||||
)
|
||||
|
||||
db.execute("DELETE FROM t WHERE rowid = 0")
|
||||
db.execute("DELETE FROM t WHERE rowid = 2")
|
||||
db.execute("DELETE FROM t WHERE rowid = 4")
|
||||
|
||||
assert db.execute("SELECT count(*) FROM t_ivf_rowid_map00").fetchone()[0] == 2
|
||||
assert ivf_unassigned_count(db) == 2
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Duplicate rowid insert
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_delete_reinsert_as_update(db):
|
||||
"""Simulate update via delete + insert on same rowid."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
|
||||
)
|
||||
db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 0, 0, 0])])
|
||||
|
||||
# Delete then re-insert as "update"
|
||||
db.execute("DELETE FROM t WHERE rowid = 1")
|
||||
db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([99, 0, 0, 0])])
|
||||
|
||||
results = knn(db, [99, 0, 0, 0], 1)
|
||||
assert len(results) == 1
|
||||
assert results[0][0] == 1
|
||||
assert results[0][1] < 0.01
|
||||
|
||||
|
||||
def test_duplicate_rowid_insert_fails(db):
|
||||
"""Inserting a duplicate rowid should fail with a constraint error."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
|
||||
)
|
||||
db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 0, 0, 0])])
|
||||
|
||||
result = exec(
|
||||
db,
|
||||
"INSERT INTO t(rowid, v) VALUES (1, ?)",
|
||||
[_f32([99, 0, 0, 0])],
|
||||
)
|
||||
assert "error" in result
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Interleaved insert/delete with KNN correctness
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_interleaved_ops_correctness(db):
|
||||
"""Complex sequence of inserts and deletes, verify KNN is always correct."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4, nprobe=4))"
|
||||
)
|
||||
|
||||
# Phase 1: Insert 50 vectors
|
||||
for i in range(50):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||
[i, _f32([float(i), 0, 0, 0])],
|
||||
)
|
||||
|
||||
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||
|
||||
# Phase 2: Delete even-numbered rowids
|
||||
for i in range(0, 50, 2):
|
||||
db.execute("DELETE FROM t WHERE rowid = ?", [i])
|
||||
|
||||
# Phase 3: Insert new vectors with higher rowids
|
||||
for i in range(50, 75):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||
[i, _f32([float(i), 0, 0, 0])],
|
||||
)
|
||||
|
||||
# Phase 4: Delete some of the new ones
|
||||
for i in range(60, 70):
|
||||
db.execute("DELETE FROM t WHERE rowid = ?", [i])
|
||||
|
||||
# KNN query: should only find existing vectors
|
||||
results = knn(db, [25.0, 0, 0, 0], 50)
|
||||
rowids = {r[0] for r in results}
|
||||
|
||||
# Verify no deleted rowids appear
|
||||
deleted = set(range(0, 50, 2)) | set(range(60, 70))
|
||||
assert len(rowids & deleted) == 0
|
||||
|
||||
# Verify we get the right count (25 odd + 15 new - 10 deleted new = 30)
|
||||
expected_alive = set(range(1, 50, 2)) | set(range(50, 60)) | set(range(70, 75))
|
||||
assert rowids.issubset(expected_alive)
|
||||
|
||||
|
||||
def test_ivf_update_vector_blocked(db):
|
||||
"""UPDATE on a vector column with IVF index should error (index would become stale)."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(emb float[4] indexed by ivf(nlist=2))"
|
||||
)
|
||||
db.execute("INSERT INTO t(rowid, emb) VALUES (1, ?)", [_f32([1, 0, 0, 0])])
|
||||
db.execute("INSERT INTO t(rowid, emb) VALUES (2, ?)", [_f32([0, 1, 0, 0])])
|
||||
|
||||
with pytest.raises(sqlite3.OperationalError, match="UPDATE on vector column.*not supported for IVF"):
|
||||
db.execute("UPDATE t SET emb = ? WHERE rowid = 1", [_f32([0, 0, 1, 0])])
|
||||
272
tests/test-ivf-quantization.py
Normal file
272
tests/test-ivf-quantization.py
Normal file
|
|
@ -0,0 +1,272 @@
|
|||
import pytest
|
||||
import sqlite3
|
||||
from helpers import _f32, exec
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def db():
|
||||
db = sqlite3.connect(":memory:")
|
||||
db.row_factory = sqlite3.Row
|
||||
db.enable_load_extension(True)
|
||||
db.load_extension("dist/vec0")
|
||||
db.enable_load_extension(False)
|
||||
return db
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Parser tests — quantizer and oversample options
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_ivf_quantizer_binary(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
"v float[768] indexed by ivf(nlist=64, quantizer=binary, oversample=10))"
|
||||
)
|
||||
tables = [
|
||||
r[0]
|
||||
for r in db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
|
||||
).fetchall()
|
||||
]
|
||||
assert "t_ivf_centroids00" in tables
|
||||
assert "t_ivf_cells00" in tables
|
||||
|
||||
|
||||
def test_ivf_quantizer_int8(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
"v float[768] indexed by ivf(nlist=64, quantizer=int8))"
|
||||
)
|
||||
tables = [
|
||||
r[0]
|
||||
for r in db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
|
||||
).fetchall()
|
||||
]
|
||||
assert "t_ivf_centroids00" in tables
|
||||
|
||||
|
||||
def test_ivf_quantizer_none_explicit(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
"v float[768] indexed by ivf(quantizer=none))"
|
||||
)
|
||||
# Should work — same as no quantizer
|
||||
tables = [
|
||||
r[0]
|
||||
for r in db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
|
||||
).fetchall()
|
||||
]
|
||||
assert "t_ivf_centroids00" in tables
|
||||
|
||||
|
||||
def test_ivf_quantizer_all_params(db):
|
||||
"""All params together: nlist, nprobe, quantizer, oversample."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
"v float[768] distance_metric=cosine "
|
||||
"indexed by ivf(nlist=128, nprobe=16, quantizer=int8, oversample=4))"
|
||||
)
|
||||
tables = [
|
||||
r[0]
|
||||
for r in db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
|
||||
).fetchall()
|
||||
]
|
||||
assert "t_ivf_centroids00" in tables
|
||||
|
||||
|
||||
def test_ivf_error_oversample_without_quantizer(db):
|
||||
"""oversample > 1 without quantizer should error."""
|
||||
result = exec(
|
||||
db,
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
"v float[768] indexed by ivf(oversample=10))",
|
||||
)
|
||||
assert "error" in result
|
||||
|
||||
|
||||
def test_ivf_error_unknown_quantizer(db):
|
||||
result = exec(
|
||||
db,
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
"v float[768] indexed by ivf(quantizer=pq))",
|
||||
)
|
||||
assert "error" in result
|
||||
|
||||
|
||||
def test_ivf_oversample_1_without_quantizer_ok(db):
|
||||
"""oversample=1 (default) is fine without quantizer."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
"v float[768] indexed by ivf(nlist=64))"
|
||||
)
|
||||
# Should succeed — oversample defaults to 1
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Functional tests — insert, train, query with quantized IVF
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_ivf_int8_insert_and_query(db):
|
||||
"""int8 quantized IVF: insert, train, query."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
"v float[4] indexed by ivf(nlist=2, quantizer=int8, oversample=4))"
|
||||
)
|
||||
for i in range(20):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||
)
|
||||
|
||||
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||
|
||||
# Should be able to query
|
||||
rows = db.execute(
|
||||
"SELECT rowid, distance FROM t WHERE v MATCH ? AND k = 5",
|
||||
[_f32([10.0, 0, 0, 0])],
|
||||
).fetchall()
|
||||
assert len(rows) == 5
|
||||
# Top result should be close to 10
|
||||
assert rows[0][0] in range(8, 13)
|
||||
|
||||
# Full vectors should be in _ivf_vectors table
|
||||
fv_count = db.execute("SELECT count(*) FROM t_ivf_vectors00").fetchone()[0]
|
||||
assert fv_count == 20
|
||||
|
||||
|
||||
def test_ivf_binary_insert_and_query(db):
|
||||
"""Binary quantized IVF: insert, train, query."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
"v float[8] indexed by ivf(nlist=2, quantizer=binary, oversample=4))"
|
||||
)
|
||||
for i in range(20):
|
||||
# Vectors with varying sign patterns
|
||||
v = [(i * 0.1 - 1.0) + j * 0.3 for j in range(8)]
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32(v)]
|
||||
)
|
||||
|
||||
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE v MATCH ? AND k = 5",
|
||||
[_f32([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5])],
|
||||
).fetchall()
|
||||
assert len(rows) == 5
|
||||
|
||||
# Full vectors stored
|
||||
fv_count = db.execute("SELECT count(*) FROM t_ivf_vectors00").fetchone()[0]
|
||||
assert fv_count == 20
|
||||
|
||||
|
||||
def test_ivf_int8_cell_sizes_smaller(db):
|
||||
"""Cell blobs should be smaller with int8 quantization."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
"v float[768] indexed by ivf(nlist=2, quantizer=int8, oversample=1))"
|
||||
)
|
||||
for i in range(10):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||
[i, _f32([float(x) / 768 for x in range(768)])],
|
||||
)
|
||||
|
||||
# Cell vectors blob: 10 vectors at int8 = 10 * 768 = 7680 bytes
|
||||
# vs float32 = 10 * 768 * 4 = 30720 bytes
|
||||
# But cells have capacity 64, so blob = 64 * 768 = 49152 (int8) vs 64*768*4=196608 (float32)
|
||||
blob_size = db.execute(
|
||||
"SELECT length(vectors) FROM t_ivf_cells00 LIMIT 1"
|
||||
).fetchone()[0]
|
||||
# int8: 64 slots * 768 bytes = 49152
|
||||
assert blob_size == 64 * 768
|
||||
|
||||
|
||||
def test_ivf_binary_cell_sizes_smaller(db):
|
||||
"""Cell blobs should be much smaller with binary quantization."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
"v float[768] indexed by ivf(nlist=2, quantizer=binary, oversample=1))"
|
||||
)
|
||||
for i in range(10):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||
[i, _f32([float(x) / 768 for x in range(768)])],
|
||||
)
|
||||
|
||||
blob_size = db.execute(
|
||||
"SELECT length(vectors) FROM t_ivf_cells00 LIMIT 1"
|
||||
).fetchone()[0]
|
||||
# binary: 64 slots * 768/8 bytes = 6144
|
||||
assert blob_size == 64 * (768 // 8)
|
||||
|
||||
|
||||
def test_ivf_int8_oversample_improves_recall(db):
|
||||
"""Oversample re-ranking should improve recall over oversample=1."""
|
||||
# Create two tables: one with oversample=1, one with oversample=10
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t1 USING vec0("
|
||||
"v float[4] indexed by ivf(nlist=4, quantizer=int8, oversample=1))"
|
||||
)
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t2 USING vec0("
|
||||
"v float[4] indexed by ivf(nlist=4, quantizer=int8, oversample=10))"
|
||||
)
|
||||
for i in range(100):
|
||||
v = _f32([i * 0.1, (i * 0.1) % 3, (i * 0.3) % 5, i * 0.01])
|
||||
db.execute("INSERT INTO t1(rowid, v) VALUES (?, ?)", [i, v])
|
||||
db.execute("INSERT INTO t2(rowid, v) VALUES (?, ?)", [i, v])
|
||||
|
||||
db.execute("INSERT INTO t1(t1) VALUES ('compute-centroids')")
|
||||
db.execute("INSERT INTO t2(t2) VALUES ('compute-centroids')")
|
||||
db.execute("INSERT INTO t1(t1) VALUES ('nprobe=4')")
|
||||
db.execute("INSERT INTO t2(t2) VALUES ('nprobe=4')")
|
||||
|
||||
query = _f32([5.0, 1.5, 2.5, 0.5])
|
||||
r1 = db.execute("SELECT rowid FROM t1 WHERE v MATCH ? AND k=10", [query]).fetchall()
|
||||
r2 = db.execute("SELECT rowid FROM t2 WHERE v MATCH ? AND k=10", [query]).fetchall()
|
||||
|
||||
# Both should return 10 results
|
||||
assert len(r1) == 10
|
||||
assert len(r2) == 10
|
||||
# oversample=10 should have at least as good recall (same or better ordering)
|
||||
|
||||
|
||||
def test_ivf_quantized_delete(db):
|
||||
"""Delete should remove from both cells and _ivf_vectors."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
"v float[4] indexed by ivf(nlist=2, quantizer=int8, oversample=1))"
|
||||
)
|
||||
for i in range(10):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||
)
|
||||
|
||||
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||
assert db.execute("SELECT count(*) FROM t_ivf_vectors00").fetchone()[0] == 10
|
||||
|
||||
db.execute("DELETE FROM t WHERE rowid = 5")
|
||||
# _ivf_vectors should have 9 rows
|
||||
assert db.execute("SELECT count(*) FROM t_ivf_vectors00").fetchone()[0] == 9
|
||||
|
||||
|
||||
def test_ivf_binary_rejects_non_multiple_of_8_dims(db):
|
||||
"""Binary quantizer requires dimensions divisible by 8."""
|
||||
with pytest.raises(sqlite3.OperationalError):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0("
|
||||
" v float[12] indexed by ivf(quantizer=binary)"
|
||||
")"
|
||||
)
|
||||
|
||||
# Dimensions divisible by 8 should work
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t2 USING vec0("
|
||||
" v float[16] indexed by ivf(quantizer=binary)"
|
||||
")"
|
||||
)
|
||||
426
tests/test-ivf.py
Normal file
426
tests/test-ivf.py
Normal file
|
|
@ -0,0 +1,426 @@
|
|||
import pytest
|
||||
import sqlite3
|
||||
import struct
|
||||
import math
|
||||
from helpers import _f32, exec
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def db():
|
||||
db = sqlite3.connect(":memory:")
|
||||
db.row_factory = sqlite3.Row
|
||||
db.enable_load_extension(True)
|
||||
db.load_extension("dist/vec0")
|
||||
db.enable_load_extension(False)
|
||||
return db
|
||||
|
||||
|
||||
def ivf_total_vectors(db, table="t", col=0):
|
||||
"""Count total vectors across all IVF cells."""
|
||||
return db.execute(
|
||||
f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d}"
|
||||
).fetchone()[0]
|
||||
|
||||
|
||||
def ivf_unassigned_count(db, table="t", col=0):
|
||||
"""Count vectors in unassigned cells (centroid_id=-1)."""
|
||||
return db.execute(
|
||||
f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d} WHERE centroid_id = -1"
|
||||
).fetchone()[0]
|
||||
|
||||
|
||||
def ivf_assigned_count(db, table="t", col=0):
|
||||
"""Count vectors in trained cells (centroid_id >= 0)."""
|
||||
return db.execute(
|
||||
f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d} WHERE centroid_id >= 0"
|
||||
).fetchone()[0]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Parser tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_ivf_create_defaults(db):
|
||||
"""ivf() with no args uses defaults."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf())"
|
||||
)
|
||||
tables = [
|
||||
r[0]
|
||||
for r in db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
|
||||
).fetchall()
|
||||
]
|
||||
assert "t_ivf_centroids00" in tables
|
||||
assert "t_ivf_cells00" in tables
|
||||
assert "t_ivf_rowid_map00" in tables
|
||||
|
||||
|
||||
def test_ivf_create_custom_params(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=64, nprobe=8))"
|
||||
)
|
||||
tables = [
|
||||
r[0]
|
||||
for r in db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
|
||||
).fetchall()
|
||||
]
|
||||
assert "t_ivf_centroids00" in tables
|
||||
assert "t_ivf_cells00" in tables
|
||||
|
||||
|
||||
def test_ivf_create_with_distance_metric(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] distance_metric=cosine indexed by ivf(nlist=16))"
|
||||
)
|
||||
tables = [
|
||||
r[0]
|
||||
for r in db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
|
||||
).fetchall()
|
||||
]
|
||||
assert "t_ivf_centroids00" in tables
|
||||
|
||||
|
||||
def test_ivf_create_error_unknown_key(db):
|
||||
result = exec(
|
||||
db,
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(bogus=1))",
|
||||
)
|
||||
assert "error" in result
|
||||
|
||||
|
||||
def test_ivf_create_error_nprobe_gt_nlist(db):
|
||||
result = exec(
|
||||
db,
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4, nprobe=10))",
|
||||
)
|
||||
assert "error" in result
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Shadow table tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_ivf_shadow_tables_created(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=8))"
|
||||
)
|
||||
trained = db.execute(
|
||||
"SELECT value FROM t_info WHERE key = 'ivf_trained_0'"
|
||||
).fetchone()[0]
|
||||
assert str(trained) == "0"
|
||||
|
||||
# No cells yet (created lazily on first insert)
|
||||
count = db.execute(
|
||||
"SELECT count(*) FROM t_ivf_cells00"
|
||||
).fetchone()[0]
|
||||
assert count == 0
|
||||
|
||||
|
||||
def test_ivf_drop_cleans_up(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
|
||||
)
|
||||
db.execute("DROP TABLE t")
|
||||
tables = [
|
||||
r[0]
|
||||
for r in db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table'"
|
||||
).fetchall()
|
||||
]
|
||||
assert not any("ivf" in t for t in tables)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Insert tests (flat mode)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_ivf_insert_flat_mode(db):
|
||||
"""Before training, vectors go to unassigned cell."""
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
|
||||
)
|
||||
db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 2, 3, 4])])
|
||||
db.execute("INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([5, 6, 7, 8])])
|
||||
|
||||
assert ivf_unassigned_count(db) == 2
|
||||
assert ivf_assigned_count(db) == 0
|
||||
|
||||
# rowid_map should have 2 entries
|
||||
assert db.execute("SELECT count(*) FROM t_ivf_rowid_map00").fetchone()[0] == 2
|
||||
|
||||
|
||||
def test_ivf_delete_flat_mode(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
|
||||
)
|
||||
db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 2, 3, 4])])
|
||||
db.execute("INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([5, 6, 7, 8])])
|
||||
db.execute("DELETE FROM t WHERE rowid = 1")
|
||||
|
||||
assert ivf_unassigned_count(db) == 1
|
||||
assert db.execute("SELECT count(*) FROM t_ivf_rowid_map00").fetchone()[0] == 1
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# KNN flat mode tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_ivf_knn_flat_mode(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
|
||||
)
|
||||
db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 0, 0, 0])])
|
||||
db.execute("INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([2, 0, 0, 0])])
|
||||
db.execute("INSERT INTO t(rowid, v) VALUES (3, ?)", [_f32([9, 0, 0, 0])])
|
||||
|
||||
rows = db.execute(
|
||||
"SELECT rowid, distance FROM t WHERE v MATCH ? AND k = 2",
|
||||
[_f32([1.5, 0, 0, 0])],
|
||||
).fetchall()
|
||||
assert len(rows) == 2
|
||||
rowids = {r[0] for r in rows}
|
||||
assert rowids == {1, 2}
|
||||
|
||||
|
||||
def test_ivf_knn_flat_empty(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
|
||||
)
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE v MATCH ? AND k = 5",
|
||||
[_f32([1, 0, 0, 0])],
|
||||
).fetchall()
|
||||
assert len(rows) == 0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# compute-centroids tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_compute_centroids(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
|
||||
)
|
||||
for i in range(40):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||
[i, _f32([i % 10, i // 10, 0, 0])],
|
||||
)
|
||||
|
||||
assert ivf_unassigned_count(db) == 40
|
||||
|
||||
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||
|
||||
# After training: unassigned cell should be gone (or empty), vectors in trained cells
|
||||
assert ivf_unassigned_count(db) == 0
|
||||
assert ivf_assigned_count(db) == 40
|
||||
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 4
|
||||
trained = db.execute(
|
||||
"SELECT value FROM t_info WHERE key='ivf_trained_0'"
|
||||
).fetchone()[0]
|
||||
assert str(trained) == "1"
|
||||
|
||||
|
||||
def test_compute_centroids_recompute(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2))"
|
||||
)
|
||||
for i in range(20):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||
)
|
||||
|
||||
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
|
||||
|
||||
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
|
||||
assert ivf_assigned_count(db) == 20
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Insert after training (assigned mode)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_ivf_insert_after_training(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2))"
|
||||
)
|
||||
for i in range(20):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||
)
|
||||
|
||||
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (100, ?)", [_f32([5, 0, 0, 0])]
|
||||
)
|
||||
|
||||
# Should be in a trained cell, not unassigned
|
||||
row = db.execute(
|
||||
"SELECT m.cell_id, c.centroid_id FROM t_ivf_rowid_map00 m "
|
||||
"JOIN t_ivf_cells00 c ON c.rowid = m.cell_id "
|
||||
"WHERE m.rowid = 100"
|
||||
).fetchone()
|
||||
assert row is not None
|
||||
assert row[1] >= 0 # centroid_id >= 0 means trained cell
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# KNN after training (IVF probe mode)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_ivf_knn_after_training(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4, nprobe=4))"
|
||||
)
|
||||
for i in range(100):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||
)
|
||||
|
||||
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||
|
||||
rows = db.execute(
|
||||
"SELECT rowid, distance FROM t WHERE v MATCH ? AND k = 5",
|
||||
[_f32([50.0, 0, 0, 0])],
|
||||
).fetchall()
|
||||
assert len(rows) == 5
|
||||
assert rows[0][0] == 50
|
||||
assert rows[0][1] < 0.01
|
||||
|
||||
|
||||
def test_ivf_knn_k_larger_than_n(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
|
||||
)
|
||||
for i in range(5):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||
)
|
||||
|
||||
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE v MATCH ? AND k = 100",
|
||||
[_f32([0, 0, 0, 0])],
|
||||
).fetchall()
|
||||
assert len(rows) == 5
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Manual centroid import (set-centroid, assign-vectors)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_set_centroid_and_assign(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=0))"
|
||||
)
|
||||
for i in range(20):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||
)
|
||||
|
||||
db.execute(
|
||||
"INSERT INTO t(t, v) VALUES ('set-centroid:0', ?)",
|
||||
[_f32([5, 0, 0, 0])],
|
||||
)
|
||||
db.execute(
|
||||
"INSERT INTO t(t, v) VALUES ('set-centroid:1', ?)",
|
||||
[_f32([15, 0, 0, 0])],
|
||||
)
|
||||
|
||||
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
|
||||
|
||||
db.execute("INSERT INTO t(t) VALUES ('assign-vectors')")
|
||||
|
||||
assert ivf_unassigned_count(db) == 0
|
||||
assert ivf_assigned_count(db) == 20
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# clear-centroids
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_clear_centroids(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2))"
|
||||
)
|
||||
for i in range(20):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||
)
|
||||
|
||||
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
|
||||
|
||||
db.execute("INSERT INTO t(t) VALUES ('clear-centroids')")
|
||||
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 0
|
||||
assert ivf_unassigned_count(db) == 20
|
||||
trained = db.execute(
|
||||
"SELECT value FROM t_info WHERE key='ivf_trained_0'"
|
||||
).fetchone()[0]
|
||||
assert str(trained) == "0"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Delete after training
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_ivf_delete_after_training(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2))"
|
||||
)
|
||||
for i in range(10):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||
)
|
||||
|
||||
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||
assert ivf_assigned_count(db) == 10
|
||||
|
||||
db.execute("DELETE FROM t WHERE rowid = 5")
|
||||
assert ivf_assigned_count(db) == 9
|
||||
assert db.execute("SELECT count(*) FROM t_ivf_rowid_map00").fetchone()[0] == 9
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Recall test
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_ivf_recall_nprobe_equals_nlist(db):
|
||||
db.execute(
|
||||
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=8, nprobe=8))"
|
||||
)
|
||||
for i in range(100):
|
||||
db.execute(
|
||||
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||
)
|
||||
|
||||
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||
|
||||
rows = db.execute(
|
||||
"SELECT rowid FROM t WHERE v MATCH ? AND k = 10",
|
||||
[_f32([50.0, 0, 0, 0])],
|
||||
).fetchall()
|
||||
rowids = {r[0] for r in rows}
|
||||
|
||||
# 45 and 55 are equidistant from 50, so either may appear in top 10
|
||||
assert 50 in rowids
|
||||
assert len(rowids) == 10
|
||||
assert all(45 <= r <= 55 for r in rowids)
|
||||
138
tests/test-legacy-compat.py
Normal file
138
tests/test-legacy-compat.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
"""Backwards compatibility tests: current sqlite-vec reading legacy databases.
|
||||
|
||||
The fixture file tests/fixtures/legacy-v0.1.6.db was generated by
|
||||
tests/generate_legacy_db.py using sqlite-vec v0.1.6. These tests verify
|
||||
that the current version can fully read, query, insert into, and delete
|
||||
from tables created by older versions.
|
||||
"""
|
||||
import sqlite3
|
||||
import struct
|
||||
import os
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
FIXTURE_PATH = os.path.join(os.path.dirname(__file__), "fixtures", "legacy-v0.1.6.db")
|
||||
|
||||
|
||||
def _f32(vals):
|
||||
return struct.pack(f"{len(vals)}f", *vals)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def legacy_db(tmp_path):
|
||||
"""Copy the legacy fixture to a temp dir so tests can modify it."""
|
||||
if not os.path.exists(FIXTURE_PATH):
|
||||
pytest.skip("Legacy fixture not found — run: uv run --script tests/generate_legacy_db.py")
|
||||
db_path = str(tmp_path / "legacy.db")
|
||||
shutil.copy2(FIXTURE_PATH, db_path)
|
||||
db = sqlite3.connect(db_path)
|
||||
db.row_factory = sqlite3.Row
|
||||
db.enable_load_extension(True)
|
||||
db.load_extension("dist/vec0")
|
||||
return db
|
||||
|
||||
|
||||
def test_legacy_select_count(legacy_db):
|
||||
"""Basic SELECT count should return all rows."""
|
||||
count = legacy_db.execute("SELECT count(*) FROM legacy_vectors").fetchone()[0]
|
||||
assert count == 50
|
||||
|
||||
|
||||
def test_legacy_point_query(legacy_db):
|
||||
"""Point query by rowid should return correct vector."""
|
||||
row = legacy_db.execute(
|
||||
"SELECT rowid, emb FROM legacy_vectors WHERE rowid = 1"
|
||||
).fetchone()
|
||||
assert row["rowid"] == 1
|
||||
vec = struct.unpack("4f", row["emb"])
|
||||
assert vec[0] == pytest.approx(1.0)
|
||||
|
||||
|
||||
def test_legacy_knn(legacy_db):
|
||||
"""KNN query on legacy table should return correct results."""
|
||||
query = _f32([1.0, 0.0, 0.0, 0.0])
|
||||
rows = legacy_db.execute(
|
||||
"SELECT rowid, distance FROM legacy_vectors "
|
||||
"WHERE emb MATCH ? AND k = 5",
|
||||
[query],
|
||||
).fetchall()
|
||||
assert len(rows) == 5
|
||||
assert rows[0]["rowid"] == 1
|
||||
assert rows[0]["distance"] == pytest.approx(0.0)
|
||||
for i in range(len(rows) - 1):
|
||||
assert rows[i]["distance"] <= rows[i + 1]["distance"]
|
||||
|
||||
|
||||
def test_legacy_insert(legacy_db):
|
||||
"""INSERT into legacy table should work."""
|
||||
legacy_db.execute(
|
||||
"INSERT INTO legacy_vectors(rowid, emb) VALUES (100, ?)",
|
||||
[_f32([100.0, 0.0, 0.0, 0.0])],
|
||||
)
|
||||
count = legacy_db.execute("SELECT count(*) FROM legacy_vectors").fetchone()[0]
|
||||
assert count == 51
|
||||
|
||||
rows = legacy_db.execute(
|
||||
"SELECT rowid FROM legacy_vectors WHERE emb MATCH ? AND k = 1",
|
||||
[_f32([100.0, 0.0, 0.0, 0.0])],
|
||||
).fetchall()
|
||||
assert rows[0]["rowid"] == 100
|
||||
|
||||
|
||||
def test_legacy_delete(legacy_db):
|
||||
"""DELETE from legacy table should work."""
|
||||
legacy_db.execute("DELETE FROM legacy_vectors WHERE rowid = 1")
|
||||
count = legacy_db.execute("SELECT count(*) FROM legacy_vectors").fetchone()[0]
|
||||
assert count == 49
|
||||
|
||||
rows = legacy_db.execute(
|
||||
"SELECT rowid FROM legacy_vectors WHERE emb MATCH ? AND k = 5",
|
||||
[_f32([1.0, 0.0, 0.0, 0.0])],
|
||||
).fetchall()
|
||||
assert 1 not in [r["rowid"] for r in rows]
|
||||
|
||||
|
||||
def test_legacy_fullscan(legacy_db):
|
||||
"""Full scan should work."""
|
||||
rows = legacy_db.execute(
|
||||
"SELECT rowid FROM legacy_vectors ORDER BY rowid LIMIT 5"
|
||||
).fetchall()
|
||||
assert [r["rowid"] for r in rows] == [1, 2, 3, 4, 5]
|
||||
|
||||
|
||||
def test_legacy_name_conflict_table(legacy_db):
|
||||
"""Legacy table where column name == table name should work.
|
||||
|
||||
The v0.1.6 DB has: CREATE VIRTUAL TABLE emb USING vec0(emb float[4])
|
||||
Current code should NOT add the command column for this table
|
||||
(detected via _info version check), avoiding the name conflict.
|
||||
"""
|
||||
count = legacy_db.execute("SELECT count(*) FROM emb").fetchone()[0]
|
||||
assert count == 10
|
||||
|
||||
rows = legacy_db.execute(
|
||||
"SELECT rowid, distance FROM emb WHERE emb MATCH ? AND k = 3",
|
||||
[_f32([1.0, 0.0, 0.0, 0.0])],
|
||||
).fetchall()
|
||||
assert len(rows) == 3
|
||||
assert rows[0]["rowid"] == 1
|
||||
|
||||
|
||||
def test_legacy_name_conflict_insert_delete(legacy_db):
|
||||
"""INSERT and DELETE on legacy name-conflict table."""
|
||||
legacy_db.execute(
|
||||
"INSERT INTO emb(rowid, emb) VALUES (100, ?)",
|
||||
[_f32([100.0, 0.0, 0.0, 0.0])],
|
||||
)
|
||||
assert legacy_db.execute("SELECT count(*) FROM emb").fetchone()[0] == 11
|
||||
|
||||
legacy_db.execute("DELETE FROM emb WHERE rowid = 5")
|
||||
assert legacy_db.execute("SELECT count(*) FROM emb").fetchone()[0] == 10
|
||||
|
||||
|
||||
def test_legacy_no_command_column(legacy_db):
|
||||
"""Legacy tables should NOT have the command column."""
|
||||
with pytest.raises(sqlite3.OperationalError):
|
||||
legacy_db.execute(
|
||||
"INSERT INTO legacy_vectors(legacy_vectors) VALUES ('some_command')"
|
||||
)
|
||||
|
|
@ -119,151 +119,9 @@ FUNCTIONS = [
|
|||
MODULES = [
|
||||
"vec0",
|
||||
"vec_each",
|
||||
# "vec_static_blob_entries",
|
||||
# "vec_static_blobs",
|
||||
]
|
||||
|
||||
|
||||
def register_numpy(db, name: str, array):
|
||||
ptr = array.__array_interface__["data"][0]
|
||||
nvectors, dimensions = array.__array_interface__["shape"]
|
||||
element_type = array.__array_interface__["typestr"]
|
||||
|
||||
assert element_type == "<f4"
|
||||
|
||||
name_escaped = db.execute("select printf('%w', ?)", [name]).fetchone()[0]
|
||||
|
||||
db.execute(
|
||||
"""
|
||||
insert into temp.vec_static_blobs(name, data)
|
||||
select ?, vec_static_blob_from_raw(?, ?, ?, ?)
|
||||
""",
|
||||
[name, ptr, element_type, dimensions, nvectors],
|
||||
)
|
||||
|
||||
db.execute(
|
||||
f'create virtual table "{name_escaped}" using vec_static_blob_entries({name_escaped})'
|
||||
)
|
||||
|
||||
|
||||
def test_vec_static_blob_entries():
|
||||
db = connect(EXT_PATH, extra_entrypoint="sqlite3_vec_static_blobs_init")
|
||||
|
||||
x = np.array([[0.1, 0.2, 0.3, 0.4], [0.9, 0.8, 0.7, 0.6]], dtype=np.float32)
|
||||
y = np.array([[0.2, 0.3], [0.9, 0.8], [0.6, 0.5]], dtype=np.float32)
|
||||
z = np.array(
|
||||
[
|
||||
[0.1, 0.1, 0.1, 0.1],
|
||||
[0.2, 0.2, 0.2, 0.2],
|
||||
[0.3, 0.3, 0.3, 0.3],
|
||||
[0.4, 0.4, 0.4, 0.4],
|
||||
[0.5, 0.5, 0.5, 0.5],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
register_numpy(db, "x", x)
|
||||
register_numpy(db, "y", y)
|
||||
register_numpy(db, "z", z)
|
||||
assert execute_all(
|
||||
db, "select *, dimensions, count from temp.vec_static_blobs;"
|
||||
) == [
|
||||
{
|
||||
"count": 2,
|
||||
"data": None,
|
||||
"dimensions": 4,
|
||||
"name": "x",
|
||||
},
|
||||
{
|
||||
"count": 3,
|
||||
"data": None,
|
||||
"dimensions": 2,
|
||||
"name": "y",
|
||||
},
|
||||
{
|
||||
"count": 5,
|
||||
"data": None,
|
||||
"dimensions": 4,
|
||||
"name": "z",
|
||||
},
|
||||
]
|
||||
|
||||
assert execute_all(db, "select vec_to_json(vector) from x;") == [
|
||||
{
|
||||
"vec_to_json(vector)": "[0.100000,0.200000,0.300000,0.400000]",
|
||||
},
|
||||
{
|
||||
"vec_to_json(vector)": "[0.900000,0.800000,0.700000,0.600000]",
|
||||
},
|
||||
]
|
||||
assert execute_all(db, "select (vector) from y limit 2;") == [
|
||||
{
|
||||
"vector": b"\xcd\xccL>\x9a\x99\x99>",
|
||||
},
|
||||
{
|
||||
"vector": b"fff?\xcd\xccL?",
|
||||
},
|
||||
]
|
||||
assert execute_all(db, "select rowid, (vector) from z") == [
|
||||
{
|
||||
"rowid": 0,
|
||||
"vector": b"\xcd\xcc\xcc=\xcd\xcc\xcc=\xcd\xcc\xcc=\xcd\xcc\xcc=",
|
||||
},
|
||||
{
|
||||
"rowid": 1,
|
||||
"vector": b"\xcd\xccL>\xcd\xccL>\xcd\xccL>\xcd\xccL>",
|
||||
},
|
||||
{
|
||||
"rowid": 2,
|
||||
"vector": b"\x9a\x99\x99>\x9a\x99\x99>\x9a\x99\x99>\x9a\x99\x99>",
|
||||
},
|
||||
{
|
||||
"rowid": 3,
|
||||
"vector": b"\xcd\xcc\xcc>\xcd\xcc\xcc>\xcd\xcc\xcc>\xcd\xcc\xcc>",
|
||||
},
|
||||
{
|
||||
"rowid": 4,
|
||||
"vector": b"\x00\x00\x00?\x00\x00\x00?\x00\x00\x00?\x00\x00\x00?",
|
||||
},
|
||||
]
|
||||
assert execute_all(
|
||||
db,
|
||||
"select rowid, vec_to_json(vector) as v from z where vector match ? and k = 3 order by distance;",
|
||||
[np.array([0.3, 0.3, 0.3, 0.3], dtype=np.float32)],
|
||||
) == [
|
||||
{
|
||||
"rowid": 2,
|
||||
"v": "[0.300000,0.300000,0.300000,0.300000]",
|
||||
},
|
||||
{
|
||||
"rowid": 3,
|
||||
"v": "[0.400000,0.400000,0.400000,0.400000]",
|
||||
},
|
||||
{
|
||||
"rowid": 1,
|
||||
"v": "[0.200000,0.200000,0.200000,0.200000]",
|
||||
},
|
||||
]
|
||||
assert execute_all(
|
||||
db,
|
||||
"select rowid, vec_to_json(vector) as v from z where vector match ? and k = 3 order by distance;",
|
||||
[np.array([0.6, 0.6, 0.6, 0.6], dtype=np.float32)],
|
||||
) == [
|
||||
{
|
||||
"rowid": 4,
|
||||
"v": "[0.500000,0.500000,0.500000,0.500000]",
|
||||
},
|
||||
{
|
||||
"rowid": 3,
|
||||
"v": "[0.400000,0.400000,0.400000,0.400000]",
|
||||
},
|
||||
{
|
||||
"rowid": 2,
|
||||
"v": "[0.300000,0.300000,0.300000,0.300000]",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_limits():
|
||||
db = connect(EXT_PATH)
|
||||
with _raises(
|
||||
|
|
@ -507,6 +365,34 @@ def test_vec_distance_l1():
|
|||
)
|
||||
|
||||
|
||||
def test_vec_reject_nan_inf():
|
||||
"""NaN and Inf in float32 vectors should be rejected."""
|
||||
import struct, math
|
||||
|
||||
# NaN via blob
|
||||
nan_blob = struct.pack("4f", 1.0, float("nan"), 3.0, 4.0)
|
||||
with pytest.raises(sqlite3.OperationalError, match="NaN"):
|
||||
db.execute("SELECT vec_length(?)", [nan_blob])
|
||||
|
||||
# Inf via blob
|
||||
inf_blob = struct.pack("4f", 1.0, float("inf"), 3.0, 4.0)
|
||||
with pytest.raises(sqlite3.OperationalError, match="Inf"):
|
||||
db.execute("SELECT vec_length(?)", [inf_blob])
|
||||
|
||||
# -Inf via blob
|
||||
ninf_blob = struct.pack("4f", 1.0, float("-inf"), 3.0, 4.0)
|
||||
with pytest.raises(sqlite3.OperationalError, match="Inf"):
|
||||
db.execute("SELECT vec_length(?)", [ninf_blob])
|
||||
|
||||
# NaN via JSON
|
||||
# Note: JSON doesn't have NaN literal, but strtod may parse "NaN"
|
||||
# This tests the blob path which is the primary input method
|
||||
|
||||
# Valid vectors still work
|
||||
ok_blob = struct.pack("4f", 1.0, 2.0, 3.0, 4.0)
|
||||
assert db.execute("SELECT vec_length(?)", [ok_blob]).fetchone()[0] == 4
|
||||
|
||||
|
||||
def test_vec_distance_l2():
|
||||
vec_distance_l2 = lambda *args, a="?", b="?": db.execute(
|
||||
f"select vec_distance_l2({a}, {b})", args
|
||||
|
|
@ -523,11 +409,17 @@ def test_vec_distance_l2():
|
|||
|
||||
x = vec_distance_l2(a_sql_t, b_sql_t, a=transform, b=transform)
|
||||
y = npy_l2(np.array(a), np.array(b))
|
||||
assert isclose(x, y, abs_tol=1e-6)
|
||||
assert isclose(x, y, rel_tol=1e-5, abs_tol=1e-6)
|
||||
|
||||
check([1.2, 0.1], [0.4, -0.4])
|
||||
check([-1.2, -0.1], [-0.4, 0.4])
|
||||
check([1, 2, 3], [-9, -8, -7], dtype=np.int8)
|
||||
# Extreme int8 values: diff=255, squared=65025 which overflows i16
|
||||
# This tests the NEON widening multiply fix (slight float rounding expected)
|
||||
check([-128] * 8, [127] * 8, dtype=np.int8)
|
||||
check([-128] * 16, [127] * 16, dtype=np.int8)
|
||||
check([-128, 127, -128, 127, -128, 127, -128, 127],
|
||||
[127, -128, 127, -128, 127, -128, 127, -128], dtype=np.int8)
|
||||
|
||||
|
||||
def test_vec_length():
|
||||
|
|
@ -1618,231 +1510,6 @@ def test_vec_each():
|
|||
vec_each_f32(None)
|
||||
|
||||
|
||||
import io
|
||||
|
||||
|
||||
def to_npy(arr):
|
||||
buf = io.BytesIO()
|
||||
np.save(buf, arr)
|
||||
buf.seek(0)
|
||||
return buf.read()
|
||||
|
||||
|
||||
def test_vec_npy_each():
|
||||
db = connect(EXT_PATH, extra_entrypoint="sqlite3_vec_numpy_init")
|
||||
vec_npy_each = lambda *args: execute_all(
|
||||
db, "select rowid, * from vec_npy_each(?)", args
|
||||
)
|
||||
assert vec_npy_each(to_npy(np.array([1.1, 2.2, 3.3], dtype=np.float32))) == [
|
||||
{
|
||||
"rowid": 0,
|
||||
"vector": _f32([1.1, 2.2, 3.3]),
|
||||
},
|
||||
]
|
||||
assert vec_npy_each(to_npy(np.array([[1.1, 2.2, 3.3]], dtype=np.float32))) == [
|
||||
{
|
||||
"rowid": 0,
|
||||
"vector": _f32([1.1, 2.2, 3.3]),
|
||||
},
|
||||
]
|
||||
assert vec_npy_each(
|
||||
to_npy(np.array([[1.1, 2.2, 3.3], [9.9, 8.8, 7.7]], dtype=np.float32))
|
||||
) == [
|
||||
{
|
||||
"rowid": 0,
|
||||
"vector": _f32([1.1, 2.2, 3.3]),
|
||||
},
|
||||
{
|
||||
"rowid": 1,
|
||||
"vector": _f32([9.9, 8.8, 7.7]),
|
||||
},
|
||||
]
|
||||
|
||||
assert vec_npy_each(to_npy(np.array([], dtype=np.float32))) == []
|
||||
|
||||
|
||||
def test_vec_npy_each_errors():
|
||||
db = connect(EXT_PATH, extra_entrypoint="sqlite3_vec_numpy_init")
|
||||
vec_npy_each = lambda *args: execute_all(
|
||||
db, "select rowid, * from vec_npy_each(?)", args
|
||||
)
|
||||
|
||||
full = b"\x93NUMPY\x01\x00v\x00{'descr': '<f4', 'fortran_order': False, 'shape': (2, 4), } \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
||||
|
||||
# EVIDENCE-OF: V03312_20150 numpy validation too short
|
||||
with _raises("numpy array too short"):
|
||||
vec_npy_each(b"")
|
||||
# EVIDENCE-OF: V11954_28792 numpy validate magic
|
||||
with _raises("numpy array does not contain the 'magic' header"):
|
||||
vec_npy_each(b"\x93NUMPX\x01\x00v\x00")
|
||||
|
||||
with _raises("numpy array header length is invalid"):
|
||||
vec_npy_each(b"\x93NUMPY\x01\x00v\x00")
|
||||
|
||||
with _raises("numpy header did not start with '{'"):
|
||||
vec_npy_each(
|
||||
b"\x93NUMPY\x01\x00v\x00c'descr': '<f4', 'fortran_order': False, 'shape': (2, 4), } \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
||||
)
|
||||
|
||||
with _raises("expected key in numpy header"):
|
||||
vec_npy_each(
|
||||
b"\x93NUMPY\x01\x00v\x00{ \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
||||
)
|
||||
|
||||
with _raises("expected a string as key in numpy header"):
|
||||
vec_npy_each(
|
||||
b"\x93NUMPY\x01\x00v\x00{False: '<f4', 'fortran_order': False, 'shape': (2, 4), } \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
||||
)
|
||||
|
||||
with _raises("expected a ':' after key in numpy header"):
|
||||
vec_npy_each(
|
||||
b"\x93NUMPY\x01\x00v\x00{'descr' \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
||||
)
|
||||
with _raises("expected a ':' after key in numpy header"):
|
||||
vec_npy_each(
|
||||
b"\x93NUMPY\x01\x00v\x00{'descr' False \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
||||
)
|
||||
|
||||
with _raises("expected a string value after 'descr' key"):
|
||||
vec_npy_each(
|
||||
b"\x93NUMPY\x01\x00v\x00{'descr': \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
||||
)
|
||||
|
||||
with _raises("Only '<f4' values are supported in sqlite-vec numpy functions"):
|
||||
vec_npy_each(
|
||||
b"\x93NUMPY\x01\x00v\x00{'descr': '=f4', 'fortran_order': False, 'shape': (2, 4), } \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
||||
)
|
||||
|
||||
with _raises(
|
||||
"Only fortran_order = False is supported in sqlite-vec numpy functions"
|
||||
):
|
||||
vec_npy_each(
|
||||
b"\x93NUMPY\x01\x00v\x00{'descr': '<f4', 'fortran_order': True, 'shape': (2, 4), } \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
||||
)
|
||||
|
||||
with _raises(
|
||||
"Error parsing numpy array: Expected left parenthesis '(' after shape key"
|
||||
):
|
||||
vec_npy_each(
|
||||
b"\x93NUMPY\x01\x00v\x00{'shape': 2, 'descr': '<f4', 'fortran_order': False, } \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
||||
)
|
||||
|
||||
with _raises(
|
||||
"Error parsing numpy array: Expected an initial number in shape value"
|
||||
):
|
||||
vec_npy_each(
|
||||
b"\x93NUMPY\x01\x00v\x00{'shape': (, 'descr': '<f4', 'fortran_order': False, } \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
||||
)
|
||||
|
||||
with _raises("Error parsing numpy array: Expected comma after first shape value"):
|
||||
vec_npy_each(
|
||||
b"\x93NUMPY\x01\x00v\x00{'shape': (2), 'descr': '<f4', 'fortran_order': False, } \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
||||
)
|
||||
|
||||
with _raises(
|
||||
"Error parsing numpy array: unexpected header EOF while parsing shape"
|
||||
):
|
||||
vec_npy_each(
|
||||
b"\x93NUMPY\x01\x00v\x00{'shape': (2, \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
||||
)
|
||||
|
||||
with _raises("Error parsing numpy array: unknown type in shape value"):
|
||||
vec_npy_each(
|
||||
b"\x93NUMPY\x01\x00v\x00{'shape': (2, 'nope' \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
||||
)
|
||||
|
||||
with _raises(
|
||||
"Error parsing numpy array: expected right parenthesis after shape value"
|
||||
):
|
||||
vec_npy_each(
|
||||
b"\x93NUMPY\x01\x00v\x00{'shape': (2,4 ( \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
||||
)
|
||||
|
||||
with _raises("Error parsing numpy array: unknown key in numpy header"):
|
||||
vec_npy_each(
|
||||
b"\x93NUMPY\x01\x00v\x00{'no': '<f4', 'fortran_order': False, 'shape': (2, 4), } \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
||||
)
|
||||
|
||||
with _raises("Error parsing numpy array: unknown extra token after value"):
|
||||
vec_npy_each(
|
||||
b"\x93NUMPY\x01\x00v\x00{'descr': '<f4' 'asdf', 'fortran_order': False, 'shape': (2, 4), } \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
||||
)
|
||||
|
||||
with _raises("numpy array error: Expected a data size of 32, found 31"):
|
||||
vec_npy_each(
|
||||
b"\x93NUMPY\x01\x00v\x00{'descr': '<f4', 'fortran_order': False, 'shape': (2, 4), } \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3"
|
||||
)
|
||||
|
||||
# with _raises("XXX"):
|
||||
# vec_npy_each(b"\x93NUMPY\x01\x00v\x00{'descr': '<f4', 'fortran_order': False, 'shape': (2, 4), } \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@")
|
||||
|
||||
|
||||
import tempfile
|
||||
|
||||
|
||||
def test_vec_npy_each_errors_files():
|
||||
db = connect(EXT_PATH, extra_entrypoint="sqlite3_vec_numpy_init")
|
||||
|
||||
def vec_npy_each(data):
|
||||
with tempfile.NamedTemporaryFile(delete_on_close=False) as f:
|
||||
f.write(data)
|
||||
f.close()
|
||||
try:
|
||||
return execute_all(
|
||||
db, "select rowid, * from vec_npy_each(vec_npy_file(?))", [f.name]
|
||||
)
|
||||
finally:
|
||||
f.close()
|
||||
|
||||
with _raises("Could not open numpy file"):
|
||||
db.execute('select * from vec_npy_each(vec_npy_file("not exist"))')
|
||||
|
||||
with _raises("numpy array file too short"):
|
||||
vec_npy_each(b"\x93NUMPY\x01\x00v")
|
||||
|
||||
with _raises("numpy array file does not contain the 'magic' header"):
|
||||
vec_npy_each(b"\x93XUMPY\x01\x00v\x00")
|
||||
|
||||
with _raises("numpy array file header length is invalid"):
|
||||
vec_npy_each(b"\x93NUMPY\x01\x00v\x00")
|
||||
|
||||
with _raises(
|
||||
"Error parsing numpy array: Only fortran_order = False is supported in sqlite-vec numpy functions"
|
||||
):
|
||||
vec_npy_each(
|
||||
b"\x93NUMPY\x01\x00v\x00{'descr': '<f4', 'fortran_order': True, 'shape': (2, 4), } \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
||||
)
|
||||
|
||||
with _raises("numpy array file error: Expected a data size of 32, found 31"):
|
||||
vec_npy_each(
|
||||
b"\x93NUMPY\x01\x00v\x00{'descr': '<f4', 'fortran_order': False, 'shape': (2, 4), } \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3"
|
||||
)
|
||||
|
||||
assert vec_npy_each(to_npy(np.array([1.1, 2.2, 3.3], dtype=np.float32))) == [
|
||||
{
|
||||
"rowid": 0,
|
||||
"vector": _f32([1.1, 2.2, 3.3]),
|
||||
},
|
||||
]
|
||||
assert vec_npy_each(
|
||||
to_npy(np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32))
|
||||
) == [
|
||||
{
|
||||
"rowid": 0,
|
||||
"vector": _f32([1.1, 2.2, 3.3]),
|
||||
},
|
||||
{
|
||||
"rowid": 1,
|
||||
"vector": _f32([4.4, 5.5, 6.6]),
|
||||
},
|
||||
]
|
||||
assert vec_npy_each(to_npy(np.array([], dtype=np.float32))) == []
|
||||
x1025 = vec_npy_each(to_npy(np.array([[0.1, 0.2, 0.3]] * 1025, dtype=np.float32)))
|
||||
assert len(x1025) == 1025
|
||||
|
||||
# np.array([[.1, .2, 3]] * 99, dtype=np.float32).shape
|
||||
|
||||
|
||||
def test_vec0_constructor():
|
||||
vec_constructor_error_prefix = "vec0 constructor error: {}"
|
||||
vec_col_error_prefix = "vec0 constructor error: could not parse vector column '{}'"
|
||||
|
|
@ -1923,6 +1590,54 @@ def test_vec0_constructor():
|
|||
db.execute("create virtual table v using vec0(4)")
|
||||
|
||||
|
||||
def test_vec0_indexed_by_flat():
|
||||
db.execute("drop table if exists t_ibf")
|
||||
db.execute("drop table if exists t_ibf2")
|
||||
db.execute("drop table if exists t_ibf3")
|
||||
db.execute("drop table if exists t_ibf4")
|
||||
|
||||
# indexed by flat() should succeed and behave identically to no index clause
|
||||
db.execute("create virtual table t_ibf using vec0(emb float[4] indexed by flat())")
|
||||
db.execute(
|
||||
"insert into t_ibf(rowid, emb) values (1, X'00000000000000000000000000000000')"
|
||||
)
|
||||
rows = db.execute("select rowid from t_ibf where emb match X'00000000000000000000000000000000' and k = 1").fetchall()
|
||||
assert len(rows) == 1
|
||||
assert rows[0][0] == 1
|
||||
db.execute("drop table t_ibf")
|
||||
|
||||
# indexed by flat() with distance_metric
|
||||
db.execute(
|
||||
"create virtual table t_ibf2 using vec0(emb float[4] distance_metric=cosine indexed by flat())"
|
||||
)
|
||||
db.execute("drop table t_ibf2")
|
||||
|
||||
# indexed by flat() on int8
|
||||
db.execute("create virtual table t_ibf3 using vec0(emb int8[4] indexed by flat())")
|
||||
db.execute("drop table t_ibf3")
|
||||
|
||||
# indexed by flat() on bit
|
||||
db.execute("create virtual table t_ibf4 using vec0(emb bit[8] indexed by flat())")
|
||||
db.execute("drop table t_ibf4")
|
||||
|
||||
# Error: unknown index type
|
||||
with _raises(
|
||||
"vec0 constructor error: could not parse vector column 'emb float[4] indexed by unknown()'",
|
||||
sqlite3.DatabaseError,
|
||||
):
|
||||
db.execute("create virtual table v using vec0(emb float[4] indexed by unknown())")
|
||||
|
||||
# Error: indexed by (missing type)
|
||||
with _raises(
|
||||
"vec0 constructor error: could not parse vector column 'emb float[4] indexed by'",
|
||||
sqlite3.DatabaseError,
|
||||
):
|
||||
db.execute("create virtual table v using vec0(emb float[4] indexed by)")
|
||||
|
||||
if db.in_transaction:
|
||||
db.rollback()
|
||||
|
||||
|
||||
def test_vec0_create_errors():
|
||||
# EVIDENCE-OF: V17740_01811 vec0 create _chunks error handling
|
||||
db.set_authorizer(authorizer_deny_on(sqlite3.SQLITE_CREATE_TABLE, "t1_chunks"))
|
||||
|
|
|
|||
|
|
@ -265,6 +265,35 @@ def test_deletes(db, snapshot):
|
|||
assert vec0_shadow_table_contents(db, "v") == snapshot()
|
||||
|
||||
|
||||
def test_delete_by_metadata_with_long_text(db):
|
||||
"""Regression for https://github.com/asg017/sqlite-vec/issues/274.
|
||||
|
||||
ClearMetadata left rc=SQLITE_DONE after the long-text DELETE, which
|
||||
propagated as an error and silently aborted the DELETE scan.
|
||||
"""
|
||||
db.execute(
|
||||
"create virtual table v using vec0("
|
||||
" tag text, embedding float[4], chunk_size=8"
|
||||
")"
|
||||
)
|
||||
for i in range(6):
|
||||
db.execute(
|
||||
"insert into v(tag, embedding) values (?, zeroblob(16))",
|
||||
[f"long_text_value_{i}"],
|
||||
)
|
||||
for i in range(4):
|
||||
db.execute(
|
||||
"insert into v(tag, embedding) values (?, zeroblob(16))",
|
||||
[f"long_text_value_0"],
|
||||
)
|
||||
assert db.execute("select count(*) from v").fetchone()[0] == 10
|
||||
|
||||
# DELETE by metadata WHERE — the pattern from the issue
|
||||
db.execute("delete from v where tag = 'long_text_value_0'")
|
||||
assert db.execute("select count(*) from v where tag = 'long_text_value_0'").fetchone()[0] == 0
|
||||
assert db.execute("select count(*) from v").fetchone()[0] == 5
|
||||
|
||||
|
||||
def test_knn(db, snapshot):
|
||||
db.execute(
|
||||
"create virtual table v using vec0(vector float[1], name text, chunk_size=8)"
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue