mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 16:56:27 +02:00
Compare commits
47 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5778fecfeb | ||
|
|
4e2dfcb79d | ||
|
|
b95c05b3aa | ||
|
|
89f6203536 | ||
|
|
c5607100ab | ||
|
|
6e2c4c6bab | ||
|
|
b7fc459be4 | ||
|
|
01b4b2a965 | ||
|
|
c36a995f1e | ||
|
|
c4c23bd8ba | ||
|
|
5522e86cd2 | ||
|
|
f2c9fb8f08 | ||
|
|
d684178a12 | ||
|
|
d033bf5728 | ||
|
|
b00865429b | ||
|
|
2f4c2e4bdb | ||
|
|
7de925be70 | ||
|
|
4bee88384b | ||
|
|
5e4c557f93 | ||
|
|
82f4eb08bf | ||
|
|
9df59b4c03 | ||
|
|
07f56e3cbe | ||
|
|
3cfc2e0c1f | ||
|
|
85cf415397 | ||
|
|
85973b3814 | ||
|
|
8544081a67 | ||
|
|
a248ecd061 | ||
|
|
1e3bb3e5e3 | ||
|
|
fb81c011ff | ||
|
|
575371d751 | ||
|
|
e2c38f387c | ||
|
|
bb3ef78f75 | ||
|
|
3e26925ce0 | ||
|
|
3358e127f6 | ||
|
|
43982c144b | ||
|
|
45d1375602 | ||
|
|
69ccb2405a | ||
|
|
0de765f457 | ||
|
|
e9f598abfa | ||
|
|
6c3bf3669f | ||
|
|
69f7b658e9 | ||
|
|
ee9bd2ba4d | ||
|
|
ba0db0b6d6 | ||
|
|
bf2455f2ba | ||
|
|
dfd8dc5290 | ||
|
|
e7ae41b761 | ||
|
|
a8d81cb235 |
105 changed files with 21725 additions and 2248 deletions
2
.github/workflows/release.yaml
vendored
2
.github/workflows/release.yaml
vendored
|
|
@ -252,7 +252,7 @@ jobs:
|
||||||
name: sqlite-vec-iossimulator-x86_64-extension
|
name: sqlite-vec-iossimulator-x86_64-extension
|
||||||
path: dist/iossimulator-x86_64
|
path: dist/iossimulator-x86_64
|
||||||
- run: make sqlite-vec.h
|
- run: make sqlite-vec.h
|
||||||
- uses: asg017/setup-sqlite-dist@73e37b2ffb0b51e64a64eb035da38c958b9ff6c6
|
- uses: asg017/setup-sqlite-dist@fadb0183a6ec70c3f1942de7d232b087ff2bacd1
|
||||||
- run: sqlite-dist build --set-version $(cat VERSION)
|
- run: sqlite-dist build --set-version $(cat VERSION)
|
||||||
- run: |
|
- run: |
|
||||||
gh release upload ${{ github.ref_name }} \
|
gh release upload ${{ github.ref_name }} \
|
||||||
|
|
|
||||||
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -31,3 +31,6 @@ poetry.lock
|
||||||
|
|
||||||
memstat.c
|
memstat.c
|
||||||
memstat.*
|
memstat.*
|
||||||
|
|
||||||
|
|
||||||
|
.DS_Store
|
||||||
35
Makefile
35
Makefile
|
|
@ -37,11 +37,18 @@ endif
|
||||||
|
|
||||||
ifndef OMIT_SIMD
|
ifndef OMIT_SIMD
|
||||||
ifeq ($(shell uname -sm),Darwin x86_64)
|
ifeq ($(shell uname -sm),Darwin x86_64)
|
||||||
CFLAGS += -mavx -DSQLITE_VEC_ENABLE_AVX
|
CFLAGS += -mavx -mavx2 -DSQLITE_VEC_ENABLE_AVX
|
||||||
endif
|
endif
|
||||||
ifeq ($(shell uname -sm),Darwin arm64)
|
ifeq ($(shell uname -sm),Darwin arm64)
|
||||||
CFLAGS += -mcpu=apple-m1 -DSQLITE_VEC_ENABLE_NEON
|
CFLAGS += -mcpu=apple-m1 -DSQLITE_VEC_ENABLE_NEON
|
||||||
endif
|
endif
|
||||||
|
ifeq ($(shell uname -s),Linux)
|
||||||
|
ifeq ($(findstring android,$(CC)),)
|
||||||
|
ifneq ($(filter avx,$(shell grep -o 'avx[^ ]*' /proc/cpuinfo 2>/dev/null | head -1)),)
|
||||||
|
CFLAGS += -mavx -mavx2 -DSQLITE_VEC_ENABLE_AVX
|
||||||
|
endif
|
||||||
|
endif
|
||||||
|
endif
|
||||||
endif
|
endif
|
||||||
|
|
||||||
ifdef USE_BREW_SQLITE
|
ifdef USE_BREW_SQLITE
|
||||||
|
|
@ -155,6 +162,13 @@ clean:
|
||||||
rm -rf dist
|
rm -rf dist
|
||||||
|
|
||||||
|
|
||||||
|
TARGET_AMALGAMATION=$(prefix)/sqlite-vec.c
|
||||||
|
|
||||||
|
amalgamation: $(TARGET_AMALGAMATION)
|
||||||
|
|
||||||
|
$(TARGET_AMALGAMATION): sqlite-vec.c $(wildcard sqlite-vec-*.c) scripts/amalgamate.py $(prefix)
|
||||||
|
python3 scripts/amalgamate.py sqlite-vec.c > $@
|
||||||
|
|
||||||
FORMAT_FILES=sqlite-vec.h sqlite-vec.c
|
FORMAT_FILES=sqlite-vec.h sqlite-vec.c
|
||||||
format: $(FORMAT_FILES)
|
format: $(FORMAT_FILES)
|
||||||
clang-format -i $(FORMAT_FILES)
|
clang-format -i $(FORMAT_FILES)
|
||||||
|
|
@ -174,7 +188,7 @@ evidence-of:
|
||||||
test:
|
test:
|
||||||
sqlite3 :memory: '.read test.sql'
|
sqlite3 :memory: '.read test.sql'
|
||||||
|
|
||||||
.PHONY: version loadable static test clean gh-release evidence-of install uninstall
|
.PHONY: version loadable static test clean gh-release evidence-of install uninstall amalgamation
|
||||||
|
|
||||||
publish-release:
|
publish-release:
|
||||||
./scripts/publish-release.sh
|
./scripts/publish-release.sh
|
||||||
|
|
@ -190,7 +204,22 @@ test-loadable-watch:
|
||||||
watchexec --exts c,py,Makefile --clear -- make test-loadable
|
watchexec --exts c,py,Makefile --clear -- make test-loadable
|
||||||
|
|
||||||
test-unit:
|
test-unit:
|
||||||
$(CC) -DSQLITE_CORE -DSQLITE_VEC_TEST tests/test-unit.c sqlite-vec.c vendor/sqlite3.c -I./ -Ivendor -o $(prefix)/test-unit && $(prefix)/test-unit
|
$(CC) -DSQLITE_CORE -DSQLITE_VEC_TEST -DSQLITE_VEC_ENABLE_RESCORE -DSQLITE_VEC_ENABLE_DISKANN=1 tests/test-unit.c sqlite-vec.c vendor/sqlite3.c -I./ -Ivendor $(CFLAGS) -o $(prefix)/test-unit && $(prefix)/test-unit
|
||||||
|
|
||||||
|
# Standalone sqlite3 CLI with vec0 compiled in. Useful for benchmarking,
|
||||||
|
# profiling (has debug symbols), and scripting without .load_extension.
|
||||||
|
# make cli
|
||||||
|
# dist/sqlite3 :memory: "SELECT vec_version()"
|
||||||
|
# dist/sqlite3 < script.sql
|
||||||
|
cli: sqlite-vec.h $(prefix)
|
||||||
|
$(CC) -O2 -g \
|
||||||
|
-DSQLITE_CORE \
|
||||||
|
-DSQLITE_EXTRA_INIT=core_init \
|
||||||
|
-DSQLITE_THREADSAFE=0 \
|
||||||
|
-Ivendor/ -I./ \
|
||||||
|
$(CFLAGS) \
|
||||||
|
vendor/sqlite3.c vendor/shell.c sqlite-vec.c examples/sqlite3-cli/core_init.c \
|
||||||
|
-ldl -lm -o $(prefix)/sqlite3
|
||||||
|
|
||||||
fuzz-build:
|
fuzz-build:
|
||||||
$(MAKE) -C tests/fuzz all
|
$(MAKE) -C tests/fuzz all
|
||||||
|
|
|
||||||
73
TODO.md
Normal file
73
TODO.md
Normal file
|
|
@ -0,0 +1,73 @@
|
||||||
|
# TODO: `ann` base branch + consolidated benchmarks
|
||||||
|
|
||||||
|
## 1. Create `ann` branch with shared code
|
||||||
|
|
||||||
|
### 1.1 Branch setup
|
||||||
|
- [x] `git checkout -B ann origin/main`
|
||||||
|
- [x] Cherry-pick `624f998` (vec0_distance_full shared distance dispatch)
|
||||||
|
- [x] Cherry-pick stdint.h fix for test header
|
||||||
|
- [ ] Pull NEON cosine optimization from ivf-yolo3 into shared code
|
||||||
|
- Currently only in ivf branch but is general-purpose (benefits all distance calcs)
|
||||||
|
- Lives in `distance_cosine_float()` — ~57 lines of ARM NEON vectorized cosine
|
||||||
|
|
||||||
|
### 1.2 Benchmark infrastructure (`benchmarks-ann/`)
|
||||||
|
- [x] Seed data pipeline (`seed/Makefile`, `seed/build_base_db.py`)
|
||||||
|
- [x] Ground truth generator (`ground_truth.py`)
|
||||||
|
- [x] Results schema (`schema.sql`)
|
||||||
|
- [x] Benchmark runner with `INDEX_REGISTRY` extension point (`bench.py`)
|
||||||
|
- Baseline configs (float, int8-rescore, bit-rescore) implemented
|
||||||
|
- Index branches register their types via `INDEX_REGISTRY` dict
|
||||||
|
- [x] Makefile with baseline targets
|
||||||
|
- [x] README
|
||||||
|
|
||||||
|
### 1.3 Rebase feature branches onto `ann`
|
||||||
|
- [x] Rebase `diskann-yolo2` onto `ann` (1 commit: DiskANN implementation)
|
||||||
|
- [x] Rebase `ivf-yolo3` onto `ann` (1 commit: IVF implementation)
|
||||||
|
- [x] Rebase `annoy-yolo2` onto `ann` (2 commits: Annoy implementation + schema fix)
|
||||||
|
- [x] Verify each branch has only its index-specific commits remaining
|
||||||
|
- [ ] Force-push all 4 branches to origin
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. Per-branch: register index type in benchmarks
|
||||||
|
|
||||||
|
Each index branch should add to `benchmarks-ann/` when rebased onto `ann`:
|
||||||
|
|
||||||
|
### 2.1 Register in `bench.py`
|
||||||
|
|
||||||
|
Add an `INDEX_REGISTRY` entry. Each entry provides:
|
||||||
|
- `defaults` — default param values
|
||||||
|
- `create_table_sql(params)` — CREATE VIRTUAL TABLE with INDEXED BY clause
|
||||||
|
- `insert_sql(params)` — custom insert SQL, or None for default
|
||||||
|
- `post_insert_hook(conn, params)` — training/building step, returns time
|
||||||
|
- `run_query(conn, params, query, k)` — custom query, or None for default MATCH
|
||||||
|
- `describe(params)` — one-line description for report output
|
||||||
|
|
||||||
|
### 2.2 Add configs to `Makefile`
|
||||||
|
|
||||||
|
Append index-specific config variables and targets. Example pattern:
|
||||||
|
|
||||||
|
```makefile
|
||||||
|
DISKANN_CONFIGS = \
|
||||||
|
"diskann-R48-binary:type=diskann,R=48,L=128,quantizer=binary" \
|
||||||
|
...
|
||||||
|
|
||||||
|
ALL_CONFIGS += $(DISKANN_CONFIGS)
|
||||||
|
|
||||||
|
bench-diskann: seed
|
||||||
|
$(BENCH) --subset-size 10000 -k 10 -o runs/diskann $(BASELINES) $(DISKANN_CONFIGS)
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2.3 Migrate existing benchmark results/docs
|
||||||
|
|
||||||
|
- Move useful results docs (RESULTS.md, etc.) into `benchmarks-ann/results/`
|
||||||
|
- Delete redundant per-branch benchmark directories once consolidated infra is proven
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. Future improvements
|
||||||
|
|
||||||
|
- [ ] Reporting script (`report.py`) — query results.db, produce markdown comparison tables
|
||||||
|
- [ ] Profiling targets in Makefile (lift from ivf-yolo3's Instruments/perf wrappers)
|
||||||
|
- [ ] Pre-computed ground truth integration (use GT DB files instead of on-the-fly brute-force)
|
||||||
2
VERSION
2
VERSION
|
|
@ -1 +1 @@
|
||||||
0.1.7
|
0.1.10-alpha.3
|
||||||
8
benchmarks-ann/.gitignore
vendored
Normal file
8
benchmarks-ann/.gitignore
vendored
Normal file
|
|
@ -0,0 +1,8 @@
|
||||||
|
*.db
|
||||||
|
*.db-shm
|
||||||
|
*.db-wal
|
||||||
|
*.parquet
|
||||||
|
runs/
|
||||||
|
|
||||||
|
viewer/
|
||||||
|
searcher/
|
||||||
85
benchmarks-ann/Makefile
Normal file
85
benchmarks-ann/Makefile
Normal file
|
|
@ -0,0 +1,85 @@
|
||||||
|
BENCH = python bench.py
|
||||||
|
BASE_DB = cohere1m/base.db
|
||||||
|
EXT = ../dist/vec0
|
||||||
|
|
||||||
|
# --- Baseline (brute-force) configs ---
|
||||||
|
BASELINES = \
|
||||||
|
"brute-float:type=vec0-flat,variant=float" \
|
||||||
|
"brute-int8:type=vec0-flat,variant=int8" \
|
||||||
|
"brute-bit:type=vec0-flat,variant=bit"
|
||||||
|
|
||||||
|
# --- IVF configs ---
|
||||||
|
IVF_CONFIGS = \
|
||||||
|
"ivf-n32-p8:type=ivf,nlist=32,nprobe=8" \
|
||||||
|
"ivf-n128-p16:type=ivf,nlist=128,nprobe=16" \
|
||||||
|
"ivf-n512-p32:type=ivf,nlist=512,nprobe=32"
|
||||||
|
|
||||||
|
RESCORE_CONFIGS = \
|
||||||
|
"rescore-bit-os8:type=rescore,quantizer=bit,oversample=8" \
|
||||||
|
"rescore-bit-os16:type=rescore,quantizer=bit,oversample=16" \
|
||||||
|
"rescore-int8-os8:type=rescore,quantizer=int8,oversample=8"
|
||||||
|
|
||||||
|
# --- DiskANN configs ---
|
||||||
|
DISKANN_CONFIGS = \
|
||||||
|
"diskann-R48-binary:type=diskann,R=48,L=128,quantizer=binary" \
|
||||||
|
"diskann-R72-binary:type=diskann,R=72,L=128,quantizer=binary" \
|
||||||
|
"diskann-R72-int8:type=diskann,R=72,L=128,quantizer=int8" \
|
||||||
|
"diskann-R72-L256:type=diskann,R=72,L=256,quantizer=binary"
|
||||||
|
|
||||||
|
ALL_CONFIGS = $(BASELINES) $(RESCORE_CONFIGS) $(IVF_CONFIGS) $(DISKANN_CONFIGS)
|
||||||
|
|
||||||
|
.PHONY: seed ground-truth bench-smoke bench-rescore bench-ivf bench-diskann bench-10k bench-50k bench-100k bench-all \
|
||||||
|
report clean
|
||||||
|
|
||||||
|
# --- Data preparation ---
|
||||||
|
seed:
|
||||||
|
$(MAKE) -C cohere1m
|
||||||
|
|
||||||
|
ground-truth: seed
|
||||||
|
python ground_truth.py --subset-size 10000
|
||||||
|
python ground_truth.py --subset-size 50000
|
||||||
|
python ground_truth.py --subset-size 100000
|
||||||
|
|
||||||
|
# --- Quick smoke test ---
|
||||||
|
bench-smoke: seed
|
||||||
|
$(BENCH) --subset-size 5000 -k 10 -n 20 --dataset cohere1m -o runs \
|
||||||
|
"brute-float:type=vec0-flat,variant=float" \
|
||||||
|
"ivf-quick:type=ivf,nlist=16,nprobe=4" \
|
||||||
|
"diskann-quick:type=diskann,R=48,L=64,quantizer=binary"
|
||||||
|
|
||||||
|
bench-rescore: seed
|
||||||
|
$(BENCH) --subset-size 10000 -k 10 --dataset cohere1m -o runs \
|
||||||
|
$(RESCORE_CONFIGS)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Standard sizes ---
|
||||||
|
bench-10k: seed
|
||||||
|
$(BENCH) --subset-size 10000 -k 10 --dataset cohere1m -o runs $(ALL_CONFIGS)
|
||||||
|
|
||||||
|
bench-50k: seed
|
||||||
|
$(BENCH) --subset-size 50000 -k 10 --dataset cohere1m -o runs $(ALL_CONFIGS)
|
||||||
|
|
||||||
|
bench-100k: seed
|
||||||
|
$(BENCH) --subset-size 100000 -k 10 --dataset cohere1m -o runs $(ALL_CONFIGS)
|
||||||
|
|
||||||
|
bench-all: bench-10k bench-50k bench-100k
|
||||||
|
|
||||||
|
# --- IVF across sizes ---
|
||||||
|
bench-ivf: seed
|
||||||
|
$(BENCH) --subset-size 10000 -k 10 --dataset cohere1m -o runs $(BASELINES) $(IVF_CONFIGS)
|
||||||
|
$(BENCH) --subset-size 50000 -k 10 --dataset cohere1m -o runs $(BASELINES) $(IVF_CONFIGS)
|
||||||
|
$(BENCH) --subset-size 100000 -k 10 --dataset cohere1m -o runs $(BASELINES) $(IVF_CONFIGS)
|
||||||
|
|
||||||
|
# --- DiskANN across sizes ---
|
||||||
|
bench-diskann: seed
|
||||||
|
$(BENCH) --subset-size 10000 -k 10 --dataset cohere1m -o runs $(BASELINES) $(DISKANN_CONFIGS)
|
||||||
|
$(BENCH) --subset-size 50000 -k 10 --dataset cohere1m -o runs $(BASELINES) $(DISKANN_CONFIGS)
|
||||||
|
$(BENCH) --subset-size 100000 -k 10 --dataset cohere1m -o runs $(BASELINES) $(DISKANN_CONFIGS)
|
||||||
|
|
||||||
|
# --- Report ---
|
||||||
|
report:
|
||||||
|
@echo "Use: sqlite3 runs/cohere1m/<size>/results.db 'SELECT run_id, config_name, status, recall FROM runs JOIN run_results USING(run_id)'"
|
||||||
|
|
||||||
|
# --- Cleanup ---
|
||||||
|
clean:
|
||||||
|
rm -rf runs/
|
||||||
111
benchmarks-ann/README.md
Normal file
111
benchmarks-ann/README.md
Normal file
|
|
@ -0,0 +1,111 @@
|
||||||
|
# KNN Benchmarks for sqlite-vec
|
||||||
|
|
||||||
|
Benchmarking infrastructure for vec0 KNN configurations. Includes brute-force
|
||||||
|
baselines (float, int8, bit), rescore, IVF, and DiskANN index types.
|
||||||
|
|
||||||
|
## Datasets
|
||||||
|
|
||||||
|
Each dataset is a subdirectory containing a `Makefile` and `build_base_db.py`
|
||||||
|
that produce a `base.db`. The benchmark runner auto-discovers any subdirectory
|
||||||
|
with a `base.db` file.
|
||||||
|
|
||||||
|
```
|
||||||
|
cohere1m/ # Cohere 768d cosine, 1M vectors
|
||||||
|
Makefile # downloads parquets from Zilliz, builds base.db
|
||||||
|
build_base_db.py
|
||||||
|
base.db # (generated)
|
||||||
|
|
||||||
|
cohere10m/ # Cohere 768d cosine, 10M vectors (10 train shards)
|
||||||
|
Makefile # make -j12 download to fetch all shards in parallel
|
||||||
|
build_base_db.py
|
||||||
|
base.db # (generated)
|
||||||
|
```
|
||||||
|
|
||||||
|
Every `base.db` has the same schema:
|
||||||
|
|
||||||
|
| Table | Columns | Description |
|
||||||
|
|-------|---------|-------------|
|
||||||
|
| `train` | `id INTEGER PRIMARY KEY, vector BLOB` | Indexed vectors (f32 blobs) |
|
||||||
|
| `query_vectors` | `id INTEGER PRIMARY KEY, vector BLOB` | Query vectors for KNN evaluation |
|
||||||
|
| `neighbors` | `query_vector_id INTEGER, rank INTEGER, neighbors_id TEXT` | Ground-truth nearest neighbors |
|
||||||
|
|
||||||
|
To add a new dataset, create a directory with a `Makefile` that builds `base.db`
|
||||||
|
with the tables above. It will be available via `--dataset <dirname>` automatically.
|
||||||
|
|
||||||
|
### Building datasets
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Cohere 1M
|
||||||
|
cd cohere1m && make download && make && cd ..
|
||||||
|
|
||||||
|
# Cohere 10M (parallel download recommended — 10 train shards + test + neighbors)
|
||||||
|
cd cohere10m && make -j12 download && make && cd ..
|
||||||
|
```
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
- Built `dist/vec0` extension (run `make loadable` from repo root)
|
||||||
|
- Python 3.10+
|
||||||
|
- `uv`
|
||||||
|
|
||||||
|
## Quick start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1. Build a dataset
|
||||||
|
cd cohere1m && make && cd ..
|
||||||
|
|
||||||
|
# 2. Quick smoke test (5k vectors)
|
||||||
|
make bench-smoke
|
||||||
|
|
||||||
|
# 3. Full benchmark at 10k
|
||||||
|
make bench-10k
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run python bench.py --subset-size 10000 -k 10 -n 50 --dataset cohere1m \
|
||||||
|
"brute-float:type=baseline,variant=float" \
|
||||||
|
"rescore-bit-os8:type=rescore,quantizer=bit,oversample=8"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Config format
|
||||||
|
|
||||||
|
`name:type=<index_type>,key=val,key=val`
|
||||||
|
|
||||||
|
| Index type | Keys |
|
||||||
|
|-----------|------|
|
||||||
|
| `baseline` | `variant` (float/int8/bit), `oversample` |
|
||||||
|
| `rescore` | `quantizer` (bit/int8), `oversample` |
|
||||||
|
| `ivf` | `nlist`, `nprobe` |
|
||||||
|
| `diskann` | `R`, `L`, `quantizer` (binary/int8), `buffer_threshold` |
|
||||||
|
|
||||||
|
### Make targets
|
||||||
|
|
||||||
|
| Target | Description |
|
||||||
|
|--------|-------------|
|
||||||
|
| `make seed` | Download and build default dataset |
|
||||||
|
| `make bench-smoke` | Quick 5k test (3 configs) |
|
||||||
|
| `make bench-10k` | All configs at 10k vectors |
|
||||||
|
| `make bench-50k` | All configs at 50k vectors |
|
||||||
|
| `make bench-100k` | All configs at 100k vectors |
|
||||||
|
| `make bench-all` | 10k + 50k + 100k |
|
||||||
|
| `make bench-ivf` | Baselines + IVF across 10k/50k/100k |
|
||||||
|
| `make bench-diskann` | Baselines + DiskANN across 10k/50k/100k |
|
||||||
|
|
||||||
|
## Results DB
|
||||||
|
|
||||||
|
Each run writes to `runs/<dataset>/<subset_size>/results.db` (SQLite, WAL mode).
|
||||||
|
Progress is written continuously — query from another terminal to monitor:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sqlite3 runs/cohere1m/10000/results.db "SELECT run_id, config_name, status FROM runs"
|
||||||
|
```
|
||||||
|
|
||||||
|
See `results_schema.sql` for the full schema (tables: `runs`, `run_results`,
|
||||||
|
`insert_batches`, `queries`).
|
||||||
|
|
||||||
|
## Adding an index type
|
||||||
|
|
||||||
|
Add an entry to `INDEX_REGISTRY` in `bench.py` and append configs to
|
||||||
|
`ALL_CONFIGS` in the `Makefile`. See existing entries for the pattern.
|
||||||
3
benchmarks-ann/bench-delete/.gitignore
vendored
Normal file
3
benchmarks-ann/bench-delete/.gitignore
vendored
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
runs/
|
||||||
|
*.db
|
||||||
|
__pycache__/
|
||||||
41
benchmarks-ann/bench-delete/Makefile
Normal file
41
benchmarks-ann/bench-delete/Makefile
Normal file
|
|
@ -0,0 +1,41 @@
|
||||||
|
BENCH = python bench_delete.py
|
||||||
|
EXT = ../../dist/vec0
|
||||||
|
|
||||||
|
# --- Configs to test ---
|
||||||
|
FLAT = "flat:type=vec0-flat,variant=float"
|
||||||
|
RESCORE_BIT = "rescore-bit:type=rescore,quantizer=bit,oversample=8"
|
||||||
|
RESCORE_INT8 = "rescore-int8:type=rescore,quantizer=int8,oversample=8"
|
||||||
|
DISKANN_R48 = "diskann-R48:type=diskann,R=48,L=128,quantizer=binary"
|
||||||
|
DISKANN_R72 = "diskann-R72:type=diskann,R=72,L=128,quantizer=binary"
|
||||||
|
|
||||||
|
ALL_CONFIGS = $(FLAT) $(RESCORE_BIT) $(RESCORE_INT8) $(DISKANN_R48) $(DISKANN_R72)
|
||||||
|
|
||||||
|
DELETE_PCTS = 5,10,25,50,75,90
|
||||||
|
|
||||||
|
.PHONY: smoke bench-10k bench-50k bench-all report clean
|
||||||
|
|
||||||
|
# Quick smoke test (small dataset, few queries)
|
||||||
|
smoke:
|
||||||
|
$(BENCH) --subset-size 5000 --delete-pct 10,50 -k 10 -n 20 \
|
||||||
|
--dataset cohere1m --ext $(EXT) \
|
||||||
|
$(FLAT) $(DISKANN_R48)
|
||||||
|
|
||||||
|
# Standard benchmarks
|
||||||
|
bench-10k:
|
||||||
|
$(BENCH) --subset-size 10000 --delete-pct $(DELETE_PCTS) -k 10 -n 50 \
|
||||||
|
--dataset cohere1m --ext $(EXT) $(ALL_CONFIGS)
|
||||||
|
|
||||||
|
bench-50k:
|
||||||
|
$(BENCH) --subset-size 50000 --delete-pct $(DELETE_PCTS) -k 10 -n 50 \
|
||||||
|
--dataset cohere1m --ext $(EXT) $(ALL_CONFIGS)
|
||||||
|
|
||||||
|
bench-all: bench-10k bench-50k
|
||||||
|
|
||||||
|
# Query saved results
|
||||||
|
report:
|
||||||
|
@echo "Query results:"
|
||||||
|
@echo " sqlite3 runs/cohere1m/10000/delete_results.db \\"
|
||||||
|
@echo " \"SELECT config_name, delete_pct, recall, query_mean_ms, vacuum_size_mb FROM delete_runs ORDER BY config_name, delete_pct\""
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -rf runs/
|
||||||
69
benchmarks-ann/bench-delete/README.md
Normal file
69
benchmarks-ann/bench-delete/README.md
Normal file
|
|
@ -0,0 +1,69 @@
|
||||||
|
# bench-delete: Recall degradation after random deletion
|
||||||
|
|
||||||
|
Measures how KNN recall changes after deleting a random percentage of rows
|
||||||
|
from different index types (flat, rescore, DiskANN).
|
||||||
|
|
||||||
|
## Quick start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Ensure dataset exists
|
||||||
|
make -C ../datasets/cohere1m
|
||||||
|
|
||||||
|
# Ensure extension is built
|
||||||
|
make -C ../.. loadable
|
||||||
|
|
||||||
|
# Quick smoke test
|
||||||
|
make smoke
|
||||||
|
|
||||||
|
# Full benchmark at 10k vectors
|
||||||
|
make bench-10k
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python bench_delete.py --subset-size 10000 --delete-pct 10,25,50,75 \
|
||||||
|
"flat:type=vec0-flat,variant=float" \
|
||||||
|
"diskann-R72:type=diskann,R=72,L=128,quantizer=binary" \
|
||||||
|
"rescore-bit:type=rescore,quantizer=bit,oversample=8"
|
||||||
|
```
|
||||||
|
|
||||||
|
## What it measures
|
||||||
|
|
||||||
|
For each config and delete percentage:
|
||||||
|
|
||||||
|
| Metric | Description |
|
||||||
|
|--------|-------------|
|
||||||
|
| **recall** | KNN recall@k after deletion (ground truth recomputed over surviving rows) |
|
||||||
|
| **delta** | Recall change vs 0% baseline |
|
||||||
|
| **query latency** | Mean/median query time after deletion |
|
||||||
|
| **db_size_mb** | DB file size before VACUUM |
|
||||||
|
| **vacuum_size_mb** | DB file size after VACUUM (space reclaimed) |
|
||||||
|
| **delete_time_s** | Wall time for the DELETE operations |
|
||||||
|
|
||||||
|
## How it works
|
||||||
|
|
||||||
|
1. Build index with N vectors (one copy per config)
|
||||||
|
2. Measure recall at k=10 (pre-delete baseline)
|
||||||
|
3. For each delete %:
|
||||||
|
- Copy the master DB
|
||||||
|
- Delete a random selection of rows (deterministic seed)
|
||||||
|
- Measure recall (ground truth recomputed over surviving rows only)
|
||||||
|
- VACUUM and measure size savings
|
||||||
|
4. Print comparison table
|
||||||
|
|
||||||
|
## Expected behavior
|
||||||
|
|
||||||
|
- **Flat index**: Recall should be 1.0 at all delete percentages (brute-force is always exact)
|
||||||
|
- **Rescore**: Recall should stay close to baseline (quantized scan + rescore is robust)
|
||||||
|
- **DiskANN**: Recall may degrade at high delete % due to graph fragmentation (dangling edges, broken connectivity)
|
||||||
|
|
||||||
|
## Results DB
|
||||||
|
|
||||||
|
Results are stored in `runs/<dataset>/<subset_size>/delete_results.db`:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
SELECT config_name, delete_pct, recall, vacuum_size_mb
|
||||||
|
FROM delete_runs
|
||||||
|
ORDER BY config_name, delete_pct;
|
||||||
|
```
|
||||||
593
benchmarks-ann/bench-delete/bench_delete.py
Normal file
593
benchmarks-ann/bench-delete/bench_delete.py
Normal file
|
|
@ -0,0 +1,593 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Benchmark: measure recall degradation after random row deletion.
|
||||||
|
|
||||||
|
Given a dataset and index config, this script:
|
||||||
|
1. Builds the index (flat + ANN)
|
||||||
|
2. Measures recall at k=10 (pre-delete baseline)
|
||||||
|
3. Deletes a random % of rows
|
||||||
|
4. Measures recall again (post-delete)
|
||||||
|
5. Records DB size before/after deletion, recall delta, timings
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python bench_delete.py --subset-size 10000 --delete-pct 25 \
|
||||||
|
"diskann-R48:type=diskann,R=48,L=128,quantizer=binary"
|
||||||
|
|
||||||
|
# Multiple delete percentages in one run:
|
||||||
|
python bench_delete.py --subset-size 10000 --delete-pct 10,25,50,75 \
|
||||||
|
"diskann-R48:type=diskann,R=48,L=128,quantizer=binary"
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import shutil
|
||||||
|
import sqlite3
|
||||||
|
import statistics
|
||||||
|
import struct
|
||||||
|
import time
|
||||||
|
|
||||||
|
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
_BENCH_DIR = os.path.join(_SCRIPT_DIR, "..")
|
||||||
|
_ROOT_DIR = os.path.join(_BENCH_DIR, "..")
|
||||||
|
|
||||||
|
EXT_PATH = os.path.join(_ROOT_DIR, "dist", "vec0")
|
||||||
|
DATASETS_DIR = os.path.join(_BENCH_DIR, "datasets")
|
||||||
|
|
||||||
|
DATASETS = {
|
||||||
|
"cohere1m": {"base_db": os.path.join(DATASETS_DIR, "cohere1m", "base.db"), "dimensions": 768},
|
||||||
|
"cohere10m": {"base_db": os.path.join(DATASETS_DIR, "cohere10m", "base.db"), "dimensions": 768},
|
||||||
|
"nyt": {"base_db": os.path.join(DATASETS_DIR, "nyt", "base.db"), "dimensions": 256},
|
||||||
|
"nyt-768": {"base_db": os.path.join(DATASETS_DIR, "nyt-768", "base.db"), "dimensions": 768},
|
||||||
|
"nyt-1024": {"base_db": os.path.join(DATASETS_DIR, "nyt-1024", "base.db"), "dimensions": 1024},
|
||||||
|
"nyt-384": {"base_db": os.path.join(DATASETS_DIR, "nyt-384", "base.db"), "dimensions": 384},
|
||||||
|
}
|
||||||
|
|
||||||
|
INSERT_BATCH_SIZE = 1000
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Timing helpers
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def now_ns():
|
||||||
|
return time.time_ns()
|
||||||
|
|
||||||
|
def ns_to_s(ns):
|
||||||
|
return ns / 1_000_000_000
|
||||||
|
|
||||||
|
def ns_to_ms(ns):
|
||||||
|
return ns / 1_000_000
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Index registry (subset of bench.py — only types relevant to deletion)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def _vec0_flat_create(p):
|
||||||
|
dims = p["dimensions"]
|
||||||
|
variant = p.get("variant", "float")
|
||||||
|
col = f"embedding float[{dims}]"
|
||||||
|
if variant == "int8":
|
||||||
|
col = f"embedding int8[{dims}]"
|
||||||
|
elif variant == "bit":
|
||||||
|
col = f"embedding bit[{dims}]"
|
||||||
|
return f"CREATE VIRTUAL TABLE vec_items USING vec0(id INTEGER PRIMARY KEY, {col})"
|
||||||
|
|
||||||
|
def _rescore_create(p):
|
||||||
|
dims = p["dimensions"]
|
||||||
|
q = p.get("quantizer", "bit")
|
||||||
|
os_val = p.get("oversample", 8)
|
||||||
|
return (
|
||||||
|
f"CREATE VIRTUAL TABLE vec_items USING vec0("
|
||||||
|
f"id INTEGER PRIMARY KEY, "
|
||||||
|
f"embedding float[{dims}] indexed by rescore(quantizer={q}, oversample={os_val}))"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _diskann_create(p):
|
||||||
|
dims = p["dimensions"]
|
||||||
|
R = p.get("R", 72)
|
||||||
|
L = p.get("L", 128)
|
||||||
|
q = p.get("quantizer", "binary")
|
||||||
|
bt = p.get("buffer_threshold", 0)
|
||||||
|
sl_insert = p.get("search_list_size_insert", 0)
|
||||||
|
sl_search = p.get("search_list_size_search", 0)
|
||||||
|
parts = [
|
||||||
|
f"neighbor_quantizer={q}",
|
||||||
|
f"n_neighbors={R}",
|
||||||
|
f"buffer_threshold={bt}",
|
||||||
|
]
|
||||||
|
if sl_insert or sl_search:
|
||||||
|
# Per-path overrides — don't also set search_list_size
|
||||||
|
if sl_insert:
|
||||||
|
parts.append(f"search_list_size_insert={sl_insert}")
|
||||||
|
if sl_search:
|
||||||
|
parts.append(f"search_list_size_search={sl_search}")
|
||||||
|
else:
|
||||||
|
parts.append(f"search_list_size={L}")
|
||||||
|
opts = ", ".join(parts)
|
||||||
|
return (
|
||||||
|
f"CREATE VIRTUAL TABLE vec_items USING vec0("
|
||||||
|
f"id INTEGER PRIMARY KEY, "
|
||||||
|
f"embedding float[{dims}] indexed by diskann({opts}))"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _ivf_create(p):
|
||||||
|
dims = p["dimensions"]
|
||||||
|
nlist = p.get("nlist", 128)
|
||||||
|
nprobe = p.get("nprobe", 16)
|
||||||
|
q = p.get("quantizer", "none")
|
||||||
|
os_val = p.get("oversample", 1)
|
||||||
|
parts = [f"nlist={nlist}", f"nprobe={nprobe}"]
|
||||||
|
if q != "none":
|
||||||
|
parts.append(f"quantizer={q}")
|
||||||
|
if os_val > 1:
|
||||||
|
parts.append(f"oversample={os_val}")
|
||||||
|
opts = ", ".join(parts)
|
||||||
|
return (
|
||||||
|
f"CREATE VIRTUAL TABLE vec_items USING vec0("
|
||||||
|
f"id INTEGER PRIMARY KEY, "
|
||||||
|
f"embedding float[{dims}] indexed by ivf({opts}))"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
INDEX_REGISTRY = {
|
||||||
|
"vec0-flat": {
|
||||||
|
"defaults": {"variant": "float"},
|
||||||
|
"create_table_sql": _vec0_flat_create,
|
||||||
|
"post_insert_hook": None,
|
||||||
|
},
|
||||||
|
"rescore": {
|
||||||
|
"defaults": {"quantizer": "bit", "oversample": 8},
|
||||||
|
"create_table_sql": _rescore_create,
|
||||||
|
"post_insert_hook": None,
|
||||||
|
},
|
||||||
|
"ivf": {
|
||||||
|
"defaults": {"nlist": 128, "nprobe": 16, "quantizer": "none",
|
||||||
|
"oversample": 1},
|
||||||
|
"create_table_sql": _ivf_create,
|
||||||
|
"post_insert_hook": lambda conn, params: _ivf_train(conn),
|
||||||
|
},
|
||||||
|
"diskann": {
|
||||||
|
"defaults": {"R": 72, "L": 128, "quantizer": "binary",
|
||||||
|
"buffer_threshold": 0},
|
||||||
|
"create_table_sql": _diskann_create,
|
||||||
|
"post_insert_hook": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _ivf_train(conn):
|
||||||
|
"""Trigger built-in k-means training for IVF."""
|
||||||
|
t0 = now_ns()
|
||||||
|
conn.execute("INSERT INTO vec_items(vec_items) VALUES ('compute-centroids')")
|
||||||
|
conn.commit()
|
||||||
|
return ns_to_s(now_ns() - t0)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Config parsing (same format as bench.py)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
INT_KEYS = {"R", "L", "oversample", "nlist", "nprobe", "buffer_threshold",
|
||||||
|
"search_list_size_insert", "search_list_size_search"}
|
||||||
|
|
||||||
|
def parse_config(spec):
|
||||||
|
if ":" not in spec:
|
||||||
|
raise ValueError(f"Config must be 'name:key=val,...': {spec}")
|
||||||
|
name, rest = spec.split(":", 1)
|
||||||
|
params = {}
|
||||||
|
for kv in rest.split(","):
|
||||||
|
k, v = kv.split("=", 1)
|
||||||
|
k = k.strip()
|
||||||
|
v = v.strip()
|
||||||
|
if k in INT_KEYS:
|
||||||
|
v = int(v)
|
||||||
|
params[k] = v
|
||||||
|
index_type = params.pop("type", None)
|
||||||
|
if not index_type or index_type not in INDEX_REGISTRY:
|
||||||
|
raise ValueError(f"Unknown index type: {index_type}")
|
||||||
|
params["index_type"] = index_type
|
||||||
|
merged = dict(INDEX_REGISTRY[index_type]["defaults"])
|
||||||
|
merged.update(params)
|
||||||
|
return name, merged
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# DB helpers
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def create_bench_db(db_path, ext_path, base_db, page_size=4096):
|
||||||
|
if os.path.exists(db_path):
|
||||||
|
os.remove(db_path)
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
conn.execute(f"PRAGMA page_size={page_size}")
|
||||||
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
|
conn.enable_load_extension(True)
|
||||||
|
conn.load_extension(ext_path)
|
||||||
|
conn.execute(f"ATTACH DATABASE '{base_db}' AS base")
|
||||||
|
return conn
|
||||||
|
|
||||||
|
|
||||||
|
def load_query_vectors(base_db, n):
|
||||||
|
conn = sqlite3.connect(base_db)
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT id, vector FROM query_vectors LIMIT ?", (n,)
|
||||||
|
).fetchall()
|
||||||
|
conn.close()
|
||||||
|
return rows
|
||||||
|
|
||||||
|
|
||||||
|
def insert_loop(conn, subset_size, label, start_from=0):
|
||||||
|
insert_sql = (
|
||||||
|
"INSERT INTO vec_items(id, embedding) "
|
||||||
|
"SELECT id, vector FROM base.train "
|
||||||
|
"WHERE id >= :lo AND id < :hi"
|
||||||
|
)
|
||||||
|
total = 0
|
||||||
|
for lo in range(start_from, subset_size, INSERT_BATCH_SIZE):
|
||||||
|
hi = min(lo + INSERT_BATCH_SIZE, subset_size)
|
||||||
|
conn.execute(insert_sql, {"lo": lo, "hi": hi})
|
||||||
|
conn.commit()
|
||||||
|
total += hi - lo
|
||||||
|
if total % 5000 == 0 or total == subset_size - start_from:
|
||||||
|
print(f" [{label}] inserted {total + start_from}/{subset_size}", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Recall measurement
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def measure_recall(conn, base_db, query_vectors, subset_size, k, alive_ids=None):
|
||||||
|
"""Measure KNN recall. If alive_ids is provided, ground truth is computed
|
||||||
|
only over those IDs (to match post-delete state)."""
|
||||||
|
recalls = []
|
||||||
|
times_ms = []
|
||||||
|
|
||||||
|
for qid, query in query_vectors:
|
||||||
|
t0 = now_ns()
|
||||||
|
results = conn.execute(
|
||||||
|
"SELECT id, distance FROM vec_items "
|
||||||
|
"WHERE embedding MATCH :query AND k = :k",
|
||||||
|
{"query": query, "k": k},
|
||||||
|
).fetchall()
|
||||||
|
t1 = now_ns()
|
||||||
|
times_ms.append(ns_to_ms(t1 - t0))
|
||||||
|
|
||||||
|
result_ids = set(r[0] for r in results)
|
||||||
|
|
||||||
|
# Ground truth: brute-force cosine over surviving rows
|
||||||
|
if alive_ids is not None:
|
||||||
|
# After deletion — compute GT only over alive IDs
|
||||||
|
# Use a temp table for the alive set for efficiency
|
||||||
|
gt_rows = conn.execute(
|
||||||
|
"SELECT id FROM ("
|
||||||
|
" SELECT id, vec_distance_l2(vector, :query) as dist "
|
||||||
|
" FROM base.train WHERE id < :n ORDER BY dist LIMIT :k2"
|
||||||
|
")",
|
||||||
|
{"query": query, "k2": k * 5, "n": subset_size},
|
||||||
|
).fetchall()
|
||||||
|
# Filter to only alive IDs, take top k
|
||||||
|
gt_alive = [r[0] for r in gt_rows if r[0] in alive_ids][:k]
|
||||||
|
gt_ids = set(gt_alive)
|
||||||
|
else:
|
||||||
|
gt_rows = conn.execute(
|
||||||
|
"SELECT id FROM ("
|
||||||
|
" SELECT id, vec_distance_l2(vector, :query) as dist "
|
||||||
|
" FROM base.train WHERE id < :n ORDER BY dist LIMIT :k"
|
||||||
|
")",
|
||||||
|
{"query": query, "k": k, "n": subset_size},
|
||||||
|
).fetchall()
|
||||||
|
gt_ids = set(r[0] for r in gt_rows)
|
||||||
|
|
||||||
|
if gt_ids:
|
||||||
|
recalls.append(len(result_ids & gt_ids) / len(gt_ids))
|
||||||
|
else:
|
||||||
|
recalls.append(0.0)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"recall": round(statistics.mean(recalls), 4) if recalls else 0.0,
|
||||||
|
"mean_ms": round(statistics.mean(times_ms), 2) if times_ms else 0.0,
|
||||||
|
"median_ms": round(statistics.median(times_ms), 2) if times_ms else 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Delete benchmark core
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def run_delete_benchmark(name, params, base_db, ext_path, subset_size, dims,
|
||||||
|
delete_pcts, k, n_queries, out_dir, seed_val):
|
||||||
|
params["dimensions"] = dims
|
||||||
|
reg = INDEX_REGISTRY[params["index_type"]]
|
||||||
|
create_sql = reg["create_table_sql"](params)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# Build once, copy for each delete %
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"Config: {name} (type={params['index_type']})")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
|
||||||
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
master_db_path = os.path.join(out_dir, f"{name}.{subset_size}.db")
|
||||||
|
print(f" Building index ({subset_size} vectors)...")
|
||||||
|
build_t0 = now_ns()
|
||||||
|
conn = create_bench_db(master_db_path, ext_path, base_db)
|
||||||
|
conn.execute(create_sql)
|
||||||
|
insert_loop(conn, subset_size, name)
|
||||||
|
hook = reg.get("post_insert_hook")
|
||||||
|
if hook:
|
||||||
|
print(f" Training...")
|
||||||
|
hook(conn, params)
|
||||||
|
conn.close()
|
||||||
|
build_time_s = ns_to_s(now_ns() - build_t0)
|
||||||
|
master_size = os.path.getsize(master_db_path)
|
||||||
|
print(f" Built in {build_time_s:.1f}s ({master_size / (1024*1024):.1f} MB)")
|
||||||
|
|
||||||
|
# Load query vectors once
|
||||||
|
query_vectors = load_query_vectors(base_db, n_queries)
|
||||||
|
|
||||||
|
# Measure pre-delete baseline on the master copy
|
||||||
|
print(f"\n --- 0% deleted (baseline) ---")
|
||||||
|
conn = sqlite3.connect(master_db_path)
|
||||||
|
conn.enable_load_extension(True)
|
||||||
|
conn.load_extension(ext_path)
|
||||||
|
conn.execute(f"ATTACH DATABASE '{base_db}' AS base")
|
||||||
|
baseline = measure_recall(conn, base_db, query_vectors, subset_size, k)
|
||||||
|
conn.close()
|
||||||
|
print(f" recall={baseline['recall']:.4f} "
|
||||||
|
f"query={baseline['mean_ms']:.2f}ms")
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"name": name,
|
||||||
|
"index_type": params["index_type"],
|
||||||
|
"subset_size": subset_size,
|
||||||
|
"delete_pct": 0,
|
||||||
|
"n_deleted": 0,
|
||||||
|
"n_remaining": subset_size,
|
||||||
|
"recall": baseline["recall"],
|
||||||
|
"query_mean_ms": baseline["mean_ms"],
|
||||||
|
"query_median_ms": baseline["median_ms"],
|
||||||
|
"db_size_mb": round(master_size / (1024 * 1024), 2),
|
||||||
|
"build_time_s": round(build_time_s, 1),
|
||||||
|
"delete_time_s": 0.0,
|
||||||
|
"vacuum_size_mb": round(master_size / (1024 * 1024), 2),
|
||||||
|
})
|
||||||
|
|
||||||
|
# All IDs in the dataset
|
||||||
|
all_ids = list(range(subset_size))
|
||||||
|
|
||||||
|
for pct in sorted(delete_pcts):
|
||||||
|
n_delete = int(subset_size * pct / 100)
|
||||||
|
print(f"\n --- {pct}% deleted ({n_delete} rows) ---")
|
||||||
|
|
||||||
|
# Copy master DB and work on the copy
|
||||||
|
copy_path = os.path.join(out_dir, f"{name}.{subset_size}.del{pct}.db")
|
||||||
|
shutil.copy2(master_db_path, copy_path)
|
||||||
|
# Also copy WAL/SHM if they exist
|
||||||
|
for suffix in ["-wal", "-shm"]:
|
||||||
|
src = master_db_path + suffix
|
||||||
|
if os.path.exists(src):
|
||||||
|
shutil.copy2(src, copy_path + suffix)
|
||||||
|
|
||||||
|
conn = sqlite3.connect(copy_path)
|
||||||
|
conn.enable_load_extension(True)
|
||||||
|
conn.load_extension(ext_path)
|
||||||
|
conn.execute(f"ATTACH DATABASE '{base_db}' AS base")
|
||||||
|
|
||||||
|
# Pick random IDs to delete (deterministic per pct)
|
||||||
|
rng = random.Random(seed_val + pct)
|
||||||
|
to_delete = set(rng.sample(all_ids, n_delete))
|
||||||
|
alive_ids = set(all_ids) - to_delete
|
||||||
|
|
||||||
|
# Delete
|
||||||
|
delete_t0 = now_ns()
|
||||||
|
batch = []
|
||||||
|
for i, rid in enumerate(to_delete):
|
||||||
|
batch.append(rid)
|
||||||
|
if len(batch) >= 500 or i == len(to_delete) - 1:
|
||||||
|
placeholders = ",".join("?" for _ in batch)
|
||||||
|
conn.execute(
|
||||||
|
f"DELETE FROM vec_items WHERE id IN ({placeholders})",
|
||||||
|
batch,
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
batch = []
|
||||||
|
delete_time_s = ns_to_s(now_ns() - delete_t0)
|
||||||
|
|
||||||
|
remaining = conn.execute("SELECT count(*) FROM vec_items").fetchone()[0]
|
||||||
|
pre_vacuum_size = os.path.getsize(copy_path)
|
||||||
|
print(f" deleted {n_delete} rows in {delete_time_s:.2f}s "
|
||||||
|
f"({remaining} remaining)")
|
||||||
|
|
||||||
|
# Measure post-delete recall
|
||||||
|
post = measure_recall(conn, base_db, query_vectors, subset_size, k,
|
||||||
|
alive_ids=alive_ids)
|
||||||
|
print(f" recall={post['recall']:.4f} "
|
||||||
|
f"(delta={post['recall'] - baseline['recall']:+.4f}) "
|
||||||
|
f"query={post['mean_ms']:.2f}ms")
|
||||||
|
|
||||||
|
# VACUUM and measure size savings — close fully, reopen without base
|
||||||
|
conn.close()
|
||||||
|
vconn = sqlite3.connect(copy_path)
|
||||||
|
vconn.execute("VACUUM")
|
||||||
|
vconn.close()
|
||||||
|
post_vacuum_size = os.path.getsize(copy_path)
|
||||||
|
saved_mb = (pre_vacuum_size - post_vacuum_size) / (1024 * 1024)
|
||||||
|
print(f" size: {pre_vacuum_size/(1024*1024):.1f} MB -> "
|
||||||
|
f"{post_vacuum_size/(1024*1024):.1f} MB after VACUUM "
|
||||||
|
f"(saved {saved_mb:.1f} MB)")
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"name": name,
|
||||||
|
"index_type": params["index_type"],
|
||||||
|
"subset_size": subset_size,
|
||||||
|
"delete_pct": pct,
|
||||||
|
"n_deleted": n_delete,
|
||||||
|
"n_remaining": remaining,
|
||||||
|
"recall": post["recall"],
|
||||||
|
"query_mean_ms": post["mean_ms"],
|
||||||
|
"query_median_ms": post["median_ms"],
|
||||||
|
"db_size_mb": round(pre_vacuum_size / (1024 * 1024), 2),
|
||||||
|
"build_time_s": round(build_time_s, 1),
|
||||||
|
"delete_time_s": round(delete_time_s, 2),
|
||||||
|
"vacuum_size_mb": round(post_vacuum_size / (1024 * 1024), 2),
|
||||||
|
})
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Results DB
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
RESULTS_SCHEMA = """\
|
||||||
|
CREATE TABLE IF NOT EXISTS delete_runs (
|
||||||
|
run_id INTEGER PRIMARY KEY,
|
||||||
|
config_name TEXT NOT NULL,
|
||||||
|
index_type TEXT NOT NULL,
|
||||||
|
params TEXT,
|
||||||
|
dataset TEXT NOT NULL,
|
||||||
|
subset_size INTEGER NOT NULL,
|
||||||
|
delete_pct INTEGER NOT NULL,
|
||||||
|
n_deleted INTEGER NOT NULL,
|
||||||
|
n_remaining INTEGER NOT NULL,
|
||||||
|
k INTEGER NOT NULL,
|
||||||
|
n_queries INTEGER NOT NULL,
|
||||||
|
seed INTEGER NOT NULL,
|
||||||
|
recall REAL,
|
||||||
|
query_mean_ms REAL,
|
||||||
|
query_median_ms REAL,
|
||||||
|
db_size_mb REAL,
|
||||||
|
vacuum_size_mb REAL,
|
||||||
|
build_time_s REAL,
|
||||||
|
delete_time_s REAL,
|
||||||
|
created_at TEXT DEFAULT (datetime('now'))
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
|
||||||
|
def save_results(results, out_dir, dataset, subset_size, params_json, k, n_queries, seed_val):
|
||||||
|
db_path = os.path.join(out_dir, "delete_results.db")
|
||||||
|
db = sqlite3.connect(db_path)
|
||||||
|
db.execute("PRAGMA journal_mode=WAL")
|
||||||
|
db.executescript(RESULTS_SCHEMA)
|
||||||
|
for r in results:
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO delete_runs "
|
||||||
|
"(config_name, index_type, params, dataset, subset_size, "
|
||||||
|
" delete_pct, n_deleted, n_remaining, k, n_queries, seed, "
|
||||||
|
" recall, query_mean_ms, query_median_ms, "
|
||||||
|
" db_size_mb, vacuum_size_mb, build_time_s, delete_time_s) "
|
||||||
|
"VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)",
|
||||||
|
(
|
||||||
|
r["name"], r["index_type"], params_json, dataset, r["subset_size"],
|
||||||
|
r["delete_pct"], r["n_deleted"], r["n_remaining"], k, n_queries, seed_val,
|
||||||
|
r["recall"], r["query_mean_ms"], r["query_median_ms"],
|
||||||
|
r["db_size_mb"], r["vacuum_size_mb"], r["build_time_s"], r["delete_time_s"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
db.close()
|
||||||
|
return db_path
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Reporting
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def print_report(all_results):
|
||||||
|
print(f"\n{'name':>22} {'del%':>5} {'deleted':>8} {'remain':>8} "
|
||||||
|
f"{'recall':>7} {'delta':>7} {'qry(ms)':>8} "
|
||||||
|
f"{'size(MB)':>9} {'vacuumed':>9} {'del(s)':>7}")
|
||||||
|
print("-" * 110)
|
||||||
|
|
||||||
|
# Group by config name
|
||||||
|
configs = {}
|
||||||
|
for r in all_results:
|
||||||
|
configs.setdefault(r["name"], []).append(r)
|
||||||
|
|
||||||
|
for name, rows in configs.items():
|
||||||
|
baseline_recall = rows[0]["recall"] # 0% delete is always first
|
||||||
|
for r in rows:
|
||||||
|
delta = r["recall"] - baseline_recall
|
||||||
|
delta_str = f"{delta:+.4f}" if r["delete_pct"] > 0 else "-"
|
||||||
|
print(
|
||||||
|
f"{r['name']:>22} {r['delete_pct']:>4}% "
|
||||||
|
f"{r['n_deleted']:>8} {r['n_remaining']:>8} "
|
||||||
|
f"{r['recall']:>7.4f} {delta_str:>7} {r['query_mean_ms']:>8.2f} "
|
||||||
|
f"{r['db_size_mb']:>9.1f} {r['vacuum_size_mb']:>9.1f} "
|
||||||
|
f"{r['delete_time_s']:>7.2f}"
|
||||||
|
)
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Main
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Benchmark recall degradation after random row deletion",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
epilog=__doc__,
|
||||||
|
)
|
||||||
|
parser.add_argument("configs", nargs="+",
|
||||||
|
help="config specs (name:type=X,key=val,...)")
|
||||||
|
parser.add_argument("--subset-size", type=int, default=10000,
|
||||||
|
help="number of vectors to build (default: 10000)")
|
||||||
|
parser.add_argument("--delete-pct", type=str, default="10,25,50",
|
||||||
|
help="comma-separated delete percentages (default: 10,25,50)")
|
||||||
|
parser.add_argument("-k", type=int, default=10, help="KNN k (default 10)")
|
||||||
|
parser.add_argument("-n", type=int, default=50,
|
||||||
|
help="number of queries (default 50)")
|
||||||
|
parser.add_argument("--dataset", default="cohere1m",
|
||||||
|
choices=list(DATASETS.keys()))
|
||||||
|
parser.add_argument("--ext", default=EXT_PATH)
|
||||||
|
parser.add_argument("-o", "--out-dir",
|
||||||
|
default=os.path.join(_SCRIPT_DIR, "runs"))
|
||||||
|
parser.add_argument("--seed", type=int, default=42,
|
||||||
|
help="random seed for delete selection (default: 42)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
ds = DATASETS[args.dataset]
|
||||||
|
base_db = ds["base_db"]
|
||||||
|
dims = ds["dimensions"]
|
||||||
|
if not os.path.exists(base_db):
|
||||||
|
print(f"Error: dataset not found at {base_db}")
|
||||||
|
print(f"Run: make -C {os.path.dirname(base_db)}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
delete_pcts = [int(x.strip()) for x in args.delete_pct.split(",")]
|
||||||
|
for p in delete_pcts:
|
||||||
|
if not 0 < p < 100:
|
||||||
|
print(f"Error: delete percentage must be 1-99, got {p}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
out_dir = os.path.join(args.out_dir, args.dataset, str(args.subset_size))
|
||||||
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
|
||||||
|
all_results = []
|
||||||
|
for spec in args.configs:
|
||||||
|
name, params = parse_config(spec)
|
||||||
|
params_json = json.dumps(params)
|
||||||
|
results = run_delete_benchmark(
|
||||||
|
name, params, base_db, args.ext, args.subset_size, dims,
|
||||||
|
delete_pcts, args.k, args.n, out_dir, args.seed,
|
||||||
|
)
|
||||||
|
all_results.extend(results)
|
||||||
|
|
||||||
|
save_results(results, out_dir, args.dataset, args.subset_size,
|
||||||
|
params_json, args.k, args.n, args.seed)
|
||||||
|
|
||||||
|
print_report(all_results)
|
||||||
|
|
||||||
|
results_path = os.path.join(out_dir, "delete_results.db")
|
||||||
|
print(f"\nResults saved to: {results_path}")
|
||||||
|
print(f"Query: sqlite3 {results_path} "
|
||||||
|
f"\"SELECT config_name, delete_pct, recall, vacuum_size_mb "
|
||||||
|
f"FROM delete_runs ORDER BY config_name, delete_pct\"")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
124
benchmarks-ann/bench-delete/test_smoke.py
Normal file
124
benchmarks-ann/bench-delete/test_smoke.py
Normal file
|
|
@ -0,0 +1,124 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Quick self-contained smoke test using a synthetic dataset.
|
||||||
|
Creates a tiny base.db in a temp dir, runs the delete benchmark, verifies output.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sqlite3
|
||||||
|
import struct
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
_ROOT_DIR = os.path.join(_SCRIPT_DIR, "..", "..")
|
||||||
|
EXT_PATH = os.path.join(_ROOT_DIR, "dist", "vec0")
|
||||||
|
|
||||||
|
DIMS = 8
|
||||||
|
N_TRAIN = 200
|
||||||
|
N_QUERIES = 10
|
||||||
|
K_NEIGHBORS = 5
|
||||||
|
|
||||||
|
|
||||||
|
def _f32(vals):
|
||||||
|
return struct.pack(f"{len(vals)}f", *vals)
|
||||||
|
|
||||||
|
|
||||||
|
def make_synthetic_base_db(path):
|
||||||
|
"""Create a minimal base.db with train vectors and query vectors."""
|
||||||
|
rng = random.Random(123)
|
||||||
|
db = sqlite3.connect(path)
|
||||||
|
db.execute("CREATE TABLE train(id INTEGER PRIMARY KEY, vector BLOB)")
|
||||||
|
db.execute("CREATE TABLE query_vectors(id INTEGER PRIMARY KEY, vector BLOB)")
|
||||||
|
|
||||||
|
for i in range(N_TRAIN):
|
||||||
|
vec = [rng.gauss(0, 1) for _ in range(DIMS)]
|
||||||
|
db.execute("INSERT INTO train VALUES (?, ?)", (i, _f32(vec)))
|
||||||
|
|
||||||
|
for i in range(N_QUERIES):
|
||||||
|
vec = [rng.gauss(0, 1) for _ in range(DIMS)]
|
||||||
|
db.execute("INSERT INTO query_vectors VALUES (?, ?)", (i, _f32(vec)))
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if not os.path.exists(EXT_PATH + ".dylib") and not os.path.exists(EXT_PATH + ".so"):
|
||||||
|
# Try bare path (sqlite handles extension)
|
||||||
|
pass
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
base_db = os.path.join(tmpdir, "base.db")
|
||||||
|
make_synthetic_base_db(base_db)
|
||||||
|
|
||||||
|
# Patch DATASETS to use our synthetic DB
|
||||||
|
import bench_delete
|
||||||
|
bench_delete.DATASETS["synthetic"] = {
|
||||||
|
"base_db": base_db,
|
||||||
|
"dimensions": DIMS,
|
||||||
|
}
|
||||||
|
|
||||||
|
out_dir = os.path.join(tmpdir, "runs")
|
||||||
|
|
||||||
|
# Test flat index
|
||||||
|
print("=== Testing flat index ===")
|
||||||
|
name, params = bench_delete.parse_config("flat:type=vec0-flat,variant=float")
|
||||||
|
params["dimensions"] = DIMS
|
||||||
|
results = bench_delete.run_delete_benchmark(
|
||||||
|
name, params, base_db, EXT_PATH,
|
||||||
|
subset_size=N_TRAIN, dims=DIMS,
|
||||||
|
delete_pcts=[25, 50], k=K_NEIGHBORS, n_queries=N_QUERIES,
|
||||||
|
out_dir=out_dir, seed_val=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
bench_delete.print_report(results)
|
||||||
|
|
||||||
|
# Flat recall should be 1.0 at all delete %
|
||||||
|
for r in results:
|
||||||
|
assert r["recall"] == 1.0, \
|
||||||
|
f"Flat recall should be 1.0, got {r['recall']} at {r['delete_pct']}%"
|
||||||
|
print("\n PASS: flat recall is 1.0 at all delete percentages\n")
|
||||||
|
|
||||||
|
# Test DiskANN
|
||||||
|
print("=== Testing DiskANN ===")
|
||||||
|
name2, params2 = bench_delete.parse_config(
|
||||||
|
"diskann:type=diskann,R=8,L=32,quantizer=binary"
|
||||||
|
)
|
||||||
|
params2["dimensions"] = DIMS
|
||||||
|
results2 = bench_delete.run_delete_benchmark(
|
||||||
|
name2, params2, base_db, EXT_PATH,
|
||||||
|
subset_size=N_TRAIN, dims=DIMS,
|
||||||
|
delete_pcts=[25, 50], k=K_NEIGHBORS, n_queries=N_QUERIES,
|
||||||
|
out_dir=out_dir, seed_val=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
bench_delete.print_report(results2)
|
||||||
|
|
||||||
|
# DiskANN baseline (0%) should have decent recall
|
||||||
|
baseline = results2[0]
|
||||||
|
assert baseline["recall"] > 0.0, \
|
||||||
|
f"DiskANN baseline recall is zero"
|
||||||
|
print(f" PASS: DiskANN baseline recall={baseline['recall']}")
|
||||||
|
|
||||||
|
# Test rescore
|
||||||
|
print("\n=== Testing rescore ===")
|
||||||
|
name3, params3 = bench_delete.parse_config(
|
||||||
|
"rescore:type=rescore,quantizer=bit,oversample=4"
|
||||||
|
)
|
||||||
|
params3["dimensions"] = DIMS
|
||||||
|
results3 = bench_delete.run_delete_benchmark(
|
||||||
|
name3, params3, base_db, EXT_PATH,
|
||||||
|
subset_size=N_TRAIN, dims=DIMS,
|
||||||
|
delete_pcts=[25, 50], k=K_NEIGHBORS, n_queries=N_QUERIES,
|
||||||
|
out_dir=out_dir, seed_val=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
bench_delete.print_report(results3)
|
||||||
|
print(f" PASS: rescore baseline recall={results3[0]['recall']}")
|
||||||
|
|
||||||
|
print("\n ALL SMOKE TESTS PASSED")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
1350
benchmarks-ann/bench.py
Normal file
1350
benchmarks-ann/bench.py
Normal file
File diff suppressed because it is too large
Load diff
27
benchmarks-ann/datasets/cohere10m/Makefile
Normal file
27
benchmarks-ann/datasets/cohere10m/Makefile
Normal file
|
|
@ -0,0 +1,27 @@
|
||||||
|
BASE_URL = https://assets.zilliz.com/benchmark/cohere_large_10m
|
||||||
|
|
||||||
|
TRAIN_PARQUETS = $(shell printf 'train-%02d-of-10.parquet ' 0 1 2 3 4 5 6 7 8 9)
|
||||||
|
OTHER_PARQUETS = test.parquet neighbors.parquet
|
||||||
|
PARQUETS = $(TRAIN_PARQUETS) $(OTHER_PARQUETS)
|
||||||
|
|
||||||
|
.PHONY: all download clean
|
||||||
|
|
||||||
|
all: base.db
|
||||||
|
|
||||||
|
# Use: make -j12 download
|
||||||
|
download: $(PARQUETS)
|
||||||
|
|
||||||
|
train-%-of-10.parquet:
|
||||||
|
curl -L -o $@ $(BASE_URL)/$@
|
||||||
|
|
||||||
|
test.parquet:
|
||||||
|
curl -L -o $@ $(BASE_URL)/test.parquet
|
||||||
|
|
||||||
|
neighbors.parquet:
|
||||||
|
curl -L -o $@ $(BASE_URL)/neighbors.parquet
|
||||||
|
|
||||||
|
base.db: $(PARQUETS) build_base_db.py
|
||||||
|
uv run --with pandas --with pyarrow python build_base_db.py
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -f base.db
|
||||||
134
benchmarks-ann/datasets/cohere10m/build_base_db.py
Normal file
134
benchmarks-ann/datasets/cohere10m/build_base_db.py
Normal file
|
|
@ -0,0 +1,134 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Build base.db from downloaded parquet files (10M dataset, 10 train shards).
|
||||||
|
|
||||||
|
Reads train-00-of-10.parquet .. train-09-of-10.parquet, test.parquet,
|
||||||
|
neighbors.parquet and creates a SQLite database with tables:
|
||||||
|
train, query_vectors, neighbors.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
uv run --with pandas --with pyarrow python build_base_db.py
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import struct
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
TRAIN_SHARDS = 10
|
||||||
|
|
||||||
|
|
||||||
|
def float_list_to_blob(floats):
|
||||||
|
"""Pack a list of floats into a little-endian f32 blob."""
|
||||||
|
return struct.pack(f"<{len(floats)}f", *floats)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
db_path = os.path.join(script_dir, "base.db")
|
||||||
|
|
||||||
|
train_paths = [
|
||||||
|
os.path.join(script_dir, f"train-{i:02d}-of-{TRAIN_SHARDS}.parquet")
|
||||||
|
for i in range(TRAIN_SHARDS)
|
||||||
|
]
|
||||||
|
test_path = os.path.join(script_dir, "test.parquet")
|
||||||
|
neighbors_path = os.path.join(script_dir, "neighbors.parquet")
|
||||||
|
|
||||||
|
for p in train_paths + [test_path, neighbors_path]:
|
||||||
|
if not os.path.exists(p):
|
||||||
|
print(f"ERROR: {p} not found. Run 'make download' first.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if os.path.exists(db_path):
|
||||||
|
os.remove(db_path)
|
||||||
|
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
|
conn.execute("PRAGMA page_size=4096")
|
||||||
|
|
||||||
|
# --- query_vectors (from test.parquet) ---
|
||||||
|
print("Loading test.parquet (query vectors)...")
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
df_test = pd.read_parquet(test_path)
|
||||||
|
conn.execute(
|
||||||
|
"CREATE TABLE query_vectors (id INTEGER PRIMARY KEY, vector BLOB)"
|
||||||
|
)
|
||||||
|
rows = []
|
||||||
|
for _, row in df_test.iterrows():
|
||||||
|
rows.append((int(row["id"]), float_list_to_blob(row["emb"])))
|
||||||
|
conn.executemany("INSERT INTO query_vectors (id, vector) VALUES (?, ?)", rows)
|
||||||
|
conn.commit()
|
||||||
|
print(f" {len(rows)} query vectors in {time.perf_counter() - t0:.1f}s")
|
||||||
|
|
||||||
|
# --- neighbors (from neighbors.parquet) ---
|
||||||
|
print("Loading neighbors.parquet...")
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
df_neighbors = pd.read_parquet(neighbors_path)
|
||||||
|
conn.execute(
|
||||||
|
"CREATE TABLE neighbors ("
|
||||||
|
" query_vector_id INTEGER, rank INTEGER, neighbors_id TEXT,"
|
||||||
|
" UNIQUE(query_vector_id, rank))"
|
||||||
|
)
|
||||||
|
rows = []
|
||||||
|
for _, row in df_neighbors.iterrows():
|
||||||
|
qid = int(row["id"])
|
||||||
|
nids = row["neighbors_id"]
|
||||||
|
if isinstance(nids, str):
|
||||||
|
nids = json.loads(nids)
|
||||||
|
for rank, nid in enumerate(nids):
|
||||||
|
rows.append((qid, rank, str(int(nid))))
|
||||||
|
conn.executemany(
|
||||||
|
"INSERT INTO neighbors (query_vector_id, rank, neighbors_id) VALUES (?, ?, ?)",
|
||||||
|
rows,
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
print(f" {len(rows)} neighbor rows in {time.perf_counter() - t0:.1f}s")
|
||||||
|
|
||||||
|
# --- train (from 10 shard parquets) ---
|
||||||
|
print(f"Loading {TRAIN_SHARDS} train shards (10M vectors, this will take a while)...")
|
||||||
|
conn.execute(
|
||||||
|
"CREATE TABLE train (id INTEGER PRIMARY KEY, vector BLOB)"
|
||||||
|
)
|
||||||
|
|
||||||
|
global_t0 = time.perf_counter()
|
||||||
|
total_inserted = 0
|
||||||
|
batch_size = 10000
|
||||||
|
|
||||||
|
for shard_idx, train_path in enumerate(train_paths):
|
||||||
|
print(f" Shard {shard_idx + 1}/{TRAIN_SHARDS}: {os.path.basename(train_path)}")
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
df = pd.read_parquet(train_path)
|
||||||
|
shard_len = len(df)
|
||||||
|
|
||||||
|
for start in range(0, shard_len, batch_size):
|
||||||
|
chunk = df.iloc[start : start + batch_size]
|
||||||
|
rows = []
|
||||||
|
for _, row in chunk.iterrows():
|
||||||
|
rows.append((int(row["id"]), float_list_to_blob(row["emb"])))
|
||||||
|
conn.executemany("INSERT INTO train (id, vector) VALUES (?, ?)", rows)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
total_inserted += len(rows)
|
||||||
|
if total_inserted % 100000 < batch_size:
|
||||||
|
elapsed = time.perf_counter() - global_t0
|
||||||
|
rate = total_inserted / elapsed if elapsed > 0 else 0
|
||||||
|
print(
|
||||||
|
f" {total_inserted:>10} {elapsed:.0f}s {rate:.0f} rows/s",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
shard_elapsed = time.perf_counter() - t0
|
||||||
|
print(f" shard done: {shard_len} rows in {shard_elapsed:.1f}s")
|
||||||
|
|
||||||
|
elapsed = time.perf_counter() - global_t0
|
||||||
|
print(f" {total_inserted} train vectors in {elapsed:.1f}s")
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
size_mb = os.path.getsize(db_path) / (1024 * 1024)
|
||||||
|
print(f"\nDone: {db_path} ({size_mb:.0f} MB)")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
2
benchmarks-ann/datasets/cohere1m/.gitignore
vendored
Normal file
2
benchmarks-ann/datasets/cohere1m/.gitignore
vendored
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
*.parquet
|
||||||
|
base.db
|
||||||
24
benchmarks-ann/datasets/cohere1m/Makefile
Normal file
24
benchmarks-ann/datasets/cohere1m/Makefile
Normal file
|
|
@ -0,0 +1,24 @@
|
||||||
|
BASE_URL = https://assets.zilliz.com/benchmark/cohere_medium_1m
|
||||||
|
|
||||||
|
PARQUETS = train.parquet test.parquet neighbors.parquet
|
||||||
|
|
||||||
|
.PHONY: all download base.db clean
|
||||||
|
|
||||||
|
all: base.db
|
||||||
|
|
||||||
|
download: $(PARQUETS)
|
||||||
|
|
||||||
|
train.parquet:
|
||||||
|
curl -L -o $@ $(BASE_URL)/train.parquet
|
||||||
|
|
||||||
|
test.parquet:
|
||||||
|
curl -L -o $@ $(BASE_URL)/test.parquet
|
||||||
|
|
||||||
|
neighbors.parquet:
|
||||||
|
curl -L -o $@ $(BASE_URL)/neighbors.parquet
|
||||||
|
|
||||||
|
base.db: $(PARQUETS) build_base_db.py
|
||||||
|
uv run --with pandas --with pyarrow python build_base_db.py
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -f base.db
|
||||||
121
benchmarks-ann/datasets/cohere1m/build_base_db.py
Normal file
121
benchmarks-ann/datasets/cohere1m/build_base_db.py
Normal file
|
|
@ -0,0 +1,121 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Build base.db from downloaded parquet files.
|
||||||
|
|
||||||
|
Reads train.parquet, test.parquet, neighbors.parquet and creates a SQLite
|
||||||
|
database with tables: train, query_vectors, neighbors.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
uv run --with pandas --with pyarrow python build_base_db.py
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import struct
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
def float_list_to_blob(floats):
|
||||||
|
"""Pack a list of floats into a little-endian f32 blob."""
|
||||||
|
return struct.pack(f"<{len(floats)}f", *floats)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
seed_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
db_path = os.path.join(seed_dir, "base.db")
|
||||||
|
|
||||||
|
train_path = os.path.join(seed_dir, "train.parquet")
|
||||||
|
test_path = os.path.join(seed_dir, "test.parquet")
|
||||||
|
neighbors_path = os.path.join(seed_dir, "neighbors.parquet")
|
||||||
|
|
||||||
|
for p in (train_path, test_path, neighbors_path):
|
||||||
|
if not os.path.exists(p):
|
||||||
|
print(f"ERROR: {p} not found. Run 'make download' first.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if os.path.exists(db_path):
|
||||||
|
os.remove(db_path)
|
||||||
|
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
|
conn.execute("PRAGMA page_size=4096")
|
||||||
|
|
||||||
|
# --- query_vectors (from test.parquet) ---
|
||||||
|
print("Loading test.parquet (query vectors)...")
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
df_test = pd.read_parquet(test_path)
|
||||||
|
conn.execute(
|
||||||
|
"CREATE TABLE query_vectors (id INTEGER PRIMARY KEY, vector BLOB)"
|
||||||
|
)
|
||||||
|
rows = []
|
||||||
|
for _, row in df_test.iterrows():
|
||||||
|
rows.append((int(row["id"]), float_list_to_blob(row["emb"])))
|
||||||
|
conn.executemany("INSERT INTO query_vectors (id, vector) VALUES (?, ?)", rows)
|
||||||
|
conn.commit()
|
||||||
|
print(f" {len(rows)} query vectors in {time.perf_counter() - t0:.1f}s")
|
||||||
|
|
||||||
|
# --- neighbors (from neighbors.parquet) ---
|
||||||
|
print("Loading neighbors.parquet...")
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
df_neighbors = pd.read_parquet(neighbors_path)
|
||||||
|
conn.execute(
|
||||||
|
"CREATE TABLE neighbors ("
|
||||||
|
" query_vector_id INTEGER, rank INTEGER, neighbors_id TEXT,"
|
||||||
|
" UNIQUE(query_vector_id, rank))"
|
||||||
|
)
|
||||||
|
rows = []
|
||||||
|
for _, row in df_neighbors.iterrows():
|
||||||
|
qid = int(row["id"])
|
||||||
|
# neighbors_id may be a numpy array or JSON string
|
||||||
|
nids = row["neighbors_id"]
|
||||||
|
if isinstance(nids, str):
|
||||||
|
nids = json.loads(nids)
|
||||||
|
for rank, nid in enumerate(nids):
|
||||||
|
rows.append((qid, rank, str(int(nid))))
|
||||||
|
conn.executemany(
|
||||||
|
"INSERT INTO neighbors (query_vector_id, rank, neighbors_id) VALUES (?, ?, ?)",
|
||||||
|
rows,
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
print(f" {len(rows)} neighbor rows in {time.perf_counter() - t0:.1f}s")
|
||||||
|
|
||||||
|
# --- train (from train.parquet) ---
|
||||||
|
print("Loading train.parquet (1M vectors, this takes a few minutes)...")
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
conn.execute(
|
||||||
|
"CREATE TABLE train (id INTEGER PRIMARY KEY, vector BLOB)"
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size = 10000
|
||||||
|
df_iter = pd.read_parquet(train_path)
|
||||||
|
total = len(df_iter)
|
||||||
|
|
||||||
|
for start in range(0, total, batch_size):
|
||||||
|
chunk = df_iter.iloc[start : start + batch_size]
|
||||||
|
rows = []
|
||||||
|
for _, row in chunk.iterrows():
|
||||||
|
rows.append((int(row["id"]), float_list_to_blob(row["emb"])))
|
||||||
|
conn.executemany("INSERT INTO train (id, vector) VALUES (?, ?)", rows)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
done = min(start + batch_size, total)
|
||||||
|
elapsed = time.perf_counter() - t0
|
||||||
|
rate = done / elapsed if elapsed > 0 else 0
|
||||||
|
eta = (total - done) / rate if rate > 0 else 0
|
||||||
|
print(
|
||||||
|
f" {done:>8}/{total} {elapsed:.0f}s {rate:.0f} rows/s eta {eta:.0f}s",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
elapsed = time.perf_counter() - t0
|
||||||
|
print(f" {total} train vectors in {elapsed:.1f}s")
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
size_mb = os.path.getsize(db_path) / (1024 * 1024)
|
||||||
|
print(f"\nDone: {db_path} ({size_mb:.0f} MB)")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
30
benchmarks-ann/datasets/nyt-1024/Makefile
Normal file
30
benchmarks-ann/datasets/nyt-1024/Makefile
Normal file
|
|
@ -0,0 +1,30 @@
|
||||||
|
MODEL ?= mixedbread-ai/mxbai-embed-large-v1
|
||||||
|
K ?= 100
|
||||||
|
BATCH_SIZE ?= 256
|
||||||
|
DATA_DIR ?= ../nyt/data
|
||||||
|
|
||||||
|
all: base.db
|
||||||
|
|
||||||
|
# Reuse data from ../nyt
|
||||||
|
$(DATA_DIR):
|
||||||
|
$(MAKE) -C ../nyt data
|
||||||
|
|
||||||
|
contents.db: $(DATA_DIR)
|
||||||
|
uv run ../nyt-768/build-contents.py --data-dir $(DATA_DIR) -o $@
|
||||||
|
|
||||||
|
base.db: contents.db queries.txt
|
||||||
|
uv run build-base.py \
|
||||||
|
--contents-db contents.db \
|
||||||
|
--model $(MODEL) \
|
||||||
|
--queries-file queries.txt \
|
||||||
|
--batch-size $(BATCH_SIZE) \
|
||||||
|
--k $(K) \
|
||||||
|
-o $@
|
||||||
|
|
||||||
|
queries.txt:
|
||||||
|
cp ../nyt/queries.txt $@
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -f base.db contents.db
|
||||||
|
|
||||||
|
.PHONY: all clean
|
||||||
163
benchmarks-ann/datasets/nyt-1024/build-base.py
Normal file
163
benchmarks-ann/datasets/nyt-1024/build-base.py
Normal file
|
|
@ -0,0 +1,163 @@
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.12"
|
||||||
|
# dependencies = [
|
||||||
|
# "sentence-transformers",
|
||||||
|
# "torch<=2.7",
|
||||||
|
# "tqdm",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sqlite3
|
||||||
|
from array import array
|
||||||
|
from itertools import batched
|
||||||
|
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Build base.db with train vectors, query vectors, and brute-force KNN neighbors",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--contents-db", "-c", default=None,
|
||||||
|
help="Path to contents.db (source of headlines and IDs)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model", "-m", default="mixedbread-ai/mxbai-embed-large-v1",
|
||||||
|
help="HuggingFace model ID (default: mixedbread-ai/mxbai-embed-large-v1)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--queries-file", "-q", default="queries.txt",
|
||||||
|
help="Path to the queries file (default: queries.txt)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output", "-o", required=True,
|
||||||
|
help="Path to the output base.db",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size", "-b", type=int, default=256,
|
||||||
|
help="Batch size for embedding (default: 256)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--k", "-k", type=int, default=100,
|
||||||
|
help="Number of nearest neighbors (default: 100)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--limit", "-l", type=int, default=0,
|
||||||
|
help="Limit number of headlines to embed (0 = all)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vec-path", "-v", default="~/projects/sqlite-vec/dist/vec0",
|
||||||
|
help="Path to sqlite-vec extension (default: ~/projects/sqlite-vec/dist/vec0)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip-neighbors", action="store_true",
|
||||||
|
help="Skip the brute-force KNN neighbor computation",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
import os
|
||||||
|
vec_path = os.path.expanduser(args.vec_path)
|
||||||
|
|
||||||
|
print(f"Loading model {args.model}...")
|
||||||
|
model = SentenceTransformer(args.model)
|
||||||
|
|
||||||
|
# Read headlines from contents.db
|
||||||
|
src = sqlite3.connect(args.contents_db)
|
||||||
|
limit_clause = f" LIMIT {args.limit}" if args.limit > 0 else ""
|
||||||
|
headlines = src.execute(
|
||||||
|
f"SELECT id, headline FROM contents ORDER BY id{limit_clause}"
|
||||||
|
).fetchall()
|
||||||
|
src.close()
|
||||||
|
print(f"Loaded {len(headlines)} headlines from {args.contents_db}")
|
||||||
|
|
||||||
|
# Read queries
|
||||||
|
with open(args.queries_file) as f:
|
||||||
|
queries = [line.strip() for line in f if line.strip()]
|
||||||
|
print(f"Loaded {len(queries)} queries from {args.queries_file}")
|
||||||
|
|
||||||
|
# Create output database
|
||||||
|
db = sqlite3.connect(args.output)
|
||||||
|
db.enable_load_extension(True)
|
||||||
|
db.load_extension(vec_path)
|
||||||
|
db.enable_load_extension(False)
|
||||||
|
|
||||||
|
db.execute("CREATE TABLE IF NOT EXISTS train(id INTEGER PRIMARY KEY, vector BLOB)")
|
||||||
|
db.execute("CREATE TABLE IF NOT EXISTS query_vectors(id INTEGER PRIMARY KEY, vector BLOB)")
|
||||||
|
db.execute(
|
||||||
|
"CREATE TABLE IF NOT EXISTS neighbors("
|
||||||
|
" query_vector_id INTEGER, rank INTEGER, neighbors_id TEXT,"
|
||||||
|
" UNIQUE(query_vector_id, rank))"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 1: Embed headlines -> train table
|
||||||
|
print("Embedding headlines...")
|
||||||
|
for batch in tqdm(
|
||||||
|
batched(headlines, args.batch_size),
|
||||||
|
total=(len(headlines) + args.batch_size - 1) // args.batch_size,
|
||||||
|
):
|
||||||
|
ids = [r[0] for r in batch]
|
||||||
|
texts = [r[1] for r in batch]
|
||||||
|
embeddings = model.encode(texts, normalize_embeddings=True)
|
||||||
|
|
||||||
|
params = [
|
||||||
|
(int(rid), array("f", emb.tolist()).tobytes())
|
||||||
|
for rid, emb in zip(ids, embeddings)
|
||||||
|
]
|
||||||
|
db.executemany("INSERT INTO train VALUES (?, ?)", params)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
del headlines
|
||||||
|
n = db.execute("SELECT count(*) FROM train").fetchone()[0]
|
||||||
|
print(f"Embedded {n} headlines")
|
||||||
|
|
||||||
|
# Step 2: Embed queries -> query_vectors table
|
||||||
|
print("Embedding queries...")
|
||||||
|
query_embeddings = model.encode(queries, normalize_embeddings=True)
|
||||||
|
query_params = []
|
||||||
|
for i, emb in enumerate(query_embeddings, 1):
|
||||||
|
blob = array("f", emb.tolist()).tobytes()
|
||||||
|
query_params.append((i, blob))
|
||||||
|
db.executemany("INSERT INTO query_vectors VALUES (?, ?)", query_params)
|
||||||
|
db.commit()
|
||||||
|
print(f"Embedded {len(queries)} queries")
|
||||||
|
|
||||||
|
if args.skip_neighbors:
|
||||||
|
db.close()
|
||||||
|
print(f"Done (skipped neighbors). Wrote {args.output}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Step 3: Brute-force KNN via sqlite-vec -> neighbors table
|
||||||
|
n_queries = db.execute("SELECT count(*) FROM query_vectors").fetchone()[0]
|
||||||
|
print(f"Computing {args.k}-NN for {n_queries} queries via sqlite-vec...")
|
||||||
|
for query_id, query_blob in tqdm(
|
||||||
|
db.execute("SELECT id, vector FROM query_vectors").fetchall()
|
||||||
|
):
|
||||||
|
results = db.execute(
|
||||||
|
"""
|
||||||
|
SELECT
|
||||||
|
train.id,
|
||||||
|
vec_distance_cosine(train.vector, ?) AS distance
|
||||||
|
FROM train
|
||||||
|
WHERE distance IS NOT NULL
|
||||||
|
ORDER BY distance ASC
|
||||||
|
LIMIT ?
|
||||||
|
""",
|
||||||
|
(query_blob, args.k),
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
params = [
|
||||||
|
(query_id, rank, str(rid))
|
||||||
|
for rank, (rid, _dist) in enumerate(results)
|
||||||
|
]
|
||||||
|
db.executemany("INSERT INTO neighbors VALUES (?, ?, ?)", params)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.close()
|
||||||
|
print(f"Done. Wrote {args.output}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
100
benchmarks-ann/datasets/nyt-1024/queries.txt
Normal file
100
benchmarks-ann/datasets/nyt-1024/queries.txt
Normal file
|
|
@ -0,0 +1,100 @@
|
||||||
|
latest news on climate change policy
|
||||||
|
presidential election results and analysis
|
||||||
|
stock market crash causes
|
||||||
|
coronavirus vaccine development updates
|
||||||
|
artificial intelligence breakthrough in healthcare
|
||||||
|
supreme court ruling on abortion rights
|
||||||
|
tech companies layoff announcements
|
||||||
|
earthquake damages in California
|
||||||
|
cybersecurity breach at major corporation
|
||||||
|
space exploration mission to Mars
|
||||||
|
immigration reform legislation debate
|
||||||
|
renewable energy investment trends
|
||||||
|
healthcare costs rising across America
|
||||||
|
protests against police brutality
|
||||||
|
wildfires destroy homes in the West
|
||||||
|
Olympic games highlights and records
|
||||||
|
celebrity scandal rocks Hollywood
|
||||||
|
breakthrough cancer treatment discovered
|
||||||
|
housing market bubble concerns
|
||||||
|
federal reserve interest rate decision
|
||||||
|
school shooting tragedy response
|
||||||
|
diplomatic tensions between superpowers
|
||||||
|
drone strike kills terrorist leader
|
||||||
|
social media platform faces regulation
|
||||||
|
archaeological discovery reveals ancient civilization
|
||||||
|
unemployment rate hits record low
|
||||||
|
autonomous vehicles testing expansion
|
||||||
|
streaming service launches original content
|
||||||
|
opioid crisis intervention programs
|
||||||
|
trade war tariffs impact economy
|
||||||
|
infrastructure bill passes Congress
|
||||||
|
data privacy concerns grow
|
||||||
|
minimum wage increase proposal
|
||||||
|
college admissions scandal exposed
|
||||||
|
NFL player protest during anthem
|
||||||
|
cryptocurrency regulation debate
|
||||||
|
pandemic lockdown restrictions eased
|
||||||
|
mass shooting gun control debate
|
||||||
|
tax reform legislation impact
|
||||||
|
ransomware attack cripples pipeline
|
||||||
|
climate activists stage demonstration
|
||||||
|
sports team wins championship
|
||||||
|
banking system collapse fears
|
||||||
|
pharmaceutical company fraud charges
|
||||||
|
genetic engineering ethical concerns
|
||||||
|
border wall funding controversy
|
||||||
|
impeachment proceedings begin
|
||||||
|
nuclear weapons treaty violation
|
||||||
|
artificial meat alternative launch
|
||||||
|
student loan debt forgiveness
|
||||||
|
venture capital funding decline
|
||||||
|
facial recognition ban proposed
|
||||||
|
election interference investigation
|
||||||
|
pandemic preparedness failures
|
||||||
|
police reform measures announced
|
||||||
|
wildfire prevention strategies
|
||||||
|
ocean pollution crisis worsens
|
||||||
|
manufacturing jobs returning
|
||||||
|
pension fund shortfall concerns
|
||||||
|
antitrust investigation launched
|
||||||
|
voting rights protection act
|
||||||
|
mental health awareness campaign
|
||||||
|
homeless population increasing
|
||||||
|
space debris collision risk
|
||||||
|
drug cartel violence escalates
|
||||||
|
renewable energy jobs growth
|
||||||
|
infrastructure deterioration report
|
||||||
|
vaccine mandate legal challenge
|
||||||
|
cryptocurrency market volatility
|
||||||
|
autonomous drone delivery service
|
||||||
|
deep fake technology dangers
|
||||||
|
Arctic ice melting accelerates
|
||||||
|
income inequality gap widens
|
||||||
|
election fraud claims disputed
|
||||||
|
corporate merger blocked
|
||||||
|
medical breakthrough extends life
|
||||||
|
transportation strike disrupts city
|
||||||
|
racial justice protests spread
|
||||||
|
carbon emissions reduction goals
|
||||||
|
financial crisis warning signs
|
||||||
|
cyberbullying prevention efforts
|
||||||
|
asteroid near miss with Earth
|
||||||
|
gene therapy approval granted
|
||||||
|
labor union organizing drive
|
||||||
|
surveillance technology expansion
|
||||||
|
education funding cuts proposed
|
||||||
|
disaster relief efforts underway
|
||||||
|
housing affordability crisis
|
||||||
|
clean water access shortage
|
||||||
|
artificial intelligence job displacement
|
||||||
|
trade agreement negotiations
|
||||||
|
prison reform initiative launched
|
||||||
|
species extinction accelerates
|
||||||
|
political corruption scandal
|
||||||
|
terrorism threat level raised
|
||||||
|
food safety contamination outbreak
|
||||||
|
ai model release
|
||||||
|
affordability interest rates
|
||||||
|
peanut allergies in newbons
|
||||||
|
breaking bad walter white
|
||||||
29
benchmarks-ann/datasets/nyt-384/Makefile
Normal file
29
benchmarks-ann/datasets/nyt-384/Makefile
Normal file
|
|
@ -0,0 +1,29 @@
|
||||||
|
MODEL ?= mixedbread-ai/mxbai-embed-xsmall-v1
|
||||||
|
K ?= 100
|
||||||
|
BATCH_SIZE ?= 512
|
||||||
|
DATA_DIR ?= ../nyt/data
|
||||||
|
|
||||||
|
all: base.db
|
||||||
|
|
||||||
|
$(DATA_DIR):
|
||||||
|
$(MAKE) -C ../nyt data
|
||||||
|
|
||||||
|
contents.db: $(DATA_DIR)
|
||||||
|
uv run ../nyt-768/build-contents.py --data-dir $(DATA_DIR) -o $@
|
||||||
|
|
||||||
|
base.db: contents.db queries.txt
|
||||||
|
uv run ../nyt-1024/build-base.py \
|
||||||
|
--contents-db contents.db \
|
||||||
|
--model $(MODEL) \
|
||||||
|
--queries-file queries.txt \
|
||||||
|
--batch-size $(BATCH_SIZE) \
|
||||||
|
--k $(K) \
|
||||||
|
-o $@
|
||||||
|
|
||||||
|
queries.txt:
|
||||||
|
cp ../nyt/queries.txt $@
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -f base.db contents.db
|
||||||
|
|
||||||
|
.PHONY: all clean
|
||||||
100
benchmarks-ann/datasets/nyt-384/queries.txt
Normal file
100
benchmarks-ann/datasets/nyt-384/queries.txt
Normal file
|
|
@ -0,0 +1,100 @@
|
||||||
|
latest news on climate change policy
|
||||||
|
presidential election results and analysis
|
||||||
|
stock market crash causes
|
||||||
|
coronavirus vaccine development updates
|
||||||
|
artificial intelligence breakthrough in healthcare
|
||||||
|
supreme court ruling on abortion rights
|
||||||
|
tech companies layoff announcements
|
||||||
|
earthquake damages in California
|
||||||
|
cybersecurity breach at major corporation
|
||||||
|
space exploration mission to Mars
|
||||||
|
immigration reform legislation debate
|
||||||
|
renewable energy investment trends
|
||||||
|
healthcare costs rising across America
|
||||||
|
protests against police brutality
|
||||||
|
wildfires destroy homes in the West
|
||||||
|
Olympic games highlights and records
|
||||||
|
celebrity scandal rocks Hollywood
|
||||||
|
breakthrough cancer treatment discovered
|
||||||
|
housing market bubble concerns
|
||||||
|
federal reserve interest rate decision
|
||||||
|
school shooting tragedy response
|
||||||
|
diplomatic tensions between superpowers
|
||||||
|
drone strike kills terrorist leader
|
||||||
|
social media platform faces regulation
|
||||||
|
archaeological discovery reveals ancient civilization
|
||||||
|
unemployment rate hits record low
|
||||||
|
autonomous vehicles testing expansion
|
||||||
|
streaming service launches original content
|
||||||
|
opioid crisis intervention programs
|
||||||
|
trade war tariffs impact economy
|
||||||
|
infrastructure bill passes Congress
|
||||||
|
data privacy concerns grow
|
||||||
|
minimum wage increase proposal
|
||||||
|
college admissions scandal exposed
|
||||||
|
NFL player protest during anthem
|
||||||
|
cryptocurrency regulation debate
|
||||||
|
pandemic lockdown restrictions eased
|
||||||
|
mass shooting gun control debate
|
||||||
|
tax reform legislation impact
|
||||||
|
ransomware attack cripples pipeline
|
||||||
|
climate activists stage demonstration
|
||||||
|
sports team wins championship
|
||||||
|
banking system collapse fears
|
||||||
|
pharmaceutical company fraud charges
|
||||||
|
genetic engineering ethical concerns
|
||||||
|
border wall funding controversy
|
||||||
|
impeachment proceedings begin
|
||||||
|
nuclear weapons treaty violation
|
||||||
|
artificial meat alternative launch
|
||||||
|
student loan debt forgiveness
|
||||||
|
venture capital funding decline
|
||||||
|
facial recognition ban proposed
|
||||||
|
election interference investigation
|
||||||
|
pandemic preparedness failures
|
||||||
|
police reform measures announced
|
||||||
|
wildfire prevention strategies
|
||||||
|
ocean pollution crisis worsens
|
||||||
|
manufacturing jobs returning
|
||||||
|
pension fund shortfall concerns
|
||||||
|
antitrust investigation launched
|
||||||
|
voting rights protection act
|
||||||
|
mental health awareness campaign
|
||||||
|
homeless population increasing
|
||||||
|
space debris collision risk
|
||||||
|
drug cartel violence escalates
|
||||||
|
renewable energy jobs growth
|
||||||
|
infrastructure deterioration report
|
||||||
|
vaccine mandate legal challenge
|
||||||
|
cryptocurrency market volatility
|
||||||
|
autonomous drone delivery service
|
||||||
|
deep fake technology dangers
|
||||||
|
Arctic ice melting accelerates
|
||||||
|
income inequality gap widens
|
||||||
|
election fraud claims disputed
|
||||||
|
corporate merger blocked
|
||||||
|
medical breakthrough extends life
|
||||||
|
transportation strike disrupts city
|
||||||
|
racial justice protests spread
|
||||||
|
carbon emissions reduction goals
|
||||||
|
financial crisis warning signs
|
||||||
|
cyberbullying prevention efforts
|
||||||
|
asteroid near miss with Earth
|
||||||
|
gene therapy approval granted
|
||||||
|
labor union organizing drive
|
||||||
|
surveillance technology expansion
|
||||||
|
education funding cuts proposed
|
||||||
|
disaster relief efforts underway
|
||||||
|
housing affordability crisis
|
||||||
|
clean water access shortage
|
||||||
|
artificial intelligence job displacement
|
||||||
|
trade agreement negotiations
|
||||||
|
prison reform initiative launched
|
||||||
|
species extinction accelerates
|
||||||
|
political corruption scandal
|
||||||
|
terrorism threat level raised
|
||||||
|
food safety contamination outbreak
|
||||||
|
ai model release
|
||||||
|
affordability interest rates
|
||||||
|
peanut allergies in newbons
|
||||||
|
breaking bad walter white
|
||||||
37
benchmarks-ann/datasets/nyt-768/Makefile
Normal file
37
benchmarks-ann/datasets/nyt-768/Makefile
Normal file
|
|
@ -0,0 +1,37 @@
|
||||||
|
MODEL ?= bge-base-en-v1.5-768
|
||||||
|
K ?= 100
|
||||||
|
BATCH_SIZE ?= 512
|
||||||
|
DATA_DIR ?= ../nyt/data
|
||||||
|
|
||||||
|
all: base.db
|
||||||
|
|
||||||
|
# Reuse data from ../nyt
|
||||||
|
$(DATA_DIR):
|
||||||
|
$(MAKE) -C ../nyt data
|
||||||
|
|
||||||
|
# Distill model (separate step, may take a while)
|
||||||
|
$(MODEL):
|
||||||
|
uv run distill-model.py
|
||||||
|
|
||||||
|
contents.db: $(DATA_DIR)
|
||||||
|
uv run build-contents.py --data-dir $(DATA_DIR) -o $@
|
||||||
|
|
||||||
|
base.db: contents.db queries.txt $(MODEL)
|
||||||
|
uv run ../nyt/build-base.py \
|
||||||
|
--contents-db contents.db \
|
||||||
|
--model $(MODEL) \
|
||||||
|
--queries-file queries.txt \
|
||||||
|
--batch-size $(BATCH_SIZE) \
|
||||||
|
--k $(K) \
|
||||||
|
-o $@
|
||||||
|
|
||||||
|
queries.txt:
|
||||||
|
cp ../nyt/queries.txt $@
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -f base.db contents.db
|
||||||
|
|
||||||
|
clean-all: clean
|
||||||
|
rm -rf $(MODEL)
|
||||||
|
|
||||||
|
.PHONY: all clean clean-all
|
||||||
64
benchmarks-ann/datasets/nyt-768/build-contents.py
Normal file
64
benchmarks-ann/datasets/nyt-768/build-contents.py
Normal file
|
|
@ -0,0 +1,64 @@
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.12"
|
||||||
|
# dependencies = [
|
||||||
|
# "duckdb",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sqlite3
|
||||||
|
import duckdb
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Load NYT headline CSVs into a SQLite contents database (most recent 1M, deduplicated)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--data-dir", "-d", default="../nyt/data",
|
||||||
|
help="Directory containing NYT CSV files (default: ../nyt/data)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--limit", "-l", type=int, default=1_000_000,
|
||||||
|
help="Maximum number of headlines to keep (default: 1000000)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output", "-o", required=True,
|
||||||
|
help="Path to the output SQLite database",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
glob_pattern = f"{args.data_dir}/new_york_times_stories_*.csv"
|
||||||
|
|
||||||
|
con = duckdb.connect()
|
||||||
|
rows = con.execute(
|
||||||
|
f"""
|
||||||
|
WITH deduped AS (
|
||||||
|
SELECT
|
||||||
|
headline,
|
||||||
|
max(pub_date) AS pub_date
|
||||||
|
FROM read_csv('{glob_pattern}', auto_detect=true, union_by_name=true)
|
||||||
|
WHERE headline IS NOT NULL AND trim(headline) != ''
|
||||||
|
GROUP BY headline
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
row_number() OVER (ORDER BY pub_date DESC) AS id,
|
||||||
|
headline
|
||||||
|
FROM deduped
|
||||||
|
ORDER BY pub_date DESC
|
||||||
|
LIMIT {args.limit}
|
||||||
|
"""
|
||||||
|
).fetchall()
|
||||||
|
con.close()
|
||||||
|
|
||||||
|
db = sqlite3.connect(args.output)
|
||||||
|
db.execute("CREATE TABLE contents(id INTEGER PRIMARY KEY, headline TEXT)")
|
||||||
|
db.executemany("INSERT INTO contents VALUES (?, ?)", rows)
|
||||||
|
db.commit()
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
print(f"Wrote {len(rows)} headlines to {args.output}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
13
benchmarks-ann/datasets/nyt-768/distill-model.py
Normal file
13
benchmarks-ann/datasets/nyt-768/distill-model.py
Normal file
|
|
@ -0,0 +1,13 @@
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.12"
|
||||||
|
# dependencies = [
|
||||||
|
# "model2vec[distill]",
|
||||||
|
# "torch<=2.7",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
|
||||||
|
from model2vec.distill import distill
|
||||||
|
|
||||||
|
model = distill(model_name="BAAI/bge-base-en-v1.5", pca_dims=768)
|
||||||
|
model.save_pretrained("bge-base-en-v1.5-768")
|
||||||
|
print("Saved distilled model to bge-base-en-v1.5-768/")
|
||||||
100
benchmarks-ann/datasets/nyt-768/queries.txt
Normal file
100
benchmarks-ann/datasets/nyt-768/queries.txt
Normal file
|
|
@ -0,0 +1,100 @@
|
||||||
|
latest news on climate change policy
|
||||||
|
presidential election results and analysis
|
||||||
|
stock market crash causes
|
||||||
|
coronavirus vaccine development updates
|
||||||
|
artificial intelligence breakthrough in healthcare
|
||||||
|
supreme court ruling on abortion rights
|
||||||
|
tech companies layoff announcements
|
||||||
|
earthquake damages in California
|
||||||
|
cybersecurity breach at major corporation
|
||||||
|
space exploration mission to Mars
|
||||||
|
immigration reform legislation debate
|
||||||
|
renewable energy investment trends
|
||||||
|
healthcare costs rising across America
|
||||||
|
protests against police brutality
|
||||||
|
wildfires destroy homes in the West
|
||||||
|
Olympic games highlights and records
|
||||||
|
celebrity scandal rocks Hollywood
|
||||||
|
breakthrough cancer treatment discovered
|
||||||
|
housing market bubble concerns
|
||||||
|
federal reserve interest rate decision
|
||||||
|
school shooting tragedy response
|
||||||
|
diplomatic tensions between superpowers
|
||||||
|
drone strike kills terrorist leader
|
||||||
|
social media platform faces regulation
|
||||||
|
archaeological discovery reveals ancient civilization
|
||||||
|
unemployment rate hits record low
|
||||||
|
autonomous vehicles testing expansion
|
||||||
|
streaming service launches original content
|
||||||
|
opioid crisis intervention programs
|
||||||
|
trade war tariffs impact economy
|
||||||
|
infrastructure bill passes Congress
|
||||||
|
data privacy concerns grow
|
||||||
|
minimum wage increase proposal
|
||||||
|
college admissions scandal exposed
|
||||||
|
NFL player protest during anthem
|
||||||
|
cryptocurrency regulation debate
|
||||||
|
pandemic lockdown restrictions eased
|
||||||
|
mass shooting gun control debate
|
||||||
|
tax reform legislation impact
|
||||||
|
ransomware attack cripples pipeline
|
||||||
|
climate activists stage demonstration
|
||||||
|
sports team wins championship
|
||||||
|
banking system collapse fears
|
||||||
|
pharmaceutical company fraud charges
|
||||||
|
genetic engineering ethical concerns
|
||||||
|
border wall funding controversy
|
||||||
|
impeachment proceedings begin
|
||||||
|
nuclear weapons treaty violation
|
||||||
|
artificial meat alternative launch
|
||||||
|
student loan debt forgiveness
|
||||||
|
venture capital funding decline
|
||||||
|
facial recognition ban proposed
|
||||||
|
election interference investigation
|
||||||
|
pandemic preparedness failures
|
||||||
|
police reform measures announced
|
||||||
|
wildfire prevention strategies
|
||||||
|
ocean pollution crisis worsens
|
||||||
|
manufacturing jobs returning
|
||||||
|
pension fund shortfall concerns
|
||||||
|
antitrust investigation launched
|
||||||
|
voting rights protection act
|
||||||
|
mental health awareness campaign
|
||||||
|
homeless population increasing
|
||||||
|
space debris collision risk
|
||||||
|
drug cartel violence escalates
|
||||||
|
renewable energy jobs growth
|
||||||
|
infrastructure deterioration report
|
||||||
|
vaccine mandate legal challenge
|
||||||
|
cryptocurrency market volatility
|
||||||
|
autonomous drone delivery service
|
||||||
|
deep fake technology dangers
|
||||||
|
Arctic ice melting accelerates
|
||||||
|
income inequality gap widens
|
||||||
|
election fraud claims disputed
|
||||||
|
corporate merger blocked
|
||||||
|
medical breakthrough extends life
|
||||||
|
transportation strike disrupts city
|
||||||
|
racial justice protests spread
|
||||||
|
carbon emissions reduction goals
|
||||||
|
financial crisis warning signs
|
||||||
|
cyberbullying prevention efforts
|
||||||
|
asteroid near miss with Earth
|
||||||
|
gene therapy approval granted
|
||||||
|
labor union organizing drive
|
||||||
|
surveillance technology expansion
|
||||||
|
education funding cuts proposed
|
||||||
|
disaster relief efforts underway
|
||||||
|
housing affordability crisis
|
||||||
|
clean water access shortage
|
||||||
|
artificial intelligence job displacement
|
||||||
|
trade agreement negotiations
|
||||||
|
prison reform initiative launched
|
||||||
|
species extinction accelerates
|
||||||
|
political corruption scandal
|
||||||
|
terrorism threat level raised
|
||||||
|
food safety contamination outbreak
|
||||||
|
ai model release
|
||||||
|
affordability interest rates
|
||||||
|
peanut allergies in newbons
|
||||||
|
breaking bad walter white
|
||||||
1
benchmarks-ann/datasets/nyt/.gitignore
vendored
Normal file
1
benchmarks-ann/datasets/nyt/.gitignore
vendored
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
data/
|
||||||
30
benchmarks-ann/datasets/nyt/Makefile
Normal file
30
benchmarks-ann/datasets/nyt/Makefile
Normal file
|
|
@ -0,0 +1,30 @@
|
||||||
|
MODEL ?= minishlab/potion-base-8M
|
||||||
|
K ?= 100
|
||||||
|
BATCH_SIZE ?= 512
|
||||||
|
DATA_DIR ?= data
|
||||||
|
|
||||||
|
all: base.db contents.db
|
||||||
|
|
||||||
|
# Download NYT headlines CSVs from Kaggle (requires `kaggle` CLI + API token)
|
||||||
|
$(DATA_DIR):
|
||||||
|
kaggle datasets download -d johnbandy/new-york-times-headlines -p $(DATA_DIR) --unzip
|
||||||
|
|
||||||
|
contents.db: $(DATA_DIR)
|
||||||
|
uv run build-contents.py --data-dir $(DATA_DIR) -o $@
|
||||||
|
|
||||||
|
base.db: contents.db queries.txt
|
||||||
|
uv run build-base.py \
|
||||||
|
--contents-db contents.db \
|
||||||
|
--model $(MODEL) \
|
||||||
|
--queries-file queries.txt \
|
||||||
|
--batch-size $(BATCH_SIZE) \
|
||||||
|
--k $(K) \
|
||||||
|
-o $@
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -f base.db contents.db
|
||||||
|
|
||||||
|
clean-all: clean
|
||||||
|
rm -rf $(DATA_DIR)
|
||||||
|
|
||||||
|
.PHONY: all clean clean-all
|
||||||
165
benchmarks-ann/datasets/nyt/build-base.py
Normal file
165
benchmarks-ann/datasets/nyt/build-base.py
Normal file
|
|
@ -0,0 +1,165 @@
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.12"
|
||||||
|
# dependencies = [
|
||||||
|
# "model2vec",
|
||||||
|
# "torch<=2.7",
|
||||||
|
# "tqdm",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sqlite3
|
||||||
|
from array import array
|
||||||
|
from itertools import batched
|
||||||
|
|
||||||
|
from model2vec import StaticModel
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Build base.db with train vectors, query vectors, and brute-force KNN neighbors",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--contents-db", "-c", default=None,
|
||||||
|
help="Path to contents.db (source of headlines and IDs)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model", "-m", default="minishlab/potion-base-8M",
|
||||||
|
help="HuggingFace model ID or local path (default: minishlab/potion-base-8M)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--queries-file", "-q", default="queries.txt",
|
||||||
|
help="Path to the queries file (default: queries.txt)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output", "-o", required=True,
|
||||||
|
help="Path to the output base.db",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size", "-b", type=int, default=512,
|
||||||
|
help="Batch size for embedding (default: 512)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--k", "-k", type=int, default=100,
|
||||||
|
help="Number of nearest neighbors (default: 100)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vec-path", "-v", default="~/projects/sqlite-vec/dist/vec0",
|
||||||
|
help="Path to sqlite-vec extension (default: ~/projects/sqlite-vec/dist/vec0)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--rebuild-neighbors", action="store_true",
|
||||||
|
help="Only rebuild the neighbors table (skip embedding steps)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
import os
|
||||||
|
vec_path = os.path.expanduser(args.vec_path)
|
||||||
|
|
||||||
|
if args.rebuild_neighbors:
|
||||||
|
# Skip embedding, just open existing DB and rebuild neighbors
|
||||||
|
db = sqlite3.connect(args.output)
|
||||||
|
db.enable_load_extension(True)
|
||||||
|
db.load_extension(vec_path)
|
||||||
|
db.enable_load_extension(False)
|
||||||
|
db.execute("DROP TABLE IF EXISTS neighbors")
|
||||||
|
db.execute(
|
||||||
|
"CREATE TABLE neighbors("
|
||||||
|
" query_vector_id INTEGER, rank INTEGER, neighbors_id TEXT,"
|
||||||
|
" UNIQUE(query_vector_id, rank))"
|
||||||
|
)
|
||||||
|
print(f"Rebuilding neighbors in {args.output}...")
|
||||||
|
else:
|
||||||
|
print(f"Loading model {args.model}...")
|
||||||
|
model = StaticModel.from_pretrained(args.model)
|
||||||
|
|
||||||
|
# Read headlines from contents.db
|
||||||
|
src = sqlite3.connect(args.contents_db)
|
||||||
|
headlines = src.execute("SELECT id, headline FROM contents ORDER BY id").fetchall()
|
||||||
|
src.close()
|
||||||
|
print(f"Loaded {len(headlines)} headlines from {args.contents_db}")
|
||||||
|
|
||||||
|
# Read queries
|
||||||
|
with open(args.queries_file) as f:
|
||||||
|
queries = [line.strip() for line in f if line.strip()]
|
||||||
|
print(f"Loaded {len(queries)} queries from {args.queries_file}")
|
||||||
|
|
||||||
|
# Create output database
|
||||||
|
db = sqlite3.connect(args.output)
|
||||||
|
db.enable_load_extension(True)
|
||||||
|
db.load_extension(vec_path)
|
||||||
|
db.enable_load_extension(False)
|
||||||
|
|
||||||
|
db.execute("CREATE TABLE train(id INTEGER PRIMARY KEY, vector BLOB)")
|
||||||
|
db.execute("CREATE TABLE query_vectors(id INTEGER PRIMARY KEY, vector BLOB)")
|
||||||
|
db.execute(
|
||||||
|
"CREATE TABLE neighbors("
|
||||||
|
" query_vector_id INTEGER, rank INTEGER, neighbors_id TEXT,"
|
||||||
|
" UNIQUE(query_vector_id, rank))"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 1: Embed headlines -> train table
|
||||||
|
print("Embedding headlines...")
|
||||||
|
for batch in tqdm(
|
||||||
|
batched(headlines, args.batch_size),
|
||||||
|
total=(len(headlines) + args.batch_size - 1) // args.batch_size,
|
||||||
|
):
|
||||||
|
ids = [r[0] for r in batch]
|
||||||
|
texts = [r[1] for r in batch]
|
||||||
|
embeddings = model.encode(texts)
|
||||||
|
|
||||||
|
params = [
|
||||||
|
(int(rid), array("f", emb.tolist()).tobytes())
|
||||||
|
for rid, emb in zip(ids, embeddings)
|
||||||
|
]
|
||||||
|
db.executemany("INSERT INTO train VALUES (?, ?)", params)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
del headlines
|
||||||
|
n = db.execute("SELECT count(*) FROM train").fetchone()[0]
|
||||||
|
print(f"Embedded {n} headlines")
|
||||||
|
|
||||||
|
# Step 2: Embed queries -> query_vectors table
|
||||||
|
print("Embedding queries...")
|
||||||
|
query_embeddings = model.encode(queries)
|
||||||
|
query_params = []
|
||||||
|
for i, emb in enumerate(query_embeddings, 1):
|
||||||
|
blob = array("f", emb.tolist()).tobytes()
|
||||||
|
query_params.append((i, blob))
|
||||||
|
db.executemany("INSERT INTO query_vectors VALUES (?, ?)", query_params)
|
||||||
|
db.commit()
|
||||||
|
print(f"Embedded {len(queries)} queries")
|
||||||
|
|
||||||
|
# Step 3: Brute-force KNN via sqlite-vec -> neighbors table
|
||||||
|
n_queries = db.execute("SELECT count(*) FROM query_vectors").fetchone()[0]
|
||||||
|
print(f"Computing {args.k}-NN for {n_queries} queries via sqlite-vec...")
|
||||||
|
for query_id, query_blob in tqdm(
|
||||||
|
db.execute("SELECT id, vector FROM query_vectors").fetchall()
|
||||||
|
):
|
||||||
|
results = db.execute(
|
||||||
|
"""
|
||||||
|
SELECT
|
||||||
|
train.id,
|
||||||
|
vec_distance_cosine(train.vector, ?) AS distance
|
||||||
|
FROM train
|
||||||
|
WHERE distance IS NOT NULL
|
||||||
|
ORDER BY distance ASC
|
||||||
|
LIMIT ?
|
||||||
|
""",
|
||||||
|
(query_blob, args.k),
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
params = [
|
||||||
|
(query_id, rank, str(rid))
|
||||||
|
for rank, (rid, _dist) in enumerate(results)
|
||||||
|
]
|
||||||
|
db.executemany("INSERT INTO neighbors VALUES (?, ?, ?)", params)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.close()
|
||||||
|
print(f"Done. Wrote {args.output}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
52
benchmarks-ann/datasets/nyt/build-contents.py
Normal file
52
benchmarks-ann/datasets/nyt/build-contents.py
Normal file
|
|
@ -0,0 +1,52 @@
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.12"
|
||||||
|
# dependencies = [
|
||||||
|
# "duckdb",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import duckdb
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Load NYT headline CSVs into a SQLite contents database via DuckDB",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--data-dir", "-d", default="data",
|
||||||
|
help="Directory containing NYT CSV files (default: data)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output", "-o", required=True,
|
||||||
|
help="Path to the output SQLite database",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
glob_pattern = os.path.join(args.data_dir, "new_york_times_stories_*.csv")
|
||||||
|
|
||||||
|
con = duckdb.connect()
|
||||||
|
rows = con.execute(
|
||||||
|
f"""
|
||||||
|
SELECT
|
||||||
|
row_number() OVER () AS id,
|
||||||
|
headline
|
||||||
|
FROM read_csv('{glob_pattern}', auto_detect=true, union_by_name=true)
|
||||||
|
WHERE headline IS NOT NULL AND headline != ''
|
||||||
|
"""
|
||||||
|
).fetchall()
|
||||||
|
con.close()
|
||||||
|
|
||||||
|
db = sqlite3.connect(args.output)
|
||||||
|
db.execute("CREATE TABLE contents(id INTEGER PRIMARY KEY, headline TEXT)")
|
||||||
|
db.executemany("INSERT INTO contents VALUES (?, ?)", rows)
|
||||||
|
db.commit()
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
print(f"Wrote {len(rows)} headlines to {args.output}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
100
benchmarks-ann/datasets/nyt/queries.txt
Normal file
100
benchmarks-ann/datasets/nyt/queries.txt
Normal file
|
|
@ -0,0 +1,100 @@
|
||||||
|
latest news on climate change policy
|
||||||
|
presidential election results and analysis
|
||||||
|
stock market crash causes
|
||||||
|
coronavirus vaccine development updates
|
||||||
|
artificial intelligence breakthrough in healthcare
|
||||||
|
supreme court ruling on abortion rights
|
||||||
|
tech companies layoff announcements
|
||||||
|
earthquake damages in California
|
||||||
|
cybersecurity breach at major corporation
|
||||||
|
space exploration mission to Mars
|
||||||
|
immigration reform legislation debate
|
||||||
|
renewable energy investment trends
|
||||||
|
healthcare costs rising across America
|
||||||
|
protests against police brutality
|
||||||
|
wildfires destroy homes in the West
|
||||||
|
Olympic games highlights and records
|
||||||
|
celebrity scandal rocks Hollywood
|
||||||
|
breakthrough cancer treatment discovered
|
||||||
|
housing market bubble concerns
|
||||||
|
federal reserve interest rate decision
|
||||||
|
school shooting tragedy response
|
||||||
|
diplomatic tensions between superpowers
|
||||||
|
drone strike kills terrorist leader
|
||||||
|
social media platform faces regulation
|
||||||
|
archaeological discovery reveals ancient civilization
|
||||||
|
unemployment rate hits record low
|
||||||
|
autonomous vehicles testing expansion
|
||||||
|
streaming service launches original content
|
||||||
|
opioid crisis intervention programs
|
||||||
|
trade war tariffs impact economy
|
||||||
|
infrastructure bill passes Congress
|
||||||
|
data privacy concerns grow
|
||||||
|
minimum wage increase proposal
|
||||||
|
college admissions scandal exposed
|
||||||
|
NFL player protest during anthem
|
||||||
|
cryptocurrency regulation debate
|
||||||
|
pandemic lockdown restrictions eased
|
||||||
|
mass shooting gun control debate
|
||||||
|
tax reform legislation impact
|
||||||
|
ransomware attack cripples pipeline
|
||||||
|
climate activists stage demonstration
|
||||||
|
sports team wins championship
|
||||||
|
banking system collapse fears
|
||||||
|
pharmaceutical company fraud charges
|
||||||
|
genetic engineering ethical concerns
|
||||||
|
border wall funding controversy
|
||||||
|
impeachment proceedings begin
|
||||||
|
nuclear weapons treaty violation
|
||||||
|
artificial meat alternative launch
|
||||||
|
student loan debt forgiveness
|
||||||
|
venture capital funding decline
|
||||||
|
facial recognition ban proposed
|
||||||
|
election interference investigation
|
||||||
|
pandemic preparedness failures
|
||||||
|
police reform measures announced
|
||||||
|
wildfire prevention strategies
|
||||||
|
ocean pollution crisis worsens
|
||||||
|
manufacturing jobs returning
|
||||||
|
pension fund shortfall concerns
|
||||||
|
antitrust investigation launched
|
||||||
|
voting rights protection act
|
||||||
|
mental health awareness campaign
|
||||||
|
homeless population increasing
|
||||||
|
space debris collision risk
|
||||||
|
drug cartel violence escalates
|
||||||
|
renewable energy jobs growth
|
||||||
|
infrastructure deterioration report
|
||||||
|
vaccine mandate legal challenge
|
||||||
|
cryptocurrency market volatility
|
||||||
|
autonomous drone delivery service
|
||||||
|
deep fake technology dangers
|
||||||
|
Arctic ice melting accelerates
|
||||||
|
income inequality gap widens
|
||||||
|
election fraud claims disputed
|
||||||
|
corporate merger blocked
|
||||||
|
medical breakthrough extends life
|
||||||
|
transportation strike disrupts city
|
||||||
|
racial justice protests spread
|
||||||
|
carbon emissions reduction goals
|
||||||
|
financial crisis warning signs
|
||||||
|
cyberbullying prevention efforts
|
||||||
|
asteroid near miss with Earth
|
||||||
|
gene therapy approval granted
|
||||||
|
labor union organizing drive
|
||||||
|
surveillance technology expansion
|
||||||
|
education funding cuts proposed
|
||||||
|
disaster relief efforts underway
|
||||||
|
housing affordability crisis
|
||||||
|
clean water access shortage
|
||||||
|
artificial intelligence job displacement
|
||||||
|
trade agreement negotiations
|
||||||
|
prison reform initiative launched
|
||||||
|
species extinction accelerates
|
||||||
|
political corruption scandal
|
||||||
|
terrorism threat level raised
|
||||||
|
food safety contamination outbreak
|
||||||
|
ai model release
|
||||||
|
affordability interest rates
|
||||||
|
peanut allergies in newbons
|
||||||
|
breaking bad walter white
|
||||||
101
benchmarks-ann/faiss_kmeans.py
Normal file
101
benchmarks-ann/faiss_kmeans.py
Normal file
|
|
@ -0,0 +1,101 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Compute k-means centroids using FAISS and save to a centroids DB.
|
||||||
|
|
||||||
|
Reads the first N vectors from a base.db, runs FAISS k-means, and writes
|
||||||
|
the centroids to an output SQLite DB as float32 blobs.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python faiss_kmeans.py --base-db datasets/cohere10m/base.db --ntrain 100000 \
|
||||||
|
--nclusters 8192 -o centroids.db
|
||||||
|
|
||||||
|
Output schema:
|
||||||
|
CREATE TABLE centroids (
|
||||||
|
centroid_id INTEGER PRIMARY KEY,
|
||||||
|
centroid BLOB NOT NULL -- float32[D]
|
||||||
|
);
|
||||||
|
CREATE TABLE meta (key TEXT PRIMARY KEY, value TEXT);
|
||||||
|
-- ntrain, nclusters, dimensions, elapsed_s
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import struct
|
||||||
|
import time
|
||||||
|
|
||||||
|
import faiss
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="FAISS k-means centroid computation")
|
||||||
|
parser.add_argument("--base-db", required=True, help="path to base.db with train table")
|
||||||
|
parser.add_argument("--ntrain", type=int, required=True, help="number of vectors to train on")
|
||||||
|
parser.add_argument("--nclusters", type=int, required=True, help="number of clusters (nlist)")
|
||||||
|
parser.add_argument("--niter", type=int, default=20, help="k-means iterations (default 20)")
|
||||||
|
parser.add_argument("--seed", type=int, default=42, help="random seed")
|
||||||
|
parser.add_argument("-o", "--output", required=True, help="output centroids DB path")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Load vectors
|
||||||
|
print(f"Loading {args.ntrain} vectors from {args.base_db}...")
|
||||||
|
conn = sqlite3.connect(args.base_db)
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT vector FROM train ORDER BY id LIMIT ?", (args.ntrain,)
|
||||||
|
).fetchall()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
# Parse float32 blobs to numpy
|
||||||
|
first_blob = rows[0][0]
|
||||||
|
D = len(first_blob) // 4 # float32
|
||||||
|
print(f" Dimensions: {D}, loaded {len(rows)} vectors")
|
||||||
|
|
||||||
|
vectors = np.zeros((len(rows), D), dtype=np.float32)
|
||||||
|
for i, (blob,) in enumerate(rows):
|
||||||
|
vectors[i] = np.frombuffer(blob, dtype=np.float32)
|
||||||
|
|
||||||
|
# Normalize for cosine distance (FAISS k-means on L2 of unit vectors ≈ cosine)
|
||||||
|
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
|
||||||
|
norms[norms == 0] = 1
|
||||||
|
vectors /= norms
|
||||||
|
|
||||||
|
# Run FAISS k-means
|
||||||
|
print(f"Running k-means: {args.nclusters} clusters, {args.niter} iterations...")
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
kmeans = faiss.Kmeans(
|
||||||
|
D, args.nclusters,
|
||||||
|
niter=args.niter,
|
||||||
|
seed=args.seed,
|
||||||
|
verbose=True,
|
||||||
|
gpu=False,
|
||||||
|
)
|
||||||
|
kmeans.train(vectors)
|
||||||
|
elapsed = time.perf_counter() - t0
|
||||||
|
print(f" Done in {elapsed:.1f}s")
|
||||||
|
|
||||||
|
centroids = kmeans.centroids # (nclusters, D) float32
|
||||||
|
|
||||||
|
# Write output DB
|
||||||
|
if os.path.exists(args.output):
|
||||||
|
os.remove(args.output)
|
||||||
|
out = sqlite3.connect(args.output)
|
||||||
|
out.execute("CREATE TABLE centroids (centroid_id INTEGER PRIMARY KEY, centroid BLOB NOT NULL)")
|
||||||
|
out.execute("CREATE TABLE meta (key TEXT PRIMARY KEY, value TEXT)")
|
||||||
|
|
||||||
|
for i in range(args.nclusters):
|
||||||
|
blob = centroids[i].tobytes()
|
||||||
|
out.execute("INSERT INTO centroids (centroid_id, centroid) VALUES (?, ?)", (i, blob))
|
||||||
|
|
||||||
|
out.execute("INSERT INTO meta VALUES ('ntrain', ?)", (str(args.ntrain),))
|
||||||
|
out.execute("INSERT INTO meta VALUES ('nclusters', ?)", (str(args.nclusters),))
|
||||||
|
out.execute("INSERT INTO meta VALUES ('dimensions', ?)", (str(D),))
|
||||||
|
out.execute("INSERT INTO meta VALUES ('niter', ?)", (str(args.niter),))
|
||||||
|
out.execute("INSERT INTO meta VALUES ('elapsed_s', ?)", (str(round(elapsed, 3)),))
|
||||||
|
out.execute("INSERT INTO meta VALUES ('seed', ?)", (str(args.seed),))
|
||||||
|
out.commit()
|
||||||
|
out.close()
|
||||||
|
|
||||||
|
print(f"Wrote {args.nclusters} centroids to {args.output}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
168
benchmarks-ann/ground_truth.py
Normal file
168
benchmarks-ann/ground_truth.py
Normal file
|
|
@ -0,0 +1,168 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Compute per-subset ground truth for ANN benchmarks.
|
||||||
|
|
||||||
|
For subset sizes < 1M, builds a temporary vec0 float table with the first N
|
||||||
|
vectors and runs brute-force KNN to get correct ground truth per subset.
|
||||||
|
|
||||||
|
For 1M (the full dataset), converts the existing `neighbors` table.
|
||||||
|
|
||||||
|
Output: ground_truth.{subset_size}.db with table:
|
||||||
|
ground_truth(query_vector_id, rank, neighbor_id, distance)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python ground_truth.py --subset-size 50000
|
||||||
|
python ground_truth.py --subset-size 1000000
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import time
|
||||||
|
|
||||||
|
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
EXT_PATH = os.path.join(_SCRIPT_DIR, "..", "dist", "vec0")
|
||||||
|
BASE_DB = os.path.join(_SCRIPT_DIR, "seed", "base.db")
|
||||||
|
FULL_DATASET_SIZE = 1_000_000
|
||||||
|
|
||||||
|
|
||||||
|
def gen_ground_truth_subset(base_db, ext_path, subset_size, n_queries, k, out_path):
|
||||||
|
"""Build ground truth by brute-force KNN over the first `subset_size` vectors."""
|
||||||
|
if os.path.exists(out_path):
|
||||||
|
os.remove(out_path)
|
||||||
|
|
||||||
|
conn = sqlite3.connect(out_path)
|
||||||
|
conn.enable_load_extension(True)
|
||||||
|
conn.load_extension(ext_path)
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
"CREATE TABLE ground_truth ("
|
||||||
|
" query_vector_id INTEGER NOT NULL,"
|
||||||
|
" rank INTEGER NOT NULL,"
|
||||||
|
" neighbor_id INTEGER NOT NULL,"
|
||||||
|
" distance REAL NOT NULL,"
|
||||||
|
" PRIMARY KEY (query_vector_id, rank)"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.execute(f"ATTACH DATABASE '{base_db}' AS base")
|
||||||
|
|
||||||
|
print(f" Building temp vec0 table with {subset_size} vectors...")
|
||||||
|
conn.execute(
|
||||||
|
"CREATE VIRTUAL TABLE tmp_vec USING vec0("
|
||||||
|
" id integer primary key,"
|
||||||
|
" embedding float[768] distance_metric=cosine"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO tmp_vec(id, embedding) "
|
||||||
|
"SELECT id, vector FROM base.train WHERE id < :n",
|
||||||
|
{"n": subset_size},
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
build_time = time.perf_counter() - t0
|
||||||
|
print(f" Temp table built in {build_time:.1f}s")
|
||||||
|
|
||||||
|
query_vectors = conn.execute(
|
||||||
|
"SELECT id, vector FROM base.query_vectors ORDER BY id LIMIT :n",
|
||||||
|
{"n": n_queries},
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
print(f" Running brute-force KNN for {len(query_vectors)} queries, k={k}...")
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
|
||||||
|
for i, (qid, qvec) in enumerate(query_vectors):
|
||||||
|
results = conn.execute(
|
||||||
|
"SELECT id, distance FROM tmp_vec "
|
||||||
|
"WHERE embedding MATCH :query AND k = :k",
|
||||||
|
{"query": qvec, "k": k},
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
for rank, (nid, dist) in enumerate(results):
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO ground_truth(query_vector_id, rank, neighbor_id, distance) "
|
||||||
|
"VALUES (?, ?, ?, ?)",
|
||||||
|
(qid, rank, nid, dist),
|
||||||
|
)
|
||||||
|
|
||||||
|
if (i + 1) % 10 == 0 or i == 0:
|
||||||
|
elapsed = time.perf_counter() - t0
|
||||||
|
eta = (elapsed / (i + 1)) * (len(query_vectors) - i - 1)
|
||||||
|
print(
|
||||||
|
f" {i+1}/{len(query_vectors)} queries "
|
||||||
|
f"elapsed={elapsed:.1f}s eta={eta:.1f}s",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
conn.execute("DROP TABLE tmp_vec")
|
||||||
|
conn.execute("DETACH DATABASE base")
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
elapsed = time.perf_counter() - t0
|
||||||
|
total_rows = conn.execute("SELECT count(*) FROM ground_truth").fetchone()[0]
|
||||||
|
conn.close()
|
||||||
|
print(f" Ground truth: {total_rows} rows in {elapsed:.1f}s -> {out_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def gen_ground_truth_full(base_db, n_queries, k, out_path):
|
||||||
|
"""Convert the existing neighbors table for the full 1M dataset."""
|
||||||
|
if os.path.exists(out_path):
|
||||||
|
os.remove(out_path)
|
||||||
|
|
||||||
|
conn = sqlite3.connect(out_path)
|
||||||
|
conn.execute(f"ATTACH DATABASE '{base_db}' AS base")
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
"CREATE TABLE ground_truth ("
|
||||||
|
" query_vector_id INTEGER NOT NULL,"
|
||||||
|
" rank INTEGER NOT NULL,"
|
||||||
|
" neighbor_id INTEGER NOT NULL,"
|
||||||
|
" distance REAL,"
|
||||||
|
" PRIMARY KEY (query_vector_id, rank)"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO ground_truth(query_vector_id, rank, neighbor_id) "
|
||||||
|
"SELECT query_vector_id, rank, CAST(neighbors_id AS INTEGER) "
|
||||||
|
"FROM base.neighbors "
|
||||||
|
"WHERE query_vector_id < :n AND rank < :k",
|
||||||
|
{"n": n_queries, "k": k},
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
total_rows = conn.execute("SELECT count(*) FROM ground_truth").fetchone()[0]
|
||||||
|
conn.execute("DETACH DATABASE base")
|
||||||
|
conn.close()
|
||||||
|
print(f" Ground truth (full): {total_rows} rows -> {out_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Generate per-subset ground truth")
|
||||||
|
parser.add_argument(
|
||||||
|
"--subset-size", type=int, required=True, help="number of vectors in subset"
|
||||||
|
)
|
||||||
|
parser.add_argument("-n", type=int, default=100, help="number of query vectors")
|
||||||
|
parser.add_argument("-k", type=int, default=100, help="max k for ground truth")
|
||||||
|
parser.add_argument("--base-db", default=BASE_DB)
|
||||||
|
parser.add_argument("--ext", default=EXT_PATH)
|
||||||
|
parser.add_argument(
|
||||||
|
"-o", "--out-dir", default=os.path.join(_SCRIPT_DIR, "seed"),
|
||||||
|
help="output directory for ground_truth.{N}.db",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
os.makedirs(args.out_dir, exist_ok=True)
|
||||||
|
out_path = os.path.join(args.out_dir, f"ground_truth.{args.subset_size}.db")
|
||||||
|
|
||||||
|
if args.subset_size >= FULL_DATASET_SIZE:
|
||||||
|
gen_ground_truth_full(args.base_db, args.n, args.k, out_path)
|
||||||
|
else:
|
||||||
|
gen_ground_truth_subset(
|
||||||
|
args.base_db, args.ext, args.subset_size, args.n, args.k, out_path
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
440
benchmarks-ann/profile.py
Normal file
440
benchmarks-ann/profile.py
Normal file
|
|
@ -0,0 +1,440 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""CPU profiling for sqlite-vec KNN configurations using macOS `sample` tool.
|
||||||
|
|
||||||
|
Builds dist/sqlite3 (with -g3), generates a SQL workload (inserts + repeated
|
||||||
|
KNN queries) for each config, profiles the sqlite3 process with `sample`, and
|
||||||
|
prints the top-N hottest functions by self (exclusive) CPU samples.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
cd benchmarks-ann
|
||||||
|
uv run profile.py --subset-size 50000 -n 50 \\
|
||||||
|
"baseline-int8:type=baseline,variant=int8,oversample=8" \\
|
||||||
|
"rescore-int8:type=rescore,quantizer=int8,oversample=8"
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
_PROJECT_ROOT = os.path.join(_SCRIPT_DIR, "..")
|
||||||
|
|
||||||
|
sys.path.insert(0, _SCRIPT_DIR)
|
||||||
|
from bench import (
|
||||||
|
BASE_DB,
|
||||||
|
DEFAULT_INSERT_SQL,
|
||||||
|
INDEX_REGISTRY,
|
||||||
|
INSERT_BATCH_SIZE,
|
||||||
|
parse_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
SQLITE3_PATH = os.path.join(_PROJECT_ROOT, "dist", "sqlite3")
|
||||||
|
EXT_PATH = os.path.join(_PROJECT_ROOT, "dist", "vec0")
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# SQL generation
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def _query_sql_for_config(params, query_id, k):
|
||||||
|
"""Return a SQL query string for a single KNN query by query_vector id."""
|
||||||
|
index_type = params["index_type"]
|
||||||
|
qvec = f"(SELECT vector FROM base.query_vectors WHERE id = {query_id})"
|
||||||
|
|
||||||
|
if index_type == "baseline":
|
||||||
|
variant = params.get("variant", "float")
|
||||||
|
oversample = params.get("oversample", 8)
|
||||||
|
oversample_k = k * oversample
|
||||||
|
|
||||||
|
if variant == "int8":
|
||||||
|
return (
|
||||||
|
f"WITH coarse AS ("
|
||||||
|
f" SELECT id, embedding FROM vec_items"
|
||||||
|
f" WHERE embedding_int8 MATCH vec_quantize_int8({qvec}, 'unit')"
|
||||||
|
f" LIMIT {oversample_k}"
|
||||||
|
f") "
|
||||||
|
f"SELECT id, vec_distance_cosine(embedding, {qvec}) as distance "
|
||||||
|
f"FROM coarse ORDER BY 2 LIMIT {k};"
|
||||||
|
)
|
||||||
|
elif variant == "bit":
|
||||||
|
return (
|
||||||
|
f"WITH coarse AS ("
|
||||||
|
f" SELECT id, embedding FROM vec_items"
|
||||||
|
f" WHERE embedding_bq MATCH vec_quantize_binary({qvec})"
|
||||||
|
f" LIMIT {oversample_k}"
|
||||||
|
f") "
|
||||||
|
f"SELECT id, vec_distance_cosine(embedding, {qvec}) as distance "
|
||||||
|
f"FROM coarse ORDER BY 2 LIMIT {k};"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Default MATCH query (baseline-float, rescore, and others)
|
||||||
|
return (
|
||||||
|
f"SELECT id, distance FROM vec_items"
|
||||||
|
f" WHERE embedding MATCH {qvec} AND k = {k};"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_sql(db_path, params, subset_size, n_queries, k, repeats):
|
||||||
|
"""Generate a complete SQL workload: load ext, create table, insert, query."""
|
||||||
|
lines = []
|
||||||
|
lines.append(".bail on")
|
||||||
|
lines.append(f".load {EXT_PATH}")
|
||||||
|
lines.append(f"ATTACH DATABASE '{os.path.abspath(BASE_DB)}' AS base;")
|
||||||
|
lines.append("PRAGMA page_size=8192;")
|
||||||
|
|
||||||
|
# Create table
|
||||||
|
reg = INDEX_REGISTRY[params["index_type"]]
|
||||||
|
lines.append(reg["create_table_sql"](params) + ";")
|
||||||
|
|
||||||
|
# Inserts
|
||||||
|
sql_fn = reg.get("insert_sql")
|
||||||
|
insert_sql = sql_fn(params) if sql_fn else None
|
||||||
|
if insert_sql is None:
|
||||||
|
insert_sql = DEFAULT_INSERT_SQL
|
||||||
|
for lo in range(0, subset_size, INSERT_BATCH_SIZE):
|
||||||
|
hi = min(lo + INSERT_BATCH_SIZE, subset_size)
|
||||||
|
stmt = insert_sql.replace(":lo", str(lo)).replace(":hi", str(hi))
|
||||||
|
lines.append(stmt + ";")
|
||||||
|
if hi % 10000 == 0 or hi == subset_size:
|
||||||
|
lines.append("-- progress: inserted %d/%d" % (hi, subset_size))
|
||||||
|
|
||||||
|
# Queries (repeated)
|
||||||
|
lines.append("-- BEGIN QUERIES")
|
||||||
|
for _rep in range(repeats):
|
||||||
|
for qid in range(n_queries):
|
||||||
|
lines.append(_query_sql_for_config(params, qid, k))
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Profiling with macOS `sample`
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def run_profile(sqlite3_path, db_path, sql_file, sample_output, duration=120):
|
||||||
|
"""Run sqlite3 under macOS `sample` profiler.
|
||||||
|
|
||||||
|
Starts sqlite3 directly with stdin from the SQL file, then immediately
|
||||||
|
attaches `sample` to its PID with -mayDie (tolerates process exit).
|
||||||
|
The workload must be long enough for sample to attach and capture useful data.
|
||||||
|
"""
|
||||||
|
sql_fd = open(sql_file, "r")
|
||||||
|
proc = subprocess.Popen(
|
||||||
|
[sqlite3_path, db_path],
|
||||||
|
stdin=sql_fd,
|
||||||
|
stdout=subprocess.DEVNULL,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
|
||||||
|
pid = proc.pid
|
||||||
|
print(f" sqlite3 PID: {pid}")
|
||||||
|
|
||||||
|
# Attach sample immediately (1ms interval, -mayDie tolerates process exit)
|
||||||
|
sample_proc = subprocess.Popen(
|
||||||
|
["sample", str(pid), str(duration), "1", "-mayDie", "-file", sample_output],
|
||||||
|
stdout=subprocess.DEVNULL,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for sqlite3 to finish
|
||||||
|
_, stderr = proc.communicate()
|
||||||
|
sql_fd.close()
|
||||||
|
rc = proc.returncode
|
||||||
|
if rc != 0:
|
||||||
|
print(f" sqlite3 failed (rc={rc}):", file=sys.stderr)
|
||||||
|
print(f" {stderr.decode().strip()}", file=sys.stderr)
|
||||||
|
sample_proc.kill()
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Wait for sample to finish
|
||||||
|
sample_proc.wait()
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Parse `sample` output
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# Tree-drawing characters used by macOS `sample` to represent hierarchy.
|
||||||
|
# We replace them with spaces so indentation depth reflects tree depth.
|
||||||
|
_TREE_CHARS_RE = re.compile(r"[+!:|]")
|
||||||
|
|
||||||
|
# After tree chars are replaced with spaces, each call-graph line looks like:
|
||||||
|
# " 800 rescore_knn (in vec0.dylib) + 3808,3640,... [0x1a,0x2b,...] file.c:123"
|
||||||
|
# We extract just (indent, count, symbol, module) — everything after "(in ...)"
|
||||||
|
# is decoration we don't need.
|
||||||
|
_LEADING_RE = re.compile(r"^(\s+)(\d+)\s+(.+)")
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_symbol_and_module(rest):
|
||||||
|
"""Given the text after 'count ', extract (symbol, module).
|
||||||
|
|
||||||
|
Handles patterns like:
|
||||||
|
'rescore_knn (in vec0.dylib) + 3808,3640,... [0x...]'
|
||||||
|
'pread (in libsystem_kernel.dylib) + 8 [0x...]'
|
||||||
|
'??? (in <unknown binary>) [0x...]'
|
||||||
|
'start (in dyld) + 2840 [0x198650274]'
|
||||||
|
'Thread_26759239 DispatchQueue_1: ...'
|
||||||
|
"""
|
||||||
|
# Try to find "(in ...)" to split symbol from module
|
||||||
|
m = re.match(r"^(.+?)\s+\(in\s+(.+?)\)", rest)
|
||||||
|
if m:
|
||||||
|
return m.group(1).strip(), m.group(2).strip()
|
||||||
|
# No module — return whole thing as symbol, strip trailing junk
|
||||||
|
sym = re.sub(r"\s+\[0x[0-9a-f].*", "", rest).strip()
|
||||||
|
return sym, ""
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_call_graph_lines(text):
|
||||||
|
"""Parse call-graph section into list of (depth, count, symbol, module)."""
|
||||||
|
entries = []
|
||||||
|
for raw_line in text.split("\n"):
|
||||||
|
# Strip tree-drawing characters, replace with spaces to preserve depth
|
||||||
|
line = _TREE_CHARS_RE.sub(" ", raw_line)
|
||||||
|
m = _LEADING_RE.match(line)
|
||||||
|
if not m:
|
||||||
|
continue
|
||||||
|
depth = len(m.group(1))
|
||||||
|
count = int(m.group(2))
|
||||||
|
rest = m.group(3)
|
||||||
|
symbol, module = _extract_symbol_and_module(rest)
|
||||||
|
entries.append((depth, count, symbol, module))
|
||||||
|
return entries
|
||||||
|
|
||||||
|
|
||||||
|
def parse_sample_output(filepath):
|
||||||
|
"""Parse `sample` call-graph output, compute exclusive (self) samples per function.
|
||||||
|
|
||||||
|
Returns dict of {display_name: self_sample_count}.
|
||||||
|
"""
|
||||||
|
with open(filepath, "r") as f:
|
||||||
|
text = f.read()
|
||||||
|
|
||||||
|
# Find "Call graph:" section
|
||||||
|
cg_start = text.find("Call graph:")
|
||||||
|
if cg_start == -1:
|
||||||
|
print(" Warning: no 'Call graph:' section found in sample output")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# End at "Total number in stack" or EOF
|
||||||
|
cg_end = text.find("\nTotal number in stack", cg_start)
|
||||||
|
if cg_end == -1:
|
||||||
|
cg_end = len(text)
|
||||||
|
|
||||||
|
entries = _parse_call_graph_lines(text[cg_start:cg_end])
|
||||||
|
|
||||||
|
if not entries:
|
||||||
|
print(" Warning: no call graph entries parsed")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Compute self (exclusive) samples per function:
|
||||||
|
# self = count - sum(direct_children_counts)
|
||||||
|
self_samples = {}
|
||||||
|
for i, (depth, count, sym, mod) in enumerate(entries):
|
||||||
|
children_sum = 0
|
||||||
|
child_depth = None
|
||||||
|
for j in range(i + 1, len(entries)):
|
||||||
|
j_depth = entries[j][0]
|
||||||
|
if j_depth <= depth:
|
||||||
|
break
|
||||||
|
if child_depth is None:
|
||||||
|
child_depth = j_depth
|
||||||
|
if j_depth == child_depth:
|
||||||
|
children_sum += entries[j][1]
|
||||||
|
|
||||||
|
self_count = count - children_sum
|
||||||
|
if self_count > 0:
|
||||||
|
key = f"{sym} ({mod})" if mod else sym
|
||||||
|
self_samples[key] = self_samples.get(key, 0) + self_count
|
||||||
|
|
||||||
|
return self_samples
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Display
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def print_profile(title, self_samples, top_n=20):
|
||||||
|
total = sum(self_samples.values())
|
||||||
|
if total == 0:
|
||||||
|
print(f"\n=== {title} (no samples) ===")
|
||||||
|
return
|
||||||
|
|
||||||
|
sorted_syms = sorted(self_samples.items(), key=lambda x: -x[1])
|
||||||
|
|
||||||
|
print(f"\n=== {title} (top {top_n}, {total} total self-samples) ===")
|
||||||
|
for sym, count in sorted_syms[:top_n]:
|
||||||
|
pct = 100.0 * count / total
|
||||||
|
print(f" {pct:5.1f}% {count:>6} {sym}")
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Main
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="CPU profiling for sqlite-vec KNN configurations",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
epilog=__doc__,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"configs", nargs="+", help="config specs (name:type=X,key=val,...)"
|
||||||
|
)
|
||||||
|
parser.add_argument("--subset-size", type=int, required=True)
|
||||||
|
parser.add_argument("-k", type=int, default=10, help="KNN k (default 10)")
|
||||||
|
parser.add_argument(
|
||||||
|
"-n", type=int, default=50, help="number of distinct queries (default 50)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--repeats",
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help="repeat query set N times for more samples (default 10)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top", type=int, default=20, help="show top N functions (default 20)"
|
||||||
|
)
|
||||||
|
parser.add_argument("--base-db", default=BASE_DB)
|
||||||
|
parser.add_argument("--sqlite3", default=SQLITE3_PATH)
|
||||||
|
parser.add_argument(
|
||||||
|
"--keep-temp",
|
||||||
|
action="store_true",
|
||||||
|
help="keep temp directory with DBs, SQL, and sample output",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Check prerequisites
|
||||||
|
if not os.path.exists(args.base_db):
|
||||||
|
print(f"Error: base DB not found at {args.base_db}", file=sys.stderr)
|
||||||
|
print("Run 'make seed' in benchmarks-ann/ first.", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if not shutil.which("sample"):
|
||||||
|
print("Error: macOS 'sample' tool not found.", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Build CLI
|
||||||
|
print("Building dist/sqlite3...")
|
||||||
|
result = subprocess.run(
|
||||||
|
["make", "cli"], cwd=_PROJECT_ROOT, capture_output=True, text=True
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
print(f"Error: make cli failed:\n{result.stderr}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
print(" done.")
|
||||||
|
|
||||||
|
if not os.path.exists(args.sqlite3):
|
||||||
|
print(f"Error: sqlite3 not found at {args.sqlite3}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
configs = [parse_config(c) for c in args.configs]
|
||||||
|
|
||||||
|
tmpdir = tempfile.mkdtemp(prefix="sqlite-vec-profile-")
|
||||||
|
print(f"Working directory: {tmpdir}")
|
||||||
|
|
||||||
|
all_profiles = []
|
||||||
|
|
||||||
|
for i, (name, params) in enumerate(configs, 1):
|
||||||
|
reg = INDEX_REGISTRY[params["index_type"]]
|
||||||
|
desc = reg["describe"](params)
|
||||||
|
print(f"\n[{i}/{len(configs)}] {name} ({desc})")
|
||||||
|
|
||||||
|
# Generate SQL workload
|
||||||
|
db_path = os.path.join(tmpdir, f"{name}.db")
|
||||||
|
sql_text = generate_sql(
|
||||||
|
db_path, params, args.subset_size, args.n, args.k, args.repeats
|
||||||
|
)
|
||||||
|
sql_file = os.path.join(tmpdir, f"{name}.sql")
|
||||||
|
with open(sql_file, "w") as f:
|
||||||
|
f.write(sql_text)
|
||||||
|
|
||||||
|
total_queries = args.n * args.repeats
|
||||||
|
print(
|
||||||
|
f" SQL workload: {args.subset_size} inserts + "
|
||||||
|
f"{total_queries} queries ({args.n} x {args.repeats} repeats)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Profile
|
||||||
|
sample_file = os.path.join(tmpdir, f"{name}.sample.txt")
|
||||||
|
print(f" Profiling...")
|
||||||
|
ok = run_profile(args.sqlite3, db_path, sql_file, sample_file)
|
||||||
|
if not ok:
|
||||||
|
print(f" FAILED — skipping {name}")
|
||||||
|
all_profiles.append((name, desc, {}))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not os.path.exists(sample_file):
|
||||||
|
print(f" Warning: sample output not created")
|
||||||
|
all_profiles.append((name, desc, {}))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Parse
|
||||||
|
self_samples = parse_sample_output(sample_file)
|
||||||
|
all_profiles.append((name, desc, self_samples))
|
||||||
|
|
||||||
|
# Show individual profile
|
||||||
|
print_profile(f"{name} ({desc})", self_samples, args.top)
|
||||||
|
|
||||||
|
# Side-by-side comparison if multiple configs
|
||||||
|
if len(all_profiles) > 1:
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("COMPARISON")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Collect all symbols that appear in top-N of any config
|
||||||
|
all_syms = set()
|
||||||
|
for _name, _desc, prof in all_profiles:
|
||||||
|
sorted_syms = sorted(prof.items(), key=lambda x: -x[1])
|
||||||
|
for sym, _count in sorted_syms[: args.top]:
|
||||||
|
all_syms.add(sym)
|
||||||
|
|
||||||
|
# Build comparison table
|
||||||
|
rows = []
|
||||||
|
for sym in all_syms:
|
||||||
|
row = [sym]
|
||||||
|
for _name, _desc, prof in all_profiles:
|
||||||
|
total = sum(prof.values())
|
||||||
|
count = prof.get(sym, 0)
|
||||||
|
pct = 100.0 * count / total if total > 0 else 0.0
|
||||||
|
row.append((pct, count))
|
||||||
|
max_pct = max(r[0] for r in row[1:])
|
||||||
|
rows.append((max_pct, row))
|
||||||
|
|
||||||
|
rows.sort(key=lambda x: -x[0])
|
||||||
|
|
||||||
|
# Header
|
||||||
|
header = f"{'function':>40}"
|
||||||
|
for name, desc, _ in all_profiles:
|
||||||
|
header += f" {name:>14}"
|
||||||
|
print(header)
|
||||||
|
print("-" * len(header))
|
||||||
|
|
||||||
|
for _sort_key, row in rows[: args.top * 2]:
|
||||||
|
sym = row[0]
|
||||||
|
display_sym = sym if len(sym) <= 40 else sym[:37] + "..."
|
||||||
|
line = f"{display_sym:>40}"
|
||||||
|
for pct, count in row[1:]:
|
||||||
|
if count > 0:
|
||||||
|
line += f" {pct:>13.1f}%"
|
||||||
|
else:
|
||||||
|
line += f" {'-':>14}"
|
||||||
|
print(line)
|
||||||
|
|
||||||
|
if args.keep_temp:
|
||||||
|
print(f"\nTemp files kept at: {tmpdir}")
|
||||||
|
else:
|
||||||
|
shutil.rmtree(tmpdir)
|
||||||
|
print(f"\nTemp files cleaned up. Use --keep-temp to preserve.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
76
benchmarks-ann/results_schema.sql
Normal file
76
benchmarks-ann/results_schema.sql
Normal file
|
|
@ -0,0 +1,76 @@
|
||||||
|
-- Comprehensive results schema for vec0 KNN benchmark runs.
|
||||||
|
-- Created in WAL mode: PRAGMA journal_mode=WAL
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS runs (
|
||||||
|
run_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
config_name TEXT NOT NULL,
|
||||||
|
index_type TEXT NOT NULL,
|
||||||
|
params TEXT NOT NULL, -- JSON: {"R":48,"L":128,"quantizer":"binary"}
|
||||||
|
dataset TEXT NOT NULL, -- "cohere1m"
|
||||||
|
subset_size INTEGER NOT NULL,
|
||||||
|
k INTEGER NOT NULL,
|
||||||
|
n_queries INTEGER NOT NULL,
|
||||||
|
phase TEXT NOT NULL DEFAULT 'both',
|
||||||
|
-- 'build', 'query', or 'both'
|
||||||
|
status TEXT NOT NULL DEFAULT 'pending',
|
||||||
|
-- pending → inserting → training → querying → done | built | error
|
||||||
|
created_at_ns INTEGER NOT NULL -- time.time_ns()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS run_results (
|
||||||
|
run_id INTEGER PRIMARY KEY REFERENCES runs(run_id),
|
||||||
|
insert_started_ns INTEGER,
|
||||||
|
insert_ended_ns INTEGER,
|
||||||
|
insert_duration_ns INTEGER,
|
||||||
|
train_started_ns INTEGER, -- NULL if no training
|
||||||
|
train_ended_ns INTEGER,
|
||||||
|
train_duration_ns INTEGER,
|
||||||
|
build_duration_ns INTEGER, -- insert + train
|
||||||
|
db_file_size_bytes INTEGER,
|
||||||
|
db_file_path TEXT,
|
||||||
|
create_sql TEXT, -- CREATE VIRTUAL TABLE ...
|
||||||
|
insert_sql TEXT, -- INSERT INTO vec_items ...
|
||||||
|
train_sql TEXT, -- NULL if no training step
|
||||||
|
query_sql TEXT, -- SELECT ... WHERE embedding MATCH ...
|
||||||
|
k INTEGER, -- denormalized from runs for easy filtering
|
||||||
|
query_mean_ms REAL, -- denormalized aggregates
|
||||||
|
query_median_ms REAL,
|
||||||
|
query_p99_ms REAL,
|
||||||
|
query_total_ms REAL,
|
||||||
|
qps REAL,
|
||||||
|
recall REAL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS insert_batches (
|
||||||
|
batch_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
run_id INTEGER NOT NULL REFERENCES runs(run_id),
|
||||||
|
batch_lo INTEGER NOT NULL, -- start index (inclusive)
|
||||||
|
batch_hi INTEGER NOT NULL, -- end index (exclusive)
|
||||||
|
rows_in_batch INTEGER NOT NULL,
|
||||||
|
started_ns INTEGER NOT NULL,
|
||||||
|
ended_ns INTEGER NOT NULL,
|
||||||
|
duration_ns INTEGER NOT NULL,
|
||||||
|
cumulative_rows INTEGER NOT NULL, -- total rows inserted so far
|
||||||
|
rate_rows_per_s REAL NOT NULL -- cumulative rate
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS queries (
|
||||||
|
query_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
run_id INTEGER NOT NULL REFERENCES runs(run_id),
|
||||||
|
k INTEGER NOT NULL,
|
||||||
|
query_vector_id INTEGER NOT NULL,
|
||||||
|
started_ns INTEGER NOT NULL,
|
||||||
|
ended_ns INTEGER NOT NULL,
|
||||||
|
duration_ms REAL NOT NULL,
|
||||||
|
result_ids TEXT NOT NULL, -- JSON array
|
||||||
|
result_distances TEXT NOT NULL, -- JSON array
|
||||||
|
ground_truth_ids TEXT NOT NULL, -- JSON array
|
||||||
|
recall REAL NOT NULL,
|
||||||
|
UNIQUE(run_id, k, query_vector_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_runs_config ON runs(config_name);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_runs_type ON runs(index_type);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_runs_status ON runs(status);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_batches_run ON insert_batches(run_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_queries_run ON queries(run_id);
|
||||||
60
benchmarks-ann/schema.sql
Normal file
60
benchmarks-ann/schema.sql
Normal file
|
|
@ -0,0 +1,60 @@
|
||||||
|
-- Canonical results schema for vec0 KNN benchmark comparisons.
|
||||||
|
-- The index_type column is a free-form TEXT field. Baseline configs use
|
||||||
|
-- "baseline"; index-specific branches add their own types (registered
|
||||||
|
-- via INDEX_REGISTRY in bench.py).
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS runs (
|
||||||
|
run_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
config_name TEXT NOT NULL,
|
||||||
|
index_type TEXT NOT NULL,
|
||||||
|
subset_size INTEGER NOT NULL,
|
||||||
|
phase TEXT NOT NULL DEFAULT 'both', -- 'build', 'query', or 'both'
|
||||||
|
status TEXT NOT NULL DEFAULT 'pending',
|
||||||
|
k INTEGER,
|
||||||
|
n INTEGER,
|
||||||
|
db_path TEXT,
|
||||||
|
insert_time_s REAL,
|
||||||
|
train_time_s REAL,
|
||||||
|
total_build_time_s REAL,
|
||||||
|
rows INTEGER,
|
||||||
|
file_size_mb REAL,
|
||||||
|
mean_ms REAL,
|
||||||
|
median_ms REAL,
|
||||||
|
p99_ms REAL,
|
||||||
|
total_query_ms REAL,
|
||||||
|
qps REAL,
|
||||||
|
recall REAL,
|
||||||
|
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||||
|
finished_at TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS build_results (
|
||||||
|
config_name TEXT NOT NULL,
|
||||||
|
index_type TEXT NOT NULL,
|
||||||
|
subset_size INTEGER NOT NULL,
|
||||||
|
db_path TEXT NOT NULL,
|
||||||
|
insert_time_s REAL NOT NULL,
|
||||||
|
train_time_s REAL, -- NULL when no training/build step is needed
|
||||||
|
total_time_s REAL NOT NULL,
|
||||||
|
rows INTEGER NOT NULL,
|
||||||
|
file_size_mb REAL NOT NULL,
|
||||||
|
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||||
|
PRIMARY KEY (config_name, subset_size)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS bench_results (
|
||||||
|
config_name TEXT NOT NULL,
|
||||||
|
index_type TEXT NOT NULL,
|
||||||
|
subset_size INTEGER NOT NULL,
|
||||||
|
k INTEGER NOT NULL,
|
||||||
|
n INTEGER NOT NULL,
|
||||||
|
mean_ms REAL NOT NULL,
|
||||||
|
median_ms REAL NOT NULL,
|
||||||
|
p99_ms REAL NOT NULL,
|
||||||
|
total_ms REAL NOT NULL,
|
||||||
|
qps REAL NOT NULL,
|
||||||
|
recall REAL NOT NULL,
|
||||||
|
db_path TEXT NOT NULL,
|
||||||
|
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||||
|
PRIMARY KEY (config_name, subset_size, k)
|
||||||
|
);
|
||||||
|
|
@ -248,59 +248,6 @@ def bench_libsql(base, query, page_size, k) -> BenchResult:
|
||||||
return BenchResult(f"libsql ({page_size})", build_time, times)
|
return BenchResult(f"libsql ({page_size})", build_time, times)
|
||||||
|
|
||||||
|
|
||||||
def register_np(db, array, name):
|
|
||||||
ptr = array.__array_interface__["data"][0]
|
|
||||||
nvectors, dimensions = array.__array_interface__["shape"]
|
|
||||||
element_type = array.__array_interface__["typestr"]
|
|
||||||
|
|
||||||
assert element_type == "<f4"
|
|
||||||
|
|
||||||
name_escaped = db.execute("select printf('%w', ?)", [name]).fetchone()[0]
|
|
||||||
|
|
||||||
db.execute(
|
|
||||||
"insert into temp.vec_static_blobs(name, data) select ?, vec_static_blob_from_raw(?, ?, ?, ?)",
|
|
||||||
[name, ptr, element_type, dimensions, nvectors],
|
|
||||||
)
|
|
||||||
|
|
||||||
db.execute(
|
|
||||||
f'create virtual table "{name_escaped}" using vec_static_blob_entries({name_escaped})'
|
|
||||||
)
|
|
||||||
|
|
||||||
def bench_sqlite_vec_static(base, query, k) -> BenchResult:
|
|
||||||
print(f"sqlite-vec static...")
|
|
||||||
|
|
||||||
db = sqlite3.connect(":memory:")
|
|
||||||
db.enable_load_extension(True)
|
|
||||||
db.load_extension("../../dist/vec0")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
t = time.time()
|
|
||||||
register_np(db, base, "base")
|
|
||||||
build_time = time.time() - t
|
|
||||||
|
|
||||||
times = []
|
|
||||||
results = []
|
|
||||||
for (
|
|
||||||
idx,
|
|
||||||
q,
|
|
||||||
) in enumerate(query):
|
|
||||||
t0 = time.time()
|
|
||||||
result = db.execute(
|
|
||||||
"""
|
|
||||||
select
|
|
||||||
rowid
|
|
||||||
from base
|
|
||||||
where vector match ?
|
|
||||||
and k = ?
|
|
||||||
order by distance
|
|
||||||
""",
|
|
||||||
[q.tobytes(), k],
|
|
||||||
).fetchall()
|
|
||||||
assert len(result) == k
|
|
||||||
times.append(time.time() - t0)
|
|
||||||
return BenchResult(f"sqlite-vec static", build_time, times)
|
|
||||||
|
|
||||||
def bench_faiss(base, query, k) -> BenchResult:
|
def bench_faiss(base, query, k) -> BenchResult:
|
||||||
import faiss
|
import faiss
|
||||||
dimensions = base.shape[1]
|
dimensions = base.shape[1]
|
||||||
|
|
@ -438,8 +385,6 @@ def suite(name, base, query, k, benchmarks):
|
||||||
for b in benchmarks:
|
for b in benchmarks:
|
||||||
if b == "faiss":
|
if b == "faiss":
|
||||||
results.append(bench_faiss(base, query, k=k))
|
results.append(bench_faiss(base, query, k=k))
|
||||||
elif b == "vec-static":
|
|
||||||
results.append(bench_sqlite_vec_static(base, query, k=k))
|
|
||||||
elif b.startswith("vec-scalar"):
|
elif b.startswith("vec-scalar"):
|
||||||
_, page_size = b.split('.')
|
_, page_size = b.split('.')
|
||||||
results.append(bench_sqlite_vec_scalar(base, query, page_size, k=k))
|
results.append(bench_sqlite_vec_scalar(base, query, page_size, k=k))
|
||||||
|
|
@ -541,7 +486,7 @@ def parse_args():
|
||||||
help="Number of queries to use. Defaults all",
|
help="Number of queries to use. Defaults all",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-x", help="type of runs to make", default="faiss,vec-scalar.4096,vec-static,vec-vec0.4096.16,usearch,duckdb,hnswlib,numpy"
|
"-x", help="type of runs to make", default="faiss,vec-scalar.4096,vec-vec0.4096.16,usearch,duckdb,hnswlib,numpy"
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
||||||
|
|
@ -8,10 +8,3 @@ create virtual table vec_items using vec0(
|
||||||
embedding float[1536]
|
embedding float[1536]
|
||||||
);
|
);
|
||||||
|
|
||||||
-- 65s (limit 1e5), ~615MB on disk
|
|
||||||
insert into vec_items
|
|
||||||
select
|
|
||||||
rowid,
|
|
||||||
vector
|
|
||||||
from vec_npy_each(vec_npy_file('examples/dbpedia-openai/data/vectors.npy'))
|
|
||||||
limit 1e5;
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ def connect(path):
|
||||||
db = sqlite3.connect(path)
|
db = sqlite3.connect(path)
|
||||||
db.enable_load_extension(True)
|
db.enable_load_extension(True)
|
||||||
db.load_extension("../dist/vec0")
|
db.load_extension("../dist/vec0")
|
||||||
db.execute("select load_extension('../dist/vec0', 'sqlite3_vec_fs_read_init')")
|
|
||||||
db.enable_load_extension(False)
|
db.enable_load_extension(False)
|
||||||
return db
|
return db
|
||||||
|
|
||||||
|
|
@ -18,8 +17,6 @@ page_sizes = [ # 4096, 8192,
|
||||||
chunk_sizes = [128, 256, 1024, 2048]
|
chunk_sizes = [128, 256, 1024, 2048]
|
||||||
types = ["f32", "int8", "bit"]
|
types = ["f32", "int8", "bit"]
|
||||||
|
|
||||||
SRC = "../examples/dbpedia-openai/data/vectors.npy"
|
|
||||||
|
|
||||||
for page_size in page_sizes:
|
for page_size in page_sizes:
|
||||||
for chunk_size in chunk_sizes:
|
for chunk_size in chunk_sizes:
|
||||||
for t in types:
|
for t in types:
|
||||||
|
|
@ -42,15 +39,8 @@ for page_size in page_sizes:
|
||||||
func = "vec_quantize_i8(vector, 'unit')"
|
func = "vec_quantize_i8(vector, 'unit')"
|
||||||
if t == "bit":
|
if t == "bit":
|
||||||
func = "vec_quantize_binary(vector)"
|
func = "vec_quantize_binary(vector)"
|
||||||
db.execute(
|
# TODO: replace with non-npy data loading
|
||||||
f"""
|
pass
|
||||||
insert into vec_items
|
|
||||||
select rowid, {func}
|
|
||||||
from vec_npy_each(vec_npy_file(?))
|
|
||||||
limit 100000
|
|
||||||
""",
|
|
||||||
[SRC],
|
|
||||||
)
|
|
||||||
elapsed = time.time() - t0
|
elapsed = time.time() - t0
|
||||||
print(elapsed)
|
print(elapsed)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ index ed2aaec..4cc0b0e 100755
|
||||||
-Wl,--initial-memory=327680 \
|
-Wl,--initial-memory=327680 \
|
||||||
-D_HAVE_SQLITE_CONFIG_H \
|
-D_HAVE_SQLITE_CONFIG_H \
|
||||||
-DSQLITE_CUSTOM_INCLUDE=sqlite_opt.h \
|
-DSQLITE_CUSTOM_INCLUDE=sqlite_opt.h \
|
||||||
+ -DSQLITE_VEC_OMIT_FS=1 \
|
|
||||||
$(awk '{print "-Wl,--export="$0}' exports.txt)
|
$(awk '{print "-Wl,--export="$0}' exports.txt)
|
||||||
|
|
||||||
"$BINARYEN/wasm-ctor-eval" -g -c _initialize sqlite3.wasm -o sqlite3.tmp
|
"$BINARYEN/wasm-ctor-eval" -g -c _initialize sqlite3.wasm -o sqlite3.tmp
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
from struct import pack
|
from struct import pack
|
||||||
from sqlite3 import Connection
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_float32(vector: List[float]) -> bytes:
|
def serialize_float32(vector: List[float]) -> bytes:
|
||||||
|
|
@ -13,33 +12,3 @@ def serialize_int8(vector: List[int]) -> bytes:
|
||||||
return pack("%sb" % len(vector), *vector)
|
return pack("%sb" % len(vector), *vector)
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
import numpy.typing as npt
|
|
||||||
|
|
||||||
def register_numpy(db: Connection, name: str, array: npt.NDArray):
|
|
||||||
"""ayoo"""
|
|
||||||
|
|
||||||
ptr = array.__array_interface__["data"][0]
|
|
||||||
nvectors, dimensions = array.__array_interface__["shape"]
|
|
||||||
element_type = array.__array_interface__["typestr"]
|
|
||||||
|
|
||||||
assert element_type == "<f4"
|
|
||||||
|
|
||||||
name_escaped = db.execute("select printf('%w', ?)", [name]).fetchone()[0]
|
|
||||||
|
|
||||||
db.execute(
|
|
||||||
"""
|
|
||||||
insert into temp.vec_static_blobs(name, data)
|
|
||||||
select ?, vec_static_blob_from_raw(?, ?, ?, ?)
|
|
||||||
""",
|
|
||||||
[name, ptr, element_type, dimensions, nvectors],
|
|
||||||
)
|
|
||||||
|
|
||||||
db.execute(
|
|
||||||
f'create virtual table "{name_escaped}" using vec_static_blob_entries({name_escaped})'
|
|
||||||
)
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
|
|
||||||
def register_numpy(db: Connection, name: str, array):
|
|
||||||
raise Exception("numpy package is required for register_numpy")
|
|
||||||
|
|
|
||||||
119
scripts/amalgamate.py
Normal file
119
scripts/amalgamate.py
Normal file
|
|
@ -0,0 +1,119 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Amalgamate sqlite-vec into a single distributable .c file.
|
||||||
|
|
||||||
|
Reads the dev sqlite-vec.c and inlines any #include "sqlite-vec-*.c" files,
|
||||||
|
stripping LSP-support blocks and per-file include guards.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python3 scripts/amalgamate.py sqlite-vec.c > dist/sqlite-vec.c
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def strip_lsp_block(content):
|
||||||
|
"""Remove the LSP-support pattern:
|
||||||
|
#ifndef SQLITE_VEC_H
|
||||||
|
#include "sqlite-vec.c" // ...
|
||||||
|
#endif
|
||||||
|
"""
|
||||||
|
pattern = re.compile(
|
||||||
|
r'^\s*#ifndef\s+SQLITE_VEC_H\s*\n'
|
||||||
|
r'\s*#include\s+"sqlite-vec\.c"[^\n]*\n'
|
||||||
|
r'\s*#endif[^\n]*\n',
|
||||||
|
re.MULTILINE,
|
||||||
|
)
|
||||||
|
return pattern.sub('', content)
|
||||||
|
|
||||||
|
|
||||||
|
def strip_include_guard(content, guard_macro):
|
||||||
|
"""Remove the include guard pair:
|
||||||
|
#ifndef GUARD_MACRO
|
||||||
|
#define GUARD_MACRO
|
||||||
|
...content...
|
||||||
|
(trailing #endif removed)
|
||||||
|
"""
|
||||||
|
# Strip the #ifndef / #define pair at the top
|
||||||
|
header_pattern = re.compile(
|
||||||
|
r'^\s*#ifndef\s+' + re.escape(guard_macro) + r'\s*\n'
|
||||||
|
r'\s*#define\s+' + re.escape(guard_macro) + r'\s*\n',
|
||||||
|
re.MULTILINE,
|
||||||
|
)
|
||||||
|
content = header_pattern.sub('', content, count=1)
|
||||||
|
|
||||||
|
# Strip the trailing #endif (last one in file that closes the guard)
|
||||||
|
# Find the last #endif and remove it
|
||||||
|
lines = content.rstrip('\n').split('\n')
|
||||||
|
for i in range(len(lines) - 1, -1, -1):
|
||||||
|
if re.match(r'^\s*#endif', lines[i]):
|
||||||
|
lines.pop(i)
|
||||||
|
break
|
||||||
|
|
||||||
|
return '\n'.join(lines) + '\n'
|
||||||
|
|
||||||
|
|
||||||
|
def detect_include_guard(content):
|
||||||
|
"""Detect an include guard macro like SQLITE_VEC_IVF_C."""
|
||||||
|
m = re.match(
|
||||||
|
r'\s*(?:/\*[\s\S]*?\*/\s*)?' # optional block comment
|
||||||
|
r'#ifndef\s+(SQLITE_VEC_\w+_C)\s*\n'
|
||||||
|
r'#define\s+\1',
|
||||||
|
content,
|
||||||
|
)
|
||||||
|
return m.group(1) if m else None
|
||||||
|
|
||||||
|
|
||||||
|
def inline_include(match, base_dir):
|
||||||
|
"""Replace an #include "sqlite-vec-*.c" with the file's contents."""
|
||||||
|
filename = match.group(1)
|
||||||
|
filepath = os.path.join(base_dir, filename)
|
||||||
|
|
||||||
|
if not os.path.exists(filepath):
|
||||||
|
print(f"Warning: {filepath} not found, leaving #include in place", file=sys.stderr)
|
||||||
|
return match.group(0)
|
||||||
|
|
||||||
|
with open(filepath, 'r') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# Strip LSP-support block
|
||||||
|
content = strip_lsp_block(content)
|
||||||
|
|
||||||
|
# Strip include guard if present
|
||||||
|
guard = detect_include_guard(content)
|
||||||
|
if guard:
|
||||||
|
content = strip_include_guard(content, guard)
|
||||||
|
|
||||||
|
separator = '/' * 78
|
||||||
|
header = f'\n{separator}\n// Begin inlined: {filename}\n{separator}\n\n'
|
||||||
|
footer = f'\n{separator}\n// End inlined: {filename}\n{separator}\n'
|
||||||
|
|
||||||
|
return header + content.strip('\n') + footer
|
||||||
|
|
||||||
|
|
||||||
|
def amalgamate(input_path):
|
||||||
|
base_dir = os.path.dirname(os.path.abspath(input_path))
|
||||||
|
|
||||||
|
with open(input_path, 'r') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# Replace #include "sqlite-vec-*.c" with inlined contents
|
||||||
|
include_pattern = re.compile(r'^#include\s+"(sqlite-vec-[^"]+\.c)"\s*$', re.MULTILINE)
|
||||||
|
content = include_pattern.sub(lambda m: inline_include(m, base_dir), content)
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if len(sys.argv) != 2:
|
||||||
|
print(f"Usage: {sys.argv[0]} <input-file>", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
result = amalgamate(sys.argv[1])
|
||||||
|
sys.stdout.write(result)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
mkdir -p vendor
|
mkdir -p vendor
|
||||||
curl -o sqlite-amalgamation.zip https://www.sqlite.org/2024/sqlite-amalgamation-3450300.zip
|
curl -o sqlite-amalgamation.zip https://www.sqlite.org/2024/sqlite-amalgamation-3450300.zip
|
||||||
unzip -d
|
|
||||||
unzip sqlite-amalgamation.zip
|
unzip sqlite-amalgamation.zip
|
||||||
mv sqlite-amalgamation-3450300/* vendor/
|
mv sqlite-amalgamation-3450300/* vendor/
|
||||||
rmdir sqlite-amalgamation-3450300
|
rmdir sqlite-amalgamation-3450300
|
||||||
|
|
|
||||||
|
|
@ -568,65 +568,6 @@ select 'todo';
|
||||||
-- 'todo'
|
-- 'todo'
|
||||||
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
## NumPy Utilities {#numpy}
|
|
||||||
|
|
||||||
Functions to read data from or work with [NumPy arrays](https://numpy.org/doc/stable/reference/generated/numpy.array.html).
|
|
||||||
|
|
||||||
### `vec_npy_each(vector)` {#vec_npy_each}
|
|
||||||
|
|
||||||
xxx
|
|
||||||
|
|
||||||
|
|
||||||
```sql
|
|
||||||
-- db.execute('select quote(?)', [to_npy(np.array([[1.0], [2.0], [3.0]], dtype=np.float32))]).fetchone()
|
|
||||||
select
|
|
||||||
rowid,
|
|
||||||
vector,
|
|
||||||
vec_type(vector),
|
|
||||||
vec_to_json(vector)
|
|
||||||
from vec_npy_each(
|
|
||||||
X'934E554D5059010076007B276465736372273A20273C6634272C2027666F727472616E5F6F72646572273A2046616C73652C20277368617065273A2028332C2031292C207D202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020200A0000803F0000004000004040'
|
|
||||||
)
|
|
||||||
/*
|
|
||||||
┌───────┬─────────────┬──────────────────┬─────────────────────┐
|
|
||||||
│ rowid │ vector │ vec_type(vector) │ vec_to_json(vector) │
|
|
||||||
├───────┼─────────────┼──────────────────┼─────────────────────┤
|
|
||||||
│ 0 │ X'0000803F' │ 'float32' │ '[1.000000]' │
|
|
||||||
├───────┼─────────────┼──────────────────┼─────────────────────┤
|
|
||||||
│ 1 │ X'00000040' │ 'float32' │ '[2.000000]' │
|
|
||||||
├───────┼─────────────┼──────────────────┼─────────────────────┤
|
|
||||||
│ 2 │ X'00004040' │ 'float32' │ '[3.000000]' │
|
|
||||||
└───────┴─────────────┴──────────────────┴─────────────────────┘
|
|
||||||
|
|
||||||
*/
|
|
||||||
|
|
||||||
|
|
||||||
-- db.execute('select quote(?)', [to_npy(np.array([[1.0], [2.0], [3.0]], dtype=np.float32))]).fetchone()
|
|
||||||
select
|
|
||||||
rowid,
|
|
||||||
vector,
|
|
||||||
vec_type(vector),
|
|
||||||
vec_to_json(vector)
|
|
||||||
from vec_npy_each(
|
|
||||||
X'934E554D5059010076007B276465736372273A20273C6634272C2027666F727472616E5F6F72646572273A2046616C73652C20277368617065273A2028332C2031292C207D202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020200A0000803F0000004000004040'
|
|
||||||
)
|
|
||||||
/*
|
|
||||||
┌───────┬─────────────┬──────────────────┬─────────────────────┐
|
|
||||||
│ rowid │ vector │ vec_type(vector) │ vec_to_json(vector) │
|
|
||||||
├───────┼─────────────┼──────────────────┼─────────────────────┤
|
|
||||||
│ 0 │ X'0000803F' │ 'float32' │ '[1.000000]' │
|
|
||||||
├───────┼─────────────┼──────────────────┼─────────────────────┤
|
|
||||||
│ 1 │ X'00000040' │ 'float32' │ '[2.000000]' │
|
|
||||||
├───────┼─────────────┼──────────────────┼─────────────────────┤
|
|
||||||
│ 2 │ X'00004040' │ 'float32' │ '[3.000000]' │
|
|
||||||
└───────┴─────────────┴──────────────────┴─────────────────────┘
|
|
||||||
|
|
||||||
*/
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Meta {#meta}
|
## Meta {#meta}
|
||||||
|
|
|
||||||
|
|
@ -59,5 +59,4 @@ The current compile-time flags are:
|
||||||
|
|
||||||
- `SQLITE_VEC_ENABLE_AVX`, enables AVX CPU instructions for some vector search operations
|
- `SQLITE_VEC_ENABLE_AVX`, enables AVX CPU instructions for some vector search operations
|
||||||
- `SQLITE_VEC_ENABLE_NEON`, enables NEON CPU instructions for some vector search operations
|
- `SQLITE_VEC_ENABLE_NEON`, enables NEON CPU instructions for some vector search operations
|
||||||
- `SQLITE_VEC_OMIT_FS`, removes some obsure SQL functions and features that use the filesystem, meant for some WASM builds where there's no available filesystem
|
|
||||||
- `SQLITE_VEC_STATIC`, meant for statically linking `sqlite-vec`
|
- `SQLITE_VEC_STATIC`, meant for statically linking `sqlite-vec`
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "sqlite-vec"
|
name = "sqlite-vec"
|
||||||
license = "MIT OR Apache"
|
license = "MIT OR Apache-2.0"
|
||||||
homepage = "https://alexgarcia.xyz/sqlite-vec"
|
homepage = "https://alexgarcia.xyz/sqlite-vec"
|
||||||
repo = "https://github.com/asg017/sqlite-vec"
|
repo = "https://github.com/asg017/sqlite-vec"
|
||||||
description = "A vector search SQLite extension."
|
description = "A vector search SQLite extension."
|
||||||
|
|
|
||||||
1889
sqlite-vec-diskann.c
Normal file
1889
sqlite-vec-diskann.c
Normal file
File diff suppressed because it is too large
Load diff
214
sqlite-vec-ivf-kmeans.c
Normal file
214
sqlite-vec-ivf-kmeans.c
Normal file
|
|
@ -0,0 +1,214 @@
|
||||||
|
/**
|
||||||
|
* sqlite-vec-ivf-kmeans.c — Pure k-means clustering algorithm.
|
||||||
|
*
|
||||||
|
* No SQLite dependency. Operates on float arrays in memory.
|
||||||
|
* #include'd into sqlite-vec.c after struct definitions.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef SQLITE_VEC_IVF_KMEANS_C
|
||||||
|
#define SQLITE_VEC_IVF_KMEANS_C
|
||||||
|
|
||||||
|
// When opened standalone in an editor, pull in types so the LSP is happy.
|
||||||
|
// When #include'd from sqlite-vec.c, SQLITE_VEC_H is already defined.
|
||||||
|
#ifndef SQLITE_VEC_H
|
||||||
|
#include "sqlite-vec.c" // IWYU pragma: keep
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <float.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#define VEC0_IVF_KMEANS_MAX_ITER 25
|
||||||
|
#define VEC0_IVF_KMEANS_DEFAULT_SEED 0
|
||||||
|
|
||||||
|
// Simple xorshift32 PRNG
|
||||||
|
static uint32_t ivf_xorshift32(uint32_t *state) {
|
||||||
|
uint32_t x = *state;
|
||||||
|
x ^= x << 13;
|
||||||
|
x ^= x >> 17;
|
||||||
|
x ^= x << 5;
|
||||||
|
*state = x;
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// L2 squared distance between two float vectors
|
||||||
|
static float ivf_l2_dist(const float *a, const float *b, int D) {
|
||||||
|
float sum = 0.0f;
|
||||||
|
for (int d = 0; d < D; d++) {
|
||||||
|
float diff = a[d] - b[d];
|
||||||
|
sum += diff * diff;
|
||||||
|
}
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find nearest centroid for a single vector. Returns centroid index.
|
||||||
|
static int ivf_nearest_centroid(const float *vec, const float *centroids,
|
||||||
|
int D, int k) {
|
||||||
|
float min_dist = FLT_MAX;
|
||||||
|
int best = 0;
|
||||||
|
for (int c = 0; c < k; c++) {
|
||||||
|
float dist = ivf_l2_dist(vec, ¢roids[c * D], D);
|
||||||
|
if (dist < min_dist) {
|
||||||
|
min_dist = dist;
|
||||||
|
best = c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return best;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* K-means++ initialization.
|
||||||
|
* Picks k initial centroids from the data with probability proportional
|
||||||
|
* to squared distance from nearest existing centroid.
|
||||||
|
*/
|
||||||
|
static int ivf_kmeans_init_plusplus(const float *vectors, int N, int D,
|
||||||
|
int k, uint32_t seed, float *centroids) {
|
||||||
|
if (N <= 0 || k <= 0 || D <= 0)
|
||||||
|
return -1;
|
||||||
|
if (seed == 0)
|
||||||
|
seed = 42;
|
||||||
|
|
||||||
|
// Pick first centroid randomly
|
||||||
|
int first = ivf_xorshift32(&seed) % N;
|
||||||
|
memcpy(centroids, &vectors[first * D], D * sizeof(float));
|
||||||
|
|
||||||
|
if (k == 1)
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
// Allocate distance array
|
||||||
|
float *dists = sqlite3_malloc64((i64)N * sizeof(float));
|
||||||
|
if (!dists)
|
||||||
|
return -1;
|
||||||
|
|
||||||
|
for (int c = 1; c < k; c++) {
|
||||||
|
// Compute D(x) = distance to nearest existing centroid
|
||||||
|
double total = 0.0;
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
float d = ivf_l2_dist(&vectors[i * D], ¢roids[(c - 1) * D], D);
|
||||||
|
if (c == 1 || d < dists[i]) {
|
||||||
|
dists[i] = d;
|
||||||
|
}
|
||||||
|
total += dists[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Weighted random selection
|
||||||
|
if (total <= 0.0) {
|
||||||
|
// All distances zero — pick randomly
|
||||||
|
int pick = ivf_xorshift32(&seed) % N;
|
||||||
|
memcpy(¢roids[c * D], &vectors[pick * D], D * sizeof(float));
|
||||||
|
} else {
|
||||||
|
double threshold = ((double)ivf_xorshift32(&seed) / (double)0xFFFFFFFF) * total;
|
||||||
|
double cumulative = 0.0;
|
||||||
|
int pick = N - 1;
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
cumulative += dists[i];
|
||||||
|
if (cumulative >= threshold) {
|
||||||
|
pick = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
memcpy(¢roids[c * D], &vectors[pick * D], D * sizeof(float));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3_free(dists);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Lloyd's k-means algorithm.
|
||||||
|
*
|
||||||
|
* @param vectors N*D float array (row-major)
|
||||||
|
* @param N number of vectors
|
||||||
|
* @param D dimensionality
|
||||||
|
* @param k number of clusters
|
||||||
|
* @param max_iter maximum iterations
|
||||||
|
* @param seed PRNG seed for initialization
|
||||||
|
* @param out_centroids output: k*D float array (caller-allocated)
|
||||||
|
* @return 0 on success, -1 on error
|
||||||
|
*/
|
||||||
|
static int ivf_kmeans(const float *vectors, int N, int D, int k,
|
||||||
|
int max_iter, uint32_t seed, float *out_centroids) {
|
||||||
|
if (N <= 0 || D <= 0 || k <= 0)
|
||||||
|
return -1;
|
||||||
|
|
||||||
|
// Clamp k to N
|
||||||
|
if (k > N)
|
||||||
|
k = N;
|
||||||
|
|
||||||
|
// Allocate working memory
|
||||||
|
int *assignments = sqlite3_malloc64((i64)N * sizeof(int));
|
||||||
|
float *new_centroids = sqlite3_malloc64((i64)k * D * sizeof(float));
|
||||||
|
int *counts = sqlite3_malloc64((i64)k * sizeof(int));
|
||||||
|
|
||||||
|
if (!assignments || !new_centroids || !counts) {
|
||||||
|
sqlite3_free(assignments);
|
||||||
|
sqlite3_free(new_centroids);
|
||||||
|
sqlite3_free(counts);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
memset(assignments, -1, N * sizeof(int));
|
||||||
|
|
||||||
|
// Initialize centroids via k-means++
|
||||||
|
if (ivf_kmeans_init_plusplus(vectors, N, D, k, seed, out_centroids) != 0) {
|
||||||
|
sqlite3_free(assignments);
|
||||||
|
sqlite3_free(new_centroids);
|
||||||
|
sqlite3_free(counts);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int iter = 0; iter < max_iter; iter++) {
|
||||||
|
// Assignment step
|
||||||
|
int changed = 0;
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
int nearest = ivf_nearest_centroid(&vectors[i * D], out_centroids, D, k);
|
||||||
|
if (nearest != assignments[i]) {
|
||||||
|
assignments[i] = nearest;
|
||||||
|
changed++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (changed == 0)
|
||||||
|
break;
|
||||||
|
|
||||||
|
// Update step
|
||||||
|
memset(new_centroids, 0, (size_t)k * D * sizeof(float));
|
||||||
|
memset(counts, 0, k * sizeof(int));
|
||||||
|
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
int c = assignments[i];
|
||||||
|
counts[c]++;
|
||||||
|
for (int d = 0; d < D; d++) {
|
||||||
|
new_centroids[c * D + d] += vectors[i * D + d];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int c = 0; c < k; c++) {
|
||||||
|
if (counts[c] == 0) {
|
||||||
|
// Empty cluster: reassign to farthest point from its nearest centroid
|
||||||
|
float max_dist = -1.0f;
|
||||||
|
int farthest = 0;
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
float d = ivf_l2_dist(&vectors[i * D],
|
||||||
|
&out_centroids[assignments[i] * D], D);
|
||||||
|
if (d > max_dist) {
|
||||||
|
max_dist = d;
|
||||||
|
farthest = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
memcpy(&out_centroids[c * D], &vectors[farthest * D],
|
||||||
|
D * sizeof(float));
|
||||||
|
} else {
|
||||||
|
for (int d = 0; d < D; d++) {
|
||||||
|
out_centroids[c * D + d] = new_centroids[c * D + d] / counts[c];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3_free(assignments);
|
||||||
|
sqlite3_free(new_centroids);
|
||||||
|
sqlite3_free(counts);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif /* SQLITE_VEC_IVF_KMEANS_C */
|
||||||
1445
sqlite-vec-ivf.c
Normal file
1445
sqlite-vec-ivf.c
Normal file
File diff suppressed because it is too large
Load diff
687
sqlite-vec-rescore.c
Normal file
687
sqlite-vec-rescore.c
Normal file
|
|
@ -0,0 +1,687 @@
|
||||||
|
/**
|
||||||
|
* sqlite-vec-rescore.c — Rescore index logic for sqlite-vec.
|
||||||
|
*
|
||||||
|
* This file is #included into sqlite-vec.c after the vec0_vtab definition.
|
||||||
|
* All functions receive a vec0_vtab *p and access p->vector_columns[i].rescore.
|
||||||
|
*
|
||||||
|
* Shadow tables per rescore-enabled vector column:
|
||||||
|
* _rescore_chunks{NN} — quantized vectors in chunk layout (for coarse scan)
|
||||||
|
* _rescore_vectors{NN} — float vectors keyed by rowid (for fast rescore lookup)
|
||||||
|
*/
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Shadow table lifecycle
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
static int rescore_create_tables(vec0_vtab *p, sqlite3 *db, char **pzErr) {
|
||||||
|
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||||
|
if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_RESCORE)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
// Quantized chunk table (same layout as _vector_chunks)
|
||||||
|
char *zSql = sqlite3_mprintf(
|
||||||
|
"CREATE TABLE \"%w\".\"%w_rescore_chunks%02d\""
|
||||||
|
"(rowid PRIMARY KEY, vectors BLOB NOT NULL)",
|
||||||
|
p->schemaName, p->tableName, i);
|
||||||
|
if (!zSql)
|
||||||
|
return SQLITE_NOMEM;
|
||||||
|
sqlite3_stmt *stmt;
|
||||||
|
int rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, 0);
|
||||||
|
sqlite3_free(zSql);
|
||||||
|
if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) {
|
||||||
|
*pzErr = sqlite3_mprintf(
|
||||||
|
"Could not create '_rescore_chunks%02d' shadow table: %s", i,
|
||||||
|
sqlite3_errmsg(db));
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
|
||||||
|
// Float vector table (rowid-keyed for fast random access)
|
||||||
|
zSql = sqlite3_mprintf(
|
||||||
|
"CREATE TABLE \"%w\".\"%w_rescore_vectors%02d\""
|
||||||
|
"(rowid INTEGER PRIMARY KEY, vector BLOB NOT NULL)",
|
||||||
|
p->schemaName, p->tableName, i);
|
||||||
|
if (!zSql)
|
||||||
|
return SQLITE_NOMEM;
|
||||||
|
rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, 0);
|
||||||
|
sqlite3_free(zSql);
|
||||||
|
if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) {
|
||||||
|
*pzErr = sqlite3_mprintf(
|
||||||
|
"Could not create '_rescore_vectors%02d' shadow table: %s", i,
|
||||||
|
sqlite3_errmsg(db));
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
}
|
||||||
|
return SQLITE_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int rescore_drop_tables(vec0_vtab *p) {
|
||||||
|
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||||
|
sqlite3_stmt *stmt;
|
||||||
|
int rc;
|
||||||
|
char *zSql;
|
||||||
|
|
||||||
|
if (p->shadowRescoreChunksNames[i]) {
|
||||||
|
zSql = sqlite3_mprintf("DROP TABLE IF EXISTS \"%w\".\"%w\"",
|
||||||
|
p->schemaName, p->shadowRescoreChunksNames[i]);
|
||||||
|
if (!zSql)
|
||||||
|
return SQLITE_NOMEM;
|
||||||
|
rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0);
|
||||||
|
sqlite3_free(zSql);
|
||||||
|
if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) {
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (p->shadowRescoreVectorsNames[i]) {
|
||||||
|
zSql = sqlite3_mprintf("DROP TABLE IF EXISTS \"%w\".\"%w\"",
|
||||||
|
p->schemaName, p->shadowRescoreVectorsNames[i]);
|
||||||
|
if (!zSql)
|
||||||
|
return SQLITE_NOMEM;
|
||||||
|
rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0);
|
||||||
|
sqlite3_free(zSql);
|
||||||
|
if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) {
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return SQLITE_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
static size_t rescore_quantized_byte_size(struct VectorColumnDefinition *col) {
|
||||||
|
switch (col->rescore.quantizer_type) {
|
||||||
|
case VEC0_RESCORE_QUANTIZER_BIT:
|
||||||
|
return col->dimensions / CHAR_BIT;
|
||||||
|
case VEC0_RESCORE_QUANTIZER_INT8:
|
||||||
|
return col->dimensions;
|
||||||
|
default:
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Insert a new chunk row into each _rescore_chunks{NN} table with a zeroblob.
|
||||||
|
*/
|
||||||
|
static int rescore_new_chunk(vec0_vtab *p, i64 chunk_rowid) {
|
||||||
|
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||||
|
if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_RESCORE)
|
||||||
|
continue;
|
||||||
|
size_t quantized_size =
|
||||||
|
rescore_quantized_byte_size(&p->vector_columns[i]);
|
||||||
|
i64 blob_size = (i64)p->chunk_size * (i64)quantized_size;
|
||||||
|
|
||||||
|
char *zSql = sqlite3_mprintf(
|
||||||
|
"INSERT INTO \"%w\".\"%w\"(_rowid_, rowid, vectors) VALUES (?, ?, ?)",
|
||||||
|
p->schemaName, p->shadowRescoreChunksNames[i]);
|
||||||
|
if (!zSql)
|
||||||
|
return SQLITE_NOMEM;
|
||||||
|
sqlite3_stmt *stmt;
|
||||||
|
int rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL);
|
||||||
|
sqlite3_free(zSql);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
sqlite3_bind_int64(stmt, 1, chunk_rowid);
|
||||||
|
sqlite3_bind_int64(stmt, 2, chunk_rowid);
|
||||||
|
sqlite3_bind_zeroblob64(stmt, 3, blob_size);
|
||||||
|
rc = sqlite3_step(stmt);
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
if (rc != SQLITE_DONE)
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
return SQLITE_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Quantization
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
static void rescore_quantize_float_to_bit(const float *src, uint8_t *dst,
|
||||||
|
size_t dimensions) {
|
||||||
|
memset(dst, 0, dimensions / CHAR_BIT);
|
||||||
|
for (size_t i = 0; i < dimensions; i++) {
|
||||||
|
if (src[i] >= 0.0f) {
|
||||||
|
dst[i / CHAR_BIT] |= (1 << (i % CHAR_BIT));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void rescore_quantize_float_to_int8(const float *src, int8_t *dst,
|
||||||
|
size_t dimensions) {
|
||||||
|
float step = 2.0f / 255.0f;
|
||||||
|
for (size_t i = 0; i < dimensions; i++) {
|
||||||
|
float v = (src[i] - (-1.0f)) / step - 128.0f;
|
||||||
|
if (!(v <= 127.0f)) v = 127.0f;
|
||||||
|
if (!(v >= -128.0f)) v = -128.0f;
|
||||||
|
dst[i] = (int8_t)v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Insert path
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Quantize float vector to _rescore_chunks and store in _rescore_vectors.
|
||||||
|
*/
|
||||||
|
static int rescore_on_insert(vec0_vtab *p, i64 chunk_rowid, i64 chunk_offset,
|
||||||
|
i64 rowid, void *vectorDatas[]) {
|
||||||
|
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||||
|
if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_RESCORE)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
struct VectorColumnDefinition *col = &p->vector_columns[i];
|
||||||
|
size_t qsize = rescore_quantized_byte_size(col);
|
||||||
|
size_t fsize = vector_column_byte_size(*col);
|
||||||
|
int rc;
|
||||||
|
|
||||||
|
// 1. Write quantized vector to _rescore_chunks blob
|
||||||
|
{
|
||||||
|
void *qbuf = sqlite3_malloc(qsize);
|
||||||
|
if (!qbuf)
|
||||||
|
return SQLITE_NOMEM;
|
||||||
|
|
||||||
|
switch (col->rescore.quantizer_type) {
|
||||||
|
case VEC0_RESCORE_QUANTIZER_BIT:
|
||||||
|
rescore_quantize_float_to_bit((const float *)vectorDatas[i],
|
||||||
|
(uint8_t *)qbuf, col->dimensions);
|
||||||
|
break;
|
||||||
|
case VEC0_RESCORE_QUANTIZER_INT8:
|
||||||
|
rescore_quantize_float_to_int8((const float *)vectorDatas[i],
|
||||||
|
(int8_t *)qbuf, col->dimensions);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3_blob *blob = NULL;
|
||||||
|
rc = sqlite3_blob_open(p->db, p->schemaName,
|
||||||
|
p->shadowRescoreChunksNames[i], "vectors",
|
||||||
|
chunk_rowid, 1, &blob);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
sqlite3_free(qbuf);
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
rc = sqlite3_blob_write(blob, qbuf, qsize, chunk_offset * qsize);
|
||||||
|
sqlite3_free(qbuf);
|
||||||
|
int brc = sqlite3_blob_close(blob);
|
||||||
|
if (rc != SQLITE_OK)
|
||||||
|
return rc;
|
||||||
|
if (brc != SQLITE_OK)
|
||||||
|
return brc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Insert float vector into _rescore_vectors (rowid-keyed)
|
||||||
|
{
|
||||||
|
char *zSql = sqlite3_mprintf(
|
||||||
|
"INSERT INTO \"%w\".\"%w\"(rowid, vector) VALUES (?, ?)",
|
||||||
|
p->schemaName, p->shadowRescoreVectorsNames[i]);
|
||||||
|
if (!zSql)
|
||||||
|
return SQLITE_NOMEM;
|
||||||
|
sqlite3_stmt *stmt;
|
||||||
|
rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL);
|
||||||
|
sqlite3_free(zSql);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
sqlite3_bind_int64(stmt, 1, rowid);
|
||||||
|
sqlite3_bind_blob(stmt, 2, vectorDatas[i], fsize, SQLITE_TRANSIENT);
|
||||||
|
rc = sqlite3_step(stmt);
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
if (rc != SQLITE_DONE)
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return SQLITE_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Delete path
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Zero out quantized vector in _rescore_chunks and delete from _rescore_vectors.
|
||||||
|
*/
|
||||||
|
static int rescore_on_delete(vec0_vtab *p, i64 chunk_id, u64 chunk_offset,
|
||||||
|
i64 rowid) {
|
||||||
|
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||||
|
if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_RESCORE)
|
||||||
|
continue;
|
||||||
|
int rc;
|
||||||
|
|
||||||
|
// 1. Zero out quantized data in _rescore_chunks
|
||||||
|
{
|
||||||
|
size_t qsize = rescore_quantized_byte_size(&p->vector_columns[i]);
|
||||||
|
void *zeroBuf = sqlite3_malloc(qsize);
|
||||||
|
if (!zeroBuf)
|
||||||
|
return SQLITE_NOMEM;
|
||||||
|
memset(zeroBuf, 0, qsize);
|
||||||
|
|
||||||
|
sqlite3_blob *blob = NULL;
|
||||||
|
rc = sqlite3_blob_open(p->db, p->schemaName,
|
||||||
|
p->shadowRescoreChunksNames[i], "vectors",
|
||||||
|
chunk_id, 1, &blob);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
sqlite3_free(zeroBuf);
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
rc = sqlite3_blob_write(blob, zeroBuf, qsize, chunk_offset * qsize);
|
||||||
|
sqlite3_free(zeroBuf);
|
||||||
|
int brc = sqlite3_blob_close(blob);
|
||||||
|
if (rc != SQLITE_OK)
|
||||||
|
return rc;
|
||||||
|
if (brc != SQLITE_OK)
|
||||||
|
return brc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Delete from _rescore_vectors
|
||||||
|
{
|
||||||
|
char *zSql = sqlite3_mprintf(
|
||||||
|
"DELETE FROM \"%w\".\"%w\" WHERE rowid = ?",
|
||||||
|
p->schemaName, p->shadowRescoreVectorsNames[i]);
|
||||||
|
if (!zSql)
|
||||||
|
return SQLITE_NOMEM;
|
||||||
|
sqlite3_stmt *stmt;
|
||||||
|
rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL);
|
||||||
|
sqlite3_free(zSql);
|
||||||
|
if (rc != SQLITE_OK)
|
||||||
|
return rc;
|
||||||
|
sqlite3_bind_int64(stmt, 1, rowid);
|
||||||
|
rc = sqlite3_step(stmt);
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
if (rc != SQLITE_DONE)
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return SQLITE_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Delete a chunk row from _rescore_chunks{NN} tables.
|
||||||
|
* (_rescore_vectors rows were already deleted per-row in rescore_on_delete)
|
||||||
|
*/
|
||||||
|
static int rescore_delete_chunk(vec0_vtab *p, i64 chunk_id) {
|
||||||
|
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||||
|
if (!p->shadowRescoreChunksNames[i])
|
||||||
|
continue;
|
||||||
|
char *zSql = sqlite3_mprintf(
|
||||||
|
"DELETE FROM \"%w\".\"%w\" WHERE rowid = ?",
|
||||||
|
p->schemaName, p->shadowRescoreChunksNames[i]);
|
||||||
|
if (!zSql)
|
||||||
|
return SQLITE_NOMEM;
|
||||||
|
sqlite3_stmt *stmt;
|
||||||
|
int rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL);
|
||||||
|
sqlite3_free(zSql);
|
||||||
|
if (rc != SQLITE_OK)
|
||||||
|
return rc;
|
||||||
|
sqlite3_bind_int64(stmt, 1, chunk_id);
|
||||||
|
rc = sqlite3_step(stmt);
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
if (rc != SQLITE_DONE)
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
|
return SQLITE_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// KNN rescore query
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Phase 1: Coarse scan of quantized chunks → top k*oversample candidates (rowids).
|
||||||
|
* Phase 2: For each candidate, blob_open _rescore_vectors by rowid, read float
|
||||||
|
* vector, compute float distance. Sort, return top k.
|
||||||
|
*
|
||||||
|
* Phase 2 is fast because _rescore_vectors has INTEGER PRIMARY KEY, so
|
||||||
|
* sqlite3_blob_open/reopen addresses rows directly by rowid — no index lookup.
|
||||||
|
*/
|
||||||
|
static int rescore_knn(vec0_vtab *p, vec0_cursor *pCur,
|
||||||
|
struct VectorColumnDefinition *vector_column,
|
||||||
|
int vectorColumnIdx, struct Array *arrayRowidsIn,
|
||||||
|
struct Array *aMetadataIn, const char *idxStr, int argc,
|
||||||
|
sqlite3_value **argv, void *queryVector, i64 k,
|
||||||
|
struct vec0_query_knn_data *knn_data) {
|
||||||
|
(void)pCur;
|
||||||
|
(void)aMetadataIn;
|
||||||
|
int rc = SQLITE_OK;
|
||||||
|
int oversample = vector_column->rescore.oversample_search > 0
|
||||||
|
? vector_column->rescore.oversample_search
|
||||||
|
: vector_column->rescore.oversample;
|
||||||
|
i64 k_oversample = k * oversample;
|
||||||
|
if (k_oversample > 4096)
|
||||||
|
k_oversample = 4096;
|
||||||
|
|
||||||
|
size_t qdim = vector_column->dimensions;
|
||||||
|
size_t qsize = rescore_quantized_byte_size(vector_column);
|
||||||
|
size_t fsize = vector_column_byte_size(*vector_column);
|
||||||
|
|
||||||
|
// Quantize the query vector
|
||||||
|
void *quantizedQuery = sqlite3_malloc(qsize);
|
||||||
|
if (!quantizedQuery)
|
||||||
|
return SQLITE_NOMEM;
|
||||||
|
|
||||||
|
switch (vector_column->rescore.quantizer_type) {
|
||||||
|
case VEC0_RESCORE_QUANTIZER_BIT:
|
||||||
|
rescore_quantize_float_to_bit((const float *)queryVector,
|
||||||
|
(uint8_t *)quantizedQuery, qdim);
|
||||||
|
break;
|
||||||
|
case VEC0_RESCORE_QUANTIZER_INT8:
|
||||||
|
rescore_quantize_float_to_int8((const float *)queryVector,
|
||||||
|
(int8_t *)quantizedQuery, qdim);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 1: Scan quantized chunks for k*oversample candidates
|
||||||
|
sqlite3_stmt *stmtChunks = NULL;
|
||||||
|
rc = vec0_chunks_iter(p, idxStr, argc, argv, &stmtChunks);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
sqlite3_free(quantizedQuery);
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
|
||||||
|
i64 *cand_rowids = sqlite3_malloc(k_oversample * sizeof(i64));
|
||||||
|
f32 *cand_distances = sqlite3_malloc(k_oversample * sizeof(f32));
|
||||||
|
i64 *tmp_rowids = sqlite3_malloc(k_oversample * sizeof(i64));
|
||||||
|
f32 *tmp_distances = sqlite3_malloc(k_oversample * sizeof(f32));
|
||||||
|
f32 *chunk_distances = sqlite3_malloc(p->chunk_size * sizeof(f32));
|
||||||
|
i32 *chunk_topk_idxs = sqlite3_malloc(k_oversample * sizeof(i32));
|
||||||
|
u8 *b = sqlite3_malloc(p->chunk_size / CHAR_BIT);
|
||||||
|
u8 *bTaken = sqlite3_malloc(p->chunk_size / CHAR_BIT);
|
||||||
|
u8 *bmRowids = NULL;
|
||||||
|
void *baseVectors = sqlite3_malloc((i64)p->chunk_size * (i64)qsize);
|
||||||
|
|
||||||
|
if (!cand_rowids || !cand_distances || !tmp_rowids || !tmp_distances ||
|
||||||
|
!chunk_distances || !chunk_topk_idxs || !b || !bTaken || !baseVectors) {
|
||||||
|
rc = SQLITE_NOMEM;
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
memset(cand_rowids, 0, k_oversample * sizeof(i64));
|
||||||
|
memset(cand_distances, 0, k_oversample * sizeof(f32));
|
||||||
|
|
||||||
|
if (arrayRowidsIn) {
|
||||||
|
bmRowids = sqlite3_malloc(p->chunk_size / CHAR_BIT);
|
||||||
|
if (!bmRowids) {
|
||||||
|
rc = SQLITE_NOMEM;
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
i64 cand_used = 0;
|
||||||
|
|
||||||
|
while (1) {
|
||||||
|
rc = sqlite3_step(stmtChunks);
|
||||||
|
if (rc == SQLITE_DONE)
|
||||||
|
break;
|
||||||
|
if (rc != SQLITE_ROW) {
|
||||||
|
rc = SQLITE_ERROR;
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
|
||||||
|
i64 chunk_id = sqlite3_column_int64(stmtChunks, 0);
|
||||||
|
unsigned char *chunkValidity =
|
||||||
|
(unsigned char *)sqlite3_column_blob(stmtChunks, 1);
|
||||||
|
i64 *chunkRowids = (i64 *)sqlite3_column_blob(stmtChunks, 2);
|
||||||
|
int validityBytes = sqlite3_column_bytes(stmtChunks, 1);
|
||||||
|
int rowidsBytes = sqlite3_column_bytes(stmtChunks, 2);
|
||||||
|
if (!chunkValidity || !chunkRowids) {
|
||||||
|
rc = SQLITE_ERROR;
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
// Validate blob sizes match chunk_size expectations
|
||||||
|
if (validityBytes < (p->chunk_size + 7) / 8 ||
|
||||||
|
rowidsBytes < p->chunk_size * (int)sizeof(i64)) {
|
||||||
|
rc = SQLITE_ERROR;
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
|
||||||
|
memset(chunk_distances, 0, p->chunk_size * sizeof(f32));
|
||||||
|
memset(chunk_topk_idxs, 0, k_oversample * sizeof(i32));
|
||||||
|
bitmap_copy(b, chunkValidity, p->chunk_size);
|
||||||
|
|
||||||
|
if (arrayRowidsIn) {
|
||||||
|
bitmap_clear(bmRowids, p->chunk_size);
|
||||||
|
for (int j = 0; j < p->chunk_size; j++) {
|
||||||
|
if (!bitmap_get(chunkValidity, j))
|
||||||
|
continue;
|
||||||
|
i64 rid = chunkRowids[j];
|
||||||
|
void *found = bsearch(&rid, arrayRowidsIn->z, arrayRowidsIn->length,
|
||||||
|
sizeof(i64), _cmp);
|
||||||
|
bitmap_set(bmRowids, j, found ? 1 : 0);
|
||||||
|
}
|
||||||
|
bitmap_and_inplace(b, bmRowids, p->chunk_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read quantized vectors
|
||||||
|
sqlite3_blob *blobQ = NULL;
|
||||||
|
rc = sqlite3_blob_open(p->db, p->schemaName,
|
||||||
|
p->shadowRescoreChunksNames[vectorColumnIdx],
|
||||||
|
"vectors", chunk_id, 0, &blobQ);
|
||||||
|
if (rc != SQLITE_OK)
|
||||||
|
goto cleanup;
|
||||||
|
rc = sqlite3_blob_read(blobQ, baseVectors,
|
||||||
|
(i64)p->chunk_size * (i64)qsize, 0);
|
||||||
|
sqlite3_blob_close(blobQ);
|
||||||
|
if (rc != SQLITE_OK)
|
||||||
|
goto cleanup;
|
||||||
|
|
||||||
|
// Compute quantized distances
|
||||||
|
for (int j = 0; j < p->chunk_size; j++) {
|
||||||
|
if (!bitmap_get(b, j))
|
||||||
|
continue;
|
||||||
|
f32 dist = FLT_MAX;
|
||||||
|
switch (vector_column->rescore.quantizer_type) {
|
||||||
|
case VEC0_RESCORE_QUANTIZER_BIT: {
|
||||||
|
const u8 *base_j = ((u8 *)baseVectors) + (j * (qdim / CHAR_BIT));
|
||||||
|
dist = distance_hamming(base_j, (u8 *)quantizedQuery, &qdim);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case VEC0_RESCORE_QUANTIZER_INT8: {
|
||||||
|
const i8 *base_j = ((i8 *)baseVectors) + (j * qdim);
|
||||||
|
switch (vector_column->distance_metric) {
|
||||||
|
case VEC0_DISTANCE_METRIC_L2:
|
||||||
|
dist = distance_l2_sqr_int8(base_j, (i8 *)quantizedQuery, &qdim);
|
||||||
|
break;
|
||||||
|
case VEC0_DISTANCE_METRIC_COSINE:
|
||||||
|
dist = distance_cosine_int8(base_j, (i8 *)quantizedQuery, &qdim);
|
||||||
|
break;
|
||||||
|
case VEC0_DISTANCE_METRIC_L1:
|
||||||
|
dist = (f32)distance_l1_int8(base_j, (i8 *)quantizedQuery, &qdim);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
chunk_distances[j] = dist;
|
||||||
|
}
|
||||||
|
|
||||||
|
int used1;
|
||||||
|
min_idx(chunk_distances, p->chunk_size, b, chunk_topk_idxs,
|
||||||
|
min(k_oversample, p->chunk_size), bTaken, &used1);
|
||||||
|
|
||||||
|
i64 merged_used;
|
||||||
|
merge_sorted_lists(cand_distances, cand_rowids, cand_used, chunk_distances,
|
||||||
|
chunkRowids, chunk_topk_idxs,
|
||||||
|
min(min(k_oversample, p->chunk_size), used1),
|
||||||
|
tmp_distances, tmp_rowids, k_oversample, &merged_used);
|
||||||
|
|
||||||
|
for (i64 j = 0; j < merged_used; j++) {
|
||||||
|
cand_rowids[j] = tmp_rowids[j];
|
||||||
|
cand_distances[j] = tmp_distances[j];
|
||||||
|
}
|
||||||
|
cand_used = merged_used;
|
||||||
|
}
|
||||||
|
rc = SQLITE_OK;
|
||||||
|
|
||||||
|
// Phase 2: Rescore candidates using _rescore_vectors (rowid-keyed)
|
||||||
|
if (cand_used == 0) {
|
||||||
|
knn_data->current_idx = 0;
|
||||||
|
knn_data->k = 0;
|
||||||
|
knn_data->rowids = NULL;
|
||||||
|
knn_data->distances = NULL;
|
||||||
|
knn_data->k_used = 0;
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
f32 *float_distances = sqlite3_malloc(cand_used * sizeof(f32));
|
||||||
|
void *fBuf = sqlite3_malloc(fsize);
|
||||||
|
if (!float_distances || !fBuf) {
|
||||||
|
sqlite3_free(float_distances);
|
||||||
|
sqlite3_free(fBuf);
|
||||||
|
rc = SQLITE_NOMEM;
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open blob on _rescore_vectors, then reopen for each candidate rowid.
|
||||||
|
// blob_reopen is O(1) for INTEGER PRIMARY KEY tables.
|
||||||
|
sqlite3_blob *blobFloat = NULL;
|
||||||
|
rc = sqlite3_blob_open(p->db, p->schemaName,
|
||||||
|
p->shadowRescoreVectorsNames[vectorColumnIdx],
|
||||||
|
"vector", cand_rowids[0], 0, &blobFloat);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
sqlite3_free(float_distances);
|
||||||
|
sqlite3_free(fBuf);
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
|
||||||
|
rc = sqlite3_blob_read(blobFloat, fBuf, fsize, 0);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
sqlite3_blob_close(blobFloat);
|
||||||
|
sqlite3_free(float_distances);
|
||||||
|
sqlite3_free(fBuf);
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
float_distances[0] =
|
||||||
|
vec0_distance_full(fBuf, queryVector, vector_column->dimensions,
|
||||||
|
vector_column->element_type,
|
||||||
|
vector_column->distance_metric);
|
||||||
|
|
||||||
|
for (i64 j = 1; j < cand_used; j++) {
|
||||||
|
rc = sqlite3_blob_reopen(blobFloat, cand_rowids[j]);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
sqlite3_blob_close(blobFloat);
|
||||||
|
sqlite3_free(float_distances);
|
||||||
|
sqlite3_free(fBuf);
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
rc = sqlite3_blob_read(blobFloat, fBuf, fsize, 0);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
sqlite3_blob_close(blobFloat);
|
||||||
|
sqlite3_free(float_distances);
|
||||||
|
sqlite3_free(fBuf);
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
float_distances[j] =
|
||||||
|
vec0_distance_full(fBuf, queryVector, vector_column->dimensions,
|
||||||
|
vector_column->element_type,
|
||||||
|
vector_column->distance_metric);
|
||||||
|
}
|
||||||
|
sqlite3_blob_close(blobFloat);
|
||||||
|
sqlite3_free(fBuf);
|
||||||
|
|
||||||
|
// Sort by float distance
|
||||||
|
for (i64 a = 0; a + 1 < cand_used; a++) {
|
||||||
|
i64 minIdx = a;
|
||||||
|
for (i64 c = a + 1; c < cand_used; c++) {
|
||||||
|
if (float_distances[c] < float_distances[minIdx])
|
||||||
|
minIdx = c;
|
||||||
|
}
|
||||||
|
if (minIdx != a) {
|
||||||
|
f32 td = float_distances[a];
|
||||||
|
float_distances[a] = float_distances[minIdx];
|
||||||
|
float_distances[minIdx] = td;
|
||||||
|
i64 tr = cand_rowids[a];
|
||||||
|
cand_rowids[a] = cand_rowids[minIdx];
|
||||||
|
cand_rowids[minIdx] = tr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
i64 result_k = min(k, cand_used);
|
||||||
|
i64 *out_rowids = sqlite3_malloc(result_k * sizeof(i64));
|
||||||
|
f32 *out_distances = sqlite3_malloc(result_k * sizeof(f32));
|
||||||
|
if (!out_rowids || !out_distances) {
|
||||||
|
sqlite3_free(out_rowids);
|
||||||
|
sqlite3_free(out_distances);
|
||||||
|
sqlite3_free(float_distances);
|
||||||
|
rc = SQLITE_NOMEM;
|
||||||
|
goto cleanup;
|
||||||
|
}
|
||||||
|
for (i64 j = 0; j < result_k; j++) {
|
||||||
|
out_rowids[j] = cand_rowids[j];
|
||||||
|
out_distances[j] = float_distances[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
knn_data->current_idx = 0;
|
||||||
|
knn_data->k = result_k;
|
||||||
|
knn_data->rowids = out_rowids;
|
||||||
|
knn_data->distances = out_distances;
|
||||||
|
knn_data->k_used = result_k;
|
||||||
|
|
||||||
|
sqlite3_free(float_distances);
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup:
|
||||||
|
sqlite3_finalize(stmtChunks);
|
||||||
|
sqlite3_free(quantizedQuery);
|
||||||
|
sqlite3_free(cand_rowids);
|
||||||
|
sqlite3_free(cand_distances);
|
||||||
|
sqlite3_free(tmp_rowids);
|
||||||
|
sqlite3_free(tmp_distances);
|
||||||
|
sqlite3_free(chunk_distances);
|
||||||
|
sqlite3_free(chunk_topk_idxs);
|
||||||
|
sqlite3_free(b);
|
||||||
|
sqlite3_free(bTaken);
|
||||||
|
sqlite3_free(bmRowids);
|
||||||
|
sqlite3_free(baseVectors);
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle FTS5-style command dispatch for rescore parameters.
|
||||||
|
* Returns SQLITE_OK if handled, SQLITE_EMPTY if not a rescore command.
|
||||||
|
*/
|
||||||
|
static int rescore_handle_command(vec0_vtab *p, const char *command) {
|
||||||
|
if (strncmp(command, "oversample=", 11) == 0) {
|
||||||
|
int val = atoi(command + 11);
|
||||||
|
if (val < 1) {
|
||||||
|
vtab_set_error(&p->base, "oversample must be >= 1");
|
||||||
|
return SQLITE_ERROR;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < p->numVectorColumns; i++) {
|
||||||
|
if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE) {
|
||||||
|
p->vector_columns[i].rescore.oversample_search = val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return SQLITE_OK;
|
||||||
|
}
|
||||||
|
return SQLITE_EMPTY;
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef SQLITE_VEC_TEST
|
||||||
|
void _test_rescore_quantize_float_to_bit(const float *src, uint8_t *dst, size_t dim) {
|
||||||
|
rescore_quantize_float_to_bit(src, dst, dim);
|
||||||
|
}
|
||||||
|
void _test_rescore_quantize_float_to_int8(const float *src, int8_t *dst, size_t dim) {
|
||||||
|
rescore_quantize_float_to_int8(src, dst, dim);
|
||||||
|
}
|
||||||
|
size_t _test_rescore_quantized_byte_size_bit(size_t dimensions) {
|
||||||
|
struct VectorColumnDefinition col;
|
||||||
|
memset(&col, 0, sizeof(col));
|
||||||
|
col.dimensions = dimensions;
|
||||||
|
col.rescore.quantizer_type = VEC0_RESCORE_QUANTIZER_BIT;
|
||||||
|
return rescore_quantized_byte_size(&col);
|
||||||
|
}
|
||||||
|
size_t _test_rescore_quantized_byte_size_int8(size_t dimensions) {
|
||||||
|
struct VectorColumnDefinition col;
|
||||||
|
memset(&col, 0, sizeof(col));
|
||||||
|
col.dimensions = dimensions;
|
||||||
|
col.rescore.quantizer_type = VEC0_RESCORE_QUANTIZER_INT8;
|
||||||
|
return rescore_quantized_byte_size(&col);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
3643
sqlite-vec.c
3643
sqlite-vec.c
File diff suppressed because it is too large
Load diff
File diff suppressed because one or more lines are too long
|
|
@ -27,8 +27,8 @@
|
||||||
OrderedDict({
|
OrderedDict({
|
||||||
'chunk_id': 1,
|
'chunk_id': 1,
|
||||||
'size': 8,
|
'size': 8,
|
||||||
'validity': b'\x06',
|
'validity': b'\x02',
|
||||||
'rowids': b'\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
|
'rowids': b'\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
|
||||||
}),
|
}),
|
||||||
]),
|
]),
|
||||||
}),
|
}),
|
||||||
|
|
@ -37,7 +37,7 @@
|
||||||
'rows': list([
|
'rows': list([
|
||||||
OrderedDict({
|
OrderedDict({
|
||||||
'rowid': 1,
|
'rowid': 1,
|
||||||
'data': b'\x06',
|
'data': b'\x02',
|
||||||
}),
|
}),
|
||||||
]),
|
]),
|
||||||
}),
|
}),
|
||||||
|
|
@ -46,7 +46,7 @@
|
||||||
'rows': list([
|
'rows': list([
|
||||||
OrderedDict({
|
OrderedDict({
|
||||||
'rowid': 1,
|
'rowid': 1,
|
||||||
'data': b'\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
|
'data': b'\x00\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
|
||||||
}),
|
}),
|
||||||
]),
|
]),
|
||||||
}),
|
}),
|
||||||
|
|
@ -55,7 +55,7 @@
|
||||||
'rows': list([
|
'rows': list([
|
||||||
OrderedDict({
|
OrderedDict({
|
||||||
'rowid': 1,
|
'rowid': 1,
|
||||||
'data': b'\x00\x00\x00\x00\x00\x00\x00\x00\x9a\x99\x99\x99\x99\x99\x01@ffffff\n@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
|
'data': b'\x00\x00\x00\x00\x00\x00\x00\x00\x9a\x99\x99\x99\x99\x99\x01@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
|
||||||
}),
|
}),
|
||||||
]),
|
]),
|
||||||
}),
|
}),
|
||||||
|
|
@ -64,17 +64,13 @@
|
||||||
'rows': list([
|
'rows': list([
|
||||||
OrderedDict({
|
OrderedDict({
|
||||||
'rowid': 1,
|
'rowid': 1,
|
||||||
'data': b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05\x00\x00\x00test2\x00\x00\x00\x00\x00\x00\x00\r\x00\x00\x00123456789012\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
|
'data': b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05\x00\x00\x00test2\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
|
||||||
}),
|
}),
|
||||||
]),
|
]),
|
||||||
}),
|
}),
|
||||||
'v_metadatatext03': OrderedDict({
|
'v_metadatatext03': OrderedDict({
|
||||||
'sql': 'select * from v_metadatatext03',
|
'sql': 'select * from v_metadatatext03',
|
||||||
'rows': list([
|
'rows': list([
|
||||||
OrderedDict({
|
|
||||||
'rowid': 3,
|
|
||||||
'data': '1234567890123',
|
|
||||||
}),
|
|
||||||
]),
|
]),
|
||||||
}),
|
}),
|
||||||
'v_rowids': OrderedDict({
|
'v_rowids': OrderedDict({
|
||||||
|
|
@ -86,12 +82,6 @@
|
||||||
'chunk_id': 1,
|
'chunk_id': 1,
|
||||||
'chunk_offset': 1,
|
'chunk_offset': 1,
|
||||||
}),
|
}),
|
||||||
OrderedDict({
|
|
||||||
'rowid': 3,
|
|
||||||
'id': None,
|
|
||||||
'chunk_id': 1,
|
|
||||||
'chunk_offset': 2,
|
|
||||||
}),
|
|
||||||
]),
|
]),
|
||||||
}),
|
}),
|
||||||
'v_vector_chunks00': OrderedDict({
|
'v_vector_chunks00': OrderedDict({
|
||||||
|
|
@ -99,7 +89,7 @@
|
||||||
'rows': list([
|
'rows': list([
|
||||||
OrderedDict({
|
OrderedDict({
|
||||||
'rowid': 1,
|
'rowid': 1,
|
||||||
'vectors': b'\x00\x00\x00\x00""""3333\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
|
'vectors': b'\x00\x00\x00\x00""""\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
|
||||||
}),
|
}),
|
||||||
]),
|
]),
|
||||||
}),
|
}),
|
||||||
|
|
@ -370,14 +360,6 @@
|
||||||
'f': 2.2,
|
'f': 2.2,
|
||||||
't': 'test2',
|
't': 'test2',
|
||||||
}),
|
}),
|
||||||
OrderedDict({
|
|
||||||
'rowid': 3,
|
|
||||||
'vector': b'3333',
|
|
||||||
'b': 1,
|
|
||||||
'n': 3,
|
|
||||||
'f': 3.3,
|
|
||||||
't': '1234567890123',
|
|
||||||
}),
|
|
||||||
]),
|
]),
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,29 @@
|
||||||
import pytest
|
import pytest
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def _vec_debug():
|
||||||
|
db = sqlite3.connect(":memory:")
|
||||||
|
db.enable_load_extension(True)
|
||||||
|
db.load_extension("dist/vec0")
|
||||||
|
db.enable_load_extension(False)
|
||||||
|
return db.execute("SELECT vec_debug()").fetchone()[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _has_build_flag(flag):
|
||||||
|
return flag in _vec_debug().split("Build flags:")[-1]
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_collection_modifyitems(config, items):
|
||||||
|
has_ivf = _has_build_flag("ivf")
|
||||||
|
if has_ivf:
|
||||||
|
return
|
||||||
|
skip_ivf = pytest.mark.skip(reason="IVF not enabled (compile with -DSQLITE_VEC_EXPERIMENTAL_IVF_ENABLE=1)")
|
||||||
|
ivf_prefixes = ("test-ivf",)
|
||||||
|
for item in items:
|
||||||
|
if any(item.fspath.basename.startswith(p) for p in ivf_prefixes):
|
||||||
|
item.add_marker(skip_ivf)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,6 @@ import json
|
||||||
db = sqlite3.connect(":memory:")
|
db = sqlite3.connect(":memory:")
|
||||||
db.enable_load_extension(True)
|
db.enable_load_extension(True)
|
||||||
db.load_extension("../../dist/vec0")
|
db.load_extension("../../dist/vec0")
|
||||||
db.execute("select load_extension('../../dist/vec0', 'sqlite3_vec_fs_read_init')")
|
|
||||||
db.enable_load_extension(False)
|
db.enable_load_extension(False)
|
||||||
|
|
||||||
results = db.execute(
|
results = db.execute(
|
||||||
|
|
@ -75,17 +74,21 @@ print(b)
|
||||||
|
|
||||||
db.execute('PRAGMA page_size=16384')
|
db.execute('PRAGMA page_size=16384')
|
||||||
|
|
||||||
print("Loading into sqlite-vec vec0 table...")
|
|
||||||
t0 = time.time()
|
|
||||||
db.execute("create virtual table v using vec0(a float[3072], chunk_size=16)")
|
|
||||||
db.execute('insert into v select rowid, vector from vec_npy_each(vec_npy_file("dbpedia_openai_3_large_00.npy"))')
|
|
||||||
print(time.time() - t0)
|
|
||||||
|
|
||||||
print("loading numpy array...")
|
print("loading numpy array...")
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
base = np.load('dbpedia_openai_3_large_00.npy')
|
base = np.load('dbpedia_openai_3_large_00.npy')
|
||||||
print(time.time() - t0)
|
print(time.time() - t0)
|
||||||
|
|
||||||
|
print("Loading into sqlite-vec vec0 table...")
|
||||||
|
t0 = time.time()
|
||||||
|
db.execute("create virtual table v using vec0(a float[3072], chunk_size=16)")
|
||||||
|
with db:
|
||||||
|
db.executemany(
|
||||||
|
"insert into v(rowid, a) values (?, ?)",
|
||||||
|
[(i, row.tobytes()) for i, row in enumerate(base)],
|
||||||
|
)
|
||||||
|
print(time.time() - t0)
|
||||||
|
|
||||||
np.random.seed(1)
|
np.random.seed(1)
|
||||||
queries = base[np.random.choice(base.shape[0], 20, replace=False), :]
|
queries = base[np.random.choice(base.shape[0], 20, replace=False), :]
|
||||||
|
|
||||||
|
|
|
||||||
BIN
tests/fixtures/legacy-v0.1.6.db
vendored
Normal file
BIN
tests/fixtures/legacy-v0.1.6.db
vendored
Normal file
Binary file not shown.
5
tests/fuzz/.gitignore
vendored
5
tests/fuzz/.gitignore
vendored
|
|
@ -1,2 +1,7 @@
|
||||||
*.dSYM
|
*.dSYM
|
||||||
targets/
|
targets/
|
||||||
|
corpus/
|
||||||
|
crash-*
|
||||||
|
leak-*
|
||||||
|
timeout-*
|
||||||
|
*.log
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ FUZZ_LDFLAGS ?= $(shell \
|
||||||
echo "-Wl,-ld_classic"; \
|
echo "-Wl,-ld_classic"; \
|
||||||
fi)
|
fi)
|
||||||
|
|
||||||
FUZZ_CFLAGS = $(FUZZ_SANITIZERS) -I ../../ -I ../../vendor -DSQLITE_CORE -g $(FUZZ_LDFLAGS)
|
FUZZ_CFLAGS = $(FUZZ_SANITIZERS) -I ../../ -I ../../vendor -DSQLITE_CORE -DSQLITE_VEC_ENABLE_DISKANN=1 -g $(FUZZ_LDFLAGS)
|
||||||
FUZZ_SRCS = ../../vendor/sqlite3.c ../../sqlite-vec.c
|
FUZZ_SRCS = ../../vendor/sqlite3.c ../../sqlite-vec.c
|
||||||
|
|
||||||
TARGET_DIR = ./targets
|
TARGET_DIR = ./targets
|
||||||
|
|
@ -72,10 +72,94 @@ $(TARGET_DIR)/vec_mismatch: vec-mismatch.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
$(TARGET_DIR)/vec0_delete_completeness: vec0-delete-completeness.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
$(TARGET_DIR)/vec0_delete_completeness: vec0-delete-completeness.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/rescore_operations: rescore-operations.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/rescore_create: rescore-create.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/rescore_quantize: rescore-quantize.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE -DSQLITE_VEC_TEST $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/rescore_shadow_corrupt: rescore-shadow-corrupt.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/rescore_knn_deep: rescore-knn-deep.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/rescore_quantize_edge: rescore-quantize-edge.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE -DSQLITE_VEC_TEST $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/rescore_interleave: rescore-interleave.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/ivf_create: ivf-create.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/ivf_operations: ivf-operations.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/ivf_quantize: ivf-quantize.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/ivf_kmeans: ivf-kmeans.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/ivf_shadow_corrupt: ivf-shadow-corrupt.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/ivf_knn_deep: ivf-knn-deep.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/ivf_cell_overflow: ivf-cell-overflow.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/ivf_rescore: ivf-rescore.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(TARGET_DIR)/diskann_operations: diskann-operations.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/diskann_create: diskann-create.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/diskann_graph_corrupt: diskann-graph-corrupt.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/diskann_deep_search: diskann-deep-search.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/diskann_blob_truncate: diskann-blob-truncate.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/diskann_delete_stress: diskann-delete-stress.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/diskann_buffer_flush: diskann-buffer-flush.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/diskann_int8_quant: diskann-int8-quant.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/diskann_prune_direct: diskann-prune-direct.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
|
$(TARGET_DIR)/diskann_command_inject: diskann-command-inject.c $(FUZZ_SRCS) | $(TARGET_DIR)
|
||||||
|
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
|
||||||
|
|
||||||
FUZZ_TARGETS = vec0_create exec json numpy \
|
FUZZ_TARGETS = vec0_create exec json numpy \
|
||||||
shadow_corrupt vec0_operations scalar_functions \
|
shadow_corrupt vec0_operations scalar_functions \
|
||||||
vec0_create_full metadata_columns vec_each vec_mismatch \
|
vec0_create_full metadata_columns vec_each vec_mismatch \
|
||||||
vec0_delete_completeness
|
vec0_delete_completeness \
|
||||||
|
rescore_operations rescore_create rescore_quantize \
|
||||||
|
rescore_shadow_corrupt rescore_knn_deep \
|
||||||
|
rescore_quantize_edge rescore_interleave \
|
||||||
|
ivf_create ivf_operations \
|
||||||
|
ivf_quantize ivf_kmeans ivf_shadow_corrupt \
|
||||||
|
ivf_knn_deep ivf_cell_overflow ivf_rescore
|
||||||
|
diskann_operations diskann_create diskann_graph_corrupt \
|
||||||
|
diskann_deep_search diskann_blob_truncate \
|
||||||
|
diskann_delete_stress diskann_buffer_flush \
|
||||||
|
diskann_int8_quant diskann_prune_direct \
|
||||||
|
diskann_command_inject
|
||||||
|
|
||||||
all: $(addprefix $(TARGET_DIR)/,$(FUZZ_TARGETS))
|
all: $(addprefix $(TARGET_DIR)/,$(FUZZ_TARGETS))
|
||||||
|
|
||||||
|
|
|
||||||
250
tests/fuzz/diskann-blob-truncate.c
Normal file
250
tests/fuzz/diskann-blob-truncate.c
Normal file
|
|
@ -0,0 +1,250 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target for DiskANN shadow table blob size mismatches.
|
||||||
|
*
|
||||||
|
* The critical vulnerability: diskann_node_read() copies whatever blob size
|
||||||
|
* SQLite returns, but diskann_search/insert/delete index into those blobs
|
||||||
|
* using cfg->n_neighbors * sizeof(i64) etc. If the blob is truncated,
|
||||||
|
* extended, or has wrong size, this causes out-of-bounds reads/writes.
|
||||||
|
*
|
||||||
|
* This fuzzer:
|
||||||
|
* 1. Creates a valid DiskANN graph with several nodes
|
||||||
|
* 2. Uses fuzz data to directly write malformed blobs to shadow tables:
|
||||||
|
* - Truncated neighbor_ids (fewer bytes than n_neighbors * 8)
|
||||||
|
* - Truncated validity bitmaps
|
||||||
|
* - Oversized blobs with garbage trailing data
|
||||||
|
* - Zero-length blobs
|
||||||
|
* - Blobs with valid headers but corrupted neighbor rowids
|
||||||
|
* 3. Runs INSERT, DELETE, and KNN operations that traverse the corrupted graph
|
||||||
|
*
|
||||||
|
* Key code paths targeted:
|
||||||
|
* - diskann_node_read with mismatched blob sizes
|
||||||
|
* - diskann_validity_get / diskann_neighbor_id_get on truncated blobs
|
||||||
|
* - diskann_add_reverse_edge reading corrupted neighbor data
|
||||||
|
* - diskann_repair_reverse_edges traversing corrupted neighbor lists
|
||||||
|
* - diskann_search iterating neighbors from corrupted blobs
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) {
|
||||||
|
if (*size == 0) return def;
|
||||||
|
uint8_t b = **data;
|
||||||
|
(*data)++;
|
||||||
|
(*size)--;
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 32) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
/* Use binary quantizer, float[16], n_neighbors=8 for predictable blob sizes:
|
||||||
|
* validity: 8/8 = 1 byte
|
||||||
|
* neighbor_ids: 8 * 8 = 64 bytes
|
||||||
|
* qvecs: 8 * (16/8) = 16 bytes (binary: 2 bytes per qvec)
|
||||||
|
*/
|
||||||
|
rc = sqlite3_exec(db,
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[16] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8))",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
/* Insert 12 vectors to create a valid graph structure */
|
||||||
|
{
|
||||||
|
sqlite3_stmt *stmt;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmt, NULL);
|
||||||
|
for (int i = 1; i <= 12; i++) {
|
||||||
|
float vec[16];
|
||||||
|
for (int j = 0; j < 16; j++) {
|
||||||
|
vec[j] = (float)i * 0.1f + (float)j * 0.01f;
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmt);
|
||||||
|
sqlite3_bind_int64(stmt, 1, i);
|
||||||
|
sqlite3_bind_blob(stmt, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmt);
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Now corrupt shadow table blobs using fuzz data */
|
||||||
|
const char *columns[] = {
|
||||||
|
"neighbors_validity",
|
||||||
|
"neighbor_ids",
|
||||||
|
"neighbor_quantized_vectors"
|
||||||
|
};
|
||||||
|
|
||||||
|
/* Expected sizes for n_neighbors=8, dims=16, binary quantizer */
|
||||||
|
int expected_sizes[] = {1, 64, 16};
|
||||||
|
|
||||||
|
while (size >= 4) {
|
||||||
|
int target_row = (fuzz_byte(&data, &size, 0) % 12) + 1;
|
||||||
|
int col_idx = fuzz_byte(&data, &size, 0) % 3;
|
||||||
|
uint8_t corrupt_mode = fuzz_byte(&data, &size, 0) % 6;
|
||||||
|
uint8_t extra = fuzz_byte(&data, &size, 0);
|
||||||
|
|
||||||
|
char sqlbuf[256];
|
||||||
|
snprintf(sqlbuf, sizeof(sqlbuf),
|
||||||
|
"UPDATE v_diskann_nodes00 SET %s = ? WHERE rowid = ?",
|
||||||
|
columns[col_idx]);
|
||||||
|
|
||||||
|
sqlite3_stmt *writeStmt;
|
||||||
|
rc = sqlite3_prepare_v2(db, sqlbuf, -1, &writeStmt, NULL);
|
||||||
|
if (rc != SQLITE_OK) continue;
|
||||||
|
|
||||||
|
int expected = expected_sizes[col_idx];
|
||||||
|
unsigned char *blob = NULL;
|
||||||
|
int blob_size = 0;
|
||||||
|
|
||||||
|
switch (corrupt_mode) {
|
||||||
|
case 0: {
|
||||||
|
/* Truncated blob: 0 to expected-1 bytes */
|
||||||
|
blob_size = extra % expected;
|
||||||
|
if (blob_size == 0) blob_size = 0; /* zero-length is interesting */
|
||||||
|
blob = sqlite3_malloc(blob_size > 0 ? blob_size : 1);
|
||||||
|
if (!blob) { sqlite3_finalize(writeStmt); continue; }
|
||||||
|
for (int i = 0; i < blob_size; i++) {
|
||||||
|
blob[i] = fuzz_byte(&data, &size, 0);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 1: {
|
||||||
|
/* Oversized blob: expected + extra bytes */
|
||||||
|
blob_size = expected + (extra % 64);
|
||||||
|
blob = sqlite3_malloc(blob_size);
|
||||||
|
if (!blob) { sqlite3_finalize(writeStmt); continue; }
|
||||||
|
for (int i = 0; i < blob_size; i++) {
|
||||||
|
blob[i] = fuzz_byte(&data, &size, 0xFF);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 2: {
|
||||||
|
/* Zero-length blob */
|
||||||
|
blob_size = 0;
|
||||||
|
blob = NULL;
|
||||||
|
sqlite3_bind_zeroblob(writeStmt, 1, 0);
|
||||||
|
sqlite3_bind_int64(writeStmt, 2, target_row);
|
||||||
|
sqlite3_step(writeStmt);
|
||||||
|
sqlite3_finalize(writeStmt);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
case 3: {
|
||||||
|
/* Correct size but all-ones validity (all slots "valid") with
|
||||||
|
* garbage neighbor IDs -- forces reading non-existent nodes */
|
||||||
|
blob_size = expected;
|
||||||
|
blob = sqlite3_malloc(blob_size);
|
||||||
|
if (!blob) { sqlite3_finalize(writeStmt); continue; }
|
||||||
|
memset(blob, 0xFF, blob_size);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 4: {
|
||||||
|
/* neighbor_ids with very large rowid values (near INT64_MAX) */
|
||||||
|
blob_size = expected;
|
||||||
|
blob = sqlite3_malloc(blob_size);
|
||||||
|
if (!blob) { sqlite3_finalize(writeStmt); continue; }
|
||||||
|
memset(blob, 0x7F, blob_size); /* fills with large positive values */
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 5: {
|
||||||
|
/* neighbor_ids with negative rowid values (rowid=0 is sentinel) */
|
||||||
|
blob_size = expected;
|
||||||
|
blob = sqlite3_malloc(blob_size);
|
||||||
|
if (!blob) { sqlite3_finalize(writeStmt); continue; }
|
||||||
|
memset(blob, 0x80, blob_size); /* fills with large negative values */
|
||||||
|
/* Flip some bytes from fuzz data */
|
||||||
|
for (int i = 0; i < blob_size && size > 0; i++) {
|
||||||
|
blob[i] ^= fuzz_byte(&data, &size, 0);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (blob) {
|
||||||
|
sqlite3_bind_blob(writeStmt, 1, blob, blob_size, SQLITE_TRANSIENT);
|
||||||
|
} else {
|
||||||
|
sqlite3_bind_blob(writeStmt, 1, "", 0, SQLITE_STATIC);
|
||||||
|
}
|
||||||
|
sqlite3_bind_int64(writeStmt, 2, target_row);
|
||||||
|
sqlite3_step(writeStmt);
|
||||||
|
sqlite3_finalize(writeStmt);
|
||||||
|
sqlite3_free(blob);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Exercise the corrupted graph with various operations */
|
||||||
|
|
||||||
|
/* KNN query */
|
||||||
|
{
|
||||||
|
float qvec[16];
|
||||||
|
for (int j = 0; j < 16; j++) qvec[j] = (float)j * 0.1f;
|
||||||
|
sqlite3_stmt *knnStmt;
|
||||||
|
rc = sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = 5",
|
||||||
|
-1, &knnStmt, NULL);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
sqlite3_bind_blob(knnStmt, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||||
|
while (sqlite3_step(knnStmt) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(knnStmt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Insert into corrupted graph (triggers add_reverse_edge on corrupted nodes) */
|
||||||
|
{
|
||||||
|
float vec[16];
|
||||||
|
for (int j = 0; j < 16; j++) vec[j] = 0.5f;
|
||||||
|
sqlite3_stmt *stmt;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmt, NULL);
|
||||||
|
if (stmt) {
|
||||||
|
sqlite3_bind_int64(stmt, 1, 100);
|
||||||
|
sqlite3_bind_blob(stmt, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmt);
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Delete from corrupted graph (triggers repair_reverse_edges) */
|
||||||
|
{
|
||||||
|
sqlite3_stmt *stmt;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"DELETE FROM v WHERE rowid = ?", -1, &stmt, NULL);
|
||||||
|
if (stmt) {
|
||||||
|
sqlite3_bind_int64(stmt, 1, 5);
|
||||||
|
sqlite3_step(stmt);
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Another KNN to traverse the post-mutation graph */
|
||||||
|
{
|
||||||
|
float qvec[16];
|
||||||
|
for (int j = 0; j < 16; j++) qvec[j] = -0.5f + (float)j * 0.07f;
|
||||||
|
sqlite3_stmt *knnStmt;
|
||||||
|
rc = sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = 12",
|
||||||
|
-1, &knnStmt, NULL);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
sqlite3_bind_blob(knnStmt, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||||
|
while (sqlite3_step(knnStmt) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(knnStmt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Full scan */
|
||||||
|
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||||
|
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
164
tests/fuzz/diskann-buffer-flush.c
Normal file
164
tests/fuzz/diskann-buffer-flush.c
Normal file
|
|
@ -0,0 +1,164 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target for DiskANN buffered insert and flush paths.
|
||||||
|
*
|
||||||
|
* When buffer_threshold > 0, inserts go into a flat buffer table and
|
||||||
|
* are flushed into the graph in batch. This fuzzer exercises:
|
||||||
|
*
|
||||||
|
* - diskann_buffer_write / diskann_buffer_delete / diskann_buffer_exists
|
||||||
|
* - diskann_flush_buffer (batch graph insertion)
|
||||||
|
* - diskann_insert with buffer_threshold (batching logic)
|
||||||
|
* - Buffer-graph merge in vec0Filter_knn_diskann (unflushed vectors
|
||||||
|
* must be scanned during KNN and merged with graph results)
|
||||||
|
* - Delete of a buffered (not yet flushed) vector
|
||||||
|
* - Delete of a graph vector while buffer has pending inserts
|
||||||
|
* - Interaction: insert to buffer, query (triggers buffer scan), flush,
|
||||||
|
* query again (now from graph)
|
||||||
|
*
|
||||||
|
* The buffer merge path in vec0Filter_knn_diskann is particularly
|
||||||
|
* interesting because it does a brute-force scan of buffer vectors and
|
||||||
|
* merges with the top-k from graph search.
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) {
|
||||||
|
if (*size == 0) return def;
|
||||||
|
uint8_t b = **data;
|
||||||
|
(*data)++;
|
||||||
|
(*size)--;
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 16) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
/* buffer_threshold: small (3-8) to trigger frequent flushes */
|
||||||
|
int buf_threshold = 3 + (fuzz_byte(&data, &size, 0) % 6);
|
||||||
|
int dims = 8;
|
||||||
|
|
||||||
|
char sql[512];
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[%d] INDEXED BY diskann("
|
||||||
|
"neighbor_quantizer=binary, n_neighbors=8, "
|
||||||
|
"search_list_size=16, buffer_threshold=%d"
|
||||||
|
"))", dims, buf_threshold);
|
||||||
|
|
||||||
|
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
sqlite3_stmt *stmtInsert = NULL, *stmtDelete = NULL, *stmtKnn = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?",
|
||||||
|
-1, &stmtKnn, NULL);
|
||||||
|
|
||||||
|
if (!stmtInsert || !stmtDelete || !stmtKnn) goto cleanup;
|
||||||
|
|
||||||
|
float vec[8];
|
||||||
|
int next_rowid = 1;
|
||||||
|
|
||||||
|
while (size >= 2) {
|
||||||
|
uint8_t op = fuzz_byte(&data, &size, 0) % 6;
|
||||||
|
uint8_t param = fuzz_byte(&data, &size, 0);
|
||||||
|
|
||||||
|
switch (op) {
|
||||||
|
case 0: { /* Insert: accumulates in buffer until threshold */
|
||||||
|
int64_t rowid = next_rowid++;
|
||||||
|
if (next_rowid > 64) next_rowid = 1; /* wrap around for reuse */
|
||||||
|
for (int j = 0; j < dims; j++) {
|
||||||
|
vec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f;
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 1: { /* KNN query while buffer may have unflushed vectors */
|
||||||
|
for (int j = 0; j < dims; j++) {
|
||||||
|
vec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f;
|
||||||
|
}
|
||||||
|
int k = (param % 10) + 1;
|
||||||
|
sqlite3_reset(stmtKnn);
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_bind_int(stmtKnn, 2, k);
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 2: { /* Delete a potentially-buffered vector */
|
||||||
|
int64_t rowid = (int64_t)(param % 64) + 1;
|
||||||
|
sqlite3_reset(stmtDelete);
|
||||||
|
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||||
|
sqlite3_step(stmtDelete);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 3: { /* Insert several at once to trigger flush mid-batch */
|
||||||
|
for (int i = 0; i < buf_threshold + 1 && size >= 2; i++) {
|
||||||
|
int64_t rowid = (int64_t)(fuzz_byte(&data, &size, 0) % 64) + 1;
|
||||||
|
for (int j = 0; j < dims; j++) {
|
||||||
|
vec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f;
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 4: { /* Insert then immediately delete (still in buffer) */
|
||||||
|
int64_t rowid = (int64_t)(param % 64) + 1;
|
||||||
|
for (int j = 0; j < dims; j++) vec[j] = 0.1f * param;
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
|
||||||
|
sqlite3_reset(stmtDelete);
|
||||||
|
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||||
|
sqlite3_step(stmtDelete);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 5: { /* Query with k=0 and k=1 (boundary) */
|
||||||
|
for (int j = 0; j < dims; j++) vec[j] = 0.0f;
|
||||||
|
sqlite3_reset(stmtKnn);
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_bind_int(stmtKnn, 2, param % 2); /* k=0 or k=1 */
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Final query to exercise post-operation state */
|
||||||
|
{
|
||||||
|
float qvec[8] = {1.0f, -1.0f, 0.5f, -0.5f, 0.0f, 0.0f, 0.0f, 0.0f};
|
||||||
|
sqlite3_reset(stmtKnn);
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_bind_int(stmtKnn, 2, 20);
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup:
|
||||||
|
sqlite3_finalize(stmtInsert);
|
||||||
|
sqlite3_finalize(stmtDelete);
|
||||||
|
sqlite3_finalize(stmtKnn);
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
158
tests/fuzz/diskann-command-inject.c
Normal file
158
tests/fuzz/diskann-command-inject.c
Normal file
|
|
@ -0,0 +1,158 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target for DiskANN runtime command dispatch (diskann_handle_command).
|
||||||
|
*
|
||||||
|
* The command handler parses strings like "search_list_size_search=42" and
|
||||||
|
* modifies live DiskANN config. This fuzzer exercises:
|
||||||
|
*
|
||||||
|
* - atoi on fuzz-controlled strings (integer overflow, negative, non-numeric)
|
||||||
|
* - strncmp boundary with fuzz data (near-matches to valid commands)
|
||||||
|
* - Changing search_list_size mid-operation (affects subsequent queries)
|
||||||
|
* - Setting search_list_size to 1 (minimum - single-candidate beam search)
|
||||||
|
* - Setting search_list_size very large (memory pressure)
|
||||||
|
* - Interleaving command changes with inserts and queries
|
||||||
|
*
|
||||||
|
* Also tests the UPDATE v SET command = ? path through the vtable.
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) {
|
||||||
|
if (*size == 0) return def;
|
||||||
|
uint8_t b = **data;
|
||||||
|
(*data)++;
|
||||||
|
(*size)--;
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 20) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
rc = sqlite3_exec(db,
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8))",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
/* Insert some vectors first */
|
||||||
|
{
|
||||||
|
sqlite3_stmt *stmt;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmt, NULL);
|
||||||
|
for (int i = 1; i <= 8; i++) {
|
||||||
|
float vec[8];
|
||||||
|
for (int j = 0; j < 8; j++) vec[j] = (float)i * 0.1f + (float)j * 0.01f;
|
||||||
|
sqlite3_reset(stmt);
|
||||||
|
sqlite3_bind_int64(stmt, 1, i);
|
||||||
|
sqlite3_bind_blob(stmt, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmt);
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3_stmt *stmtCmd = NULL;
|
||||||
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
|
sqlite3_stmt *stmtKnn = NULL;
|
||||||
|
|
||||||
|
/* Commands are dispatched via INSERT INTO t(t) VALUES ('cmd_string') */
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(v) VALUES (?)", -1, &stmtCmd, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?",
|
||||||
|
-1, &stmtKnn, NULL);
|
||||||
|
|
||||||
|
if (!stmtCmd || !stmtInsert || !stmtKnn) goto cleanup;
|
||||||
|
|
||||||
|
/* Fuzz-driven command + operation interleaving */
|
||||||
|
while (size >= 2) {
|
||||||
|
uint8_t op = fuzz_byte(&data, &size, 0) % 5;
|
||||||
|
|
||||||
|
switch (op) {
|
||||||
|
case 0: { /* Send fuzz command string */
|
||||||
|
int cmd_len = fuzz_byte(&data, &size, 0) % 64;
|
||||||
|
char cmd[65];
|
||||||
|
for (int i = 0; i < cmd_len && size > 0; i++) {
|
||||||
|
cmd[i] = (char)fuzz_byte(&data, &size, 0);
|
||||||
|
}
|
||||||
|
cmd[cmd_len] = '\0';
|
||||||
|
sqlite3_reset(stmtCmd);
|
||||||
|
sqlite3_bind_text(stmtCmd, 1, cmd, -1, SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtCmd); /* May fail -- that's expected */
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 1: { /* Send valid-looking command with fuzz value */
|
||||||
|
const char *prefixes[] = {
|
||||||
|
"search_list_size=",
|
||||||
|
"search_list_size_search=",
|
||||||
|
"search_list_size_insert=",
|
||||||
|
};
|
||||||
|
int prefix_idx = fuzz_byte(&data, &size, 0) % 3;
|
||||||
|
int val = (int)(int8_t)fuzz_byte(&data, &size, 0);
|
||||||
|
|
||||||
|
char cmd[128];
|
||||||
|
snprintf(cmd, sizeof(cmd), "%s%d", prefixes[prefix_idx], val);
|
||||||
|
sqlite3_reset(stmtCmd);
|
||||||
|
sqlite3_bind_text(stmtCmd, 1, cmd, -1, SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtCmd);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 2: { /* KNN query (uses whatever search_list_size is set) */
|
||||||
|
float qvec[8] = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
|
||||||
|
qvec[0] = (float)((int8_t)fuzz_byte(&data, &size, 127)) / 10.0f;
|
||||||
|
int k = fuzz_byte(&data, &size, 3) % 10 + 1;
|
||||||
|
sqlite3_reset(stmtKnn);
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_bind_int(stmtKnn, 2, k);
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 3: { /* Insert (uses whatever search_list_size_insert is set) */
|
||||||
|
int64_t rowid = (int64_t)(fuzz_byte(&data, &size, 0) % 32) + 1;
|
||||||
|
float vec[8];
|
||||||
|
for (int j = 0; j < 8; j++) {
|
||||||
|
vec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f;
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 4: { /* Set search_list_size to extreme values */
|
||||||
|
const char *extreme_cmds[] = {
|
||||||
|
"search_list_size=1",
|
||||||
|
"search_list_size=2",
|
||||||
|
"search_list_size=1000",
|
||||||
|
"search_list_size_search=1",
|
||||||
|
"search_list_size_insert=1",
|
||||||
|
};
|
||||||
|
int idx = fuzz_byte(&data, &size, 0) % 5;
|
||||||
|
sqlite3_reset(stmtCmd);
|
||||||
|
sqlite3_bind_text(stmtCmd, 1, extreme_cmds[idx], -1, SQLITE_STATIC);
|
||||||
|
sqlite3_step(stmtCmd);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup:
|
||||||
|
sqlite3_finalize(stmtCmd);
|
||||||
|
sqlite3_finalize(stmtInsert);
|
||||||
|
sqlite3_finalize(stmtKnn);
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
44
tests/fuzz/diskann-create.c
Normal file
44
tests/fuzz/diskann-create.c
Normal file
|
|
@ -0,0 +1,44 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target for DiskANN CREATE TABLE config parsing.
|
||||||
|
* Feeds fuzz data as the INDEXED BY diskann(...) option string.
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size > 4096) return 0; /* Limit input size */
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
sqlite3_stmt *stmt;
|
||||||
|
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
sqlite3_str *s = sqlite3_str_new(NULL);
|
||||||
|
assert(s);
|
||||||
|
sqlite3_str_appendall(s,
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[64] INDEXED BY diskann(");
|
||||||
|
sqlite3_str_appendf(s, "%.*s", (int)size, data);
|
||||||
|
sqlite3_str_appendall(s, "))");
|
||||||
|
const char *zSql = sqlite3_str_finish(s);
|
||||||
|
assert(zSql);
|
||||||
|
|
||||||
|
rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, NULL);
|
||||||
|
sqlite3_free((char *)zSql);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
sqlite3_step(stmt);
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
187
tests/fuzz/diskann-deep-search.c
Normal file
187
tests/fuzz/diskann-deep-search.c
Normal file
|
|
@ -0,0 +1,187 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target for DiskANN greedy beam search deep paths.
|
||||||
|
*
|
||||||
|
* Builds a graph with enough nodes to force multi-hop traversal, then
|
||||||
|
* uses fuzz data to control: query vector values, k, search_list_size
|
||||||
|
* overrides, and interleaved insert/delete/query sequences that stress
|
||||||
|
* the candidate list growth, visited set hash collisions, and the
|
||||||
|
* re-ranking logic.
|
||||||
|
*
|
||||||
|
* Key code paths targeted:
|
||||||
|
* - diskann_candidate_list_insert (sorted insert, dedup, eviction)
|
||||||
|
* - diskann_visited_set (hash collisions, capacity)
|
||||||
|
* - diskann_search (full beam search loop, re-ranking with exact dist)
|
||||||
|
* - diskann_distance_quantized_precomputed (both binary and int8)
|
||||||
|
* - Buffer merge in vec0Filter_knn_diskann
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
/* Consume one byte from fuzz input, or return default. */
|
||||||
|
static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) {
|
||||||
|
if (*size == 0) return def;
|
||||||
|
uint8_t b = **data;
|
||||||
|
(*data)++;
|
||||||
|
(*size)--;
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
static uint16_t fuzz_u16(const uint8_t **data, size_t *size) {
|
||||||
|
uint8_t lo = fuzz_byte(data, size, 0);
|
||||||
|
uint8_t hi = fuzz_byte(data, size, 0);
|
||||||
|
return (uint16_t)hi << 8 | lo;
|
||||||
|
}
|
||||||
|
|
||||||
|
static float fuzz_float(const uint8_t **data, size_t *size) {
|
||||||
|
return (float)((int8_t)fuzz_byte(data, size, 0)) / 10.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 32) return 0;
|
||||||
|
|
||||||
|
/* Use first bytes to pick quantizer type and dimensions */
|
||||||
|
uint8_t quantizer_choice = fuzz_byte(&data, &size, 0) % 2;
|
||||||
|
const char *quantizer = quantizer_choice ? "int8" : "binary";
|
||||||
|
|
||||||
|
/* Dimensions must be divisible by 8. Pick from {8, 16, 32} */
|
||||||
|
int dim_choices[] = {8, 16, 32};
|
||||||
|
int dims = dim_choices[fuzz_byte(&data, &size, 0) % 3];
|
||||||
|
|
||||||
|
/* n_neighbors: 8 or 16 -- small to force full-neighbor scenarios quickly */
|
||||||
|
int n_neighbors = (fuzz_byte(&data, &size, 0) % 2) ? 16 : 8;
|
||||||
|
|
||||||
|
/* search_list_size: small so beam search terminates quickly but still exercises loops */
|
||||||
|
int search_list_size = 8 + (fuzz_byte(&data, &size, 0) % 24);
|
||||||
|
|
||||||
|
/* alpha: vary to test RobustPrune pruning logic */
|
||||||
|
float alpha_choices[] = {1.0f, 1.2f, 1.5f, 2.0f};
|
||||||
|
float alpha = alpha_choices[fuzz_byte(&data, &size, 0) % 4];
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
char sql[512];
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[%d] INDEXED BY diskann("
|
||||||
|
"neighbor_quantizer=%s, n_neighbors=%d, "
|
||||||
|
"search_list_size=%d"
|
||||||
|
"))", dims, quantizer, n_neighbors, search_list_size);
|
||||||
|
|
||||||
|
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
sqlite3_stmt *stmtInsert = NULL, *stmtDelete = NULL, *stmtKnn = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
|
||||||
|
|
||||||
|
char knn_sql[256];
|
||||||
|
snprintf(knn_sql, sizeof(knn_sql),
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?");
|
||||||
|
sqlite3_prepare_v2(db, knn_sql, -1, &stmtKnn, NULL);
|
||||||
|
|
||||||
|
if (!stmtInsert || !stmtDelete || !stmtKnn) goto cleanup;
|
||||||
|
|
||||||
|
/* Phase 1: Seed the graph with enough nodes to create multi-hop structure.
|
||||||
|
* Insert 2*n_neighbors nodes so the graph is dense enough for search
|
||||||
|
* to actually traverse multiple hops. */
|
||||||
|
int seed_count = n_neighbors * 2;
|
||||||
|
if (seed_count > 64) seed_count = 64; /* Bound for performance */
|
||||||
|
{
|
||||||
|
float *vec = malloc(dims * sizeof(float));
|
||||||
|
if (!vec) goto cleanup;
|
||||||
|
for (int i = 1; i <= seed_count; i++) {
|
||||||
|
for (int j = 0; j < dims; j++) {
|
||||||
|
vec[j] = fuzz_float(&data, &size);
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, i);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, dims * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
}
|
||||||
|
free(vec);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Phase 2: Fuzz-driven operations on the seeded graph */
|
||||||
|
float *vec = malloc(dims * sizeof(float));
|
||||||
|
if (!vec) goto cleanup;
|
||||||
|
|
||||||
|
while (size >= 2) {
|
||||||
|
uint8_t op = fuzz_byte(&data, &size, 0) % 5;
|
||||||
|
uint8_t param = fuzz_byte(&data, &size, 0);
|
||||||
|
|
||||||
|
switch (op) {
|
||||||
|
case 0: { /* INSERT with fuzz-controlled vector and rowid */
|
||||||
|
int64_t rowid = (int64_t)(param % 128) + 1;
|
||||||
|
for (int j = 0; j < dims; j++) {
|
||||||
|
vec[j] = fuzz_float(&data, &size);
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, dims * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 1: { /* DELETE */
|
||||||
|
int64_t rowid = (int64_t)(param % 128) + 1;
|
||||||
|
sqlite3_reset(stmtDelete);
|
||||||
|
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||||
|
sqlite3_step(stmtDelete);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 2: { /* KNN with fuzz query vector and variable k */
|
||||||
|
for (int j = 0; j < dims; j++) {
|
||||||
|
vec[j] = fuzz_float(&data, &size);
|
||||||
|
}
|
||||||
|
int k = (param % 20) + 1;
|
||||||
|
sqlite3_reset(stmtKnn);
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, vec, dims * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_bind_int(stmtKnn, 2, k);
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 3: { /* KNN with k > number of nodes (boundary) */
|
||||||
|
for (int j = 0; j < dims; j++) {
|
||||||
|
vec[j] = fuzz_float(&data, &size);
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtKnn);
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, vec, dims * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_bind_int(stmtKnn, 2, 1000); /* k >> graph size */
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 4: { /* INSERT duplicate rowid (triggers OR REPLACE path) */
|
||||||
|
int64_t rowid = (int64_t)(param % 32) + 1;
|
||||||
|
for (int j = 0; j < dims; j++) {
|
||||||
|
vec[j] = (float)(param + j) / 50.0f;
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, dims * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
free(vec);
|
||||||
|
|
||||||
|
cleanup:
|
||||||
|
sqlite3_finalize(stmtInsert);
|
||||||
|
sqlite3_finalize(stmtDelete);
|
||||||
|
sqlite3_finalize(stmtKnn);
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
175
tests/fuzz/diskann-delete-stress.c
Normal file
175
tests/fuzz/diskann-delete-stress.c
Normal file
|
|
@ -0,0 +1,175 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target for DiskANN delete path and graph connectivity maintenance.
|
||||||
|
*
|
||||||
|
* The delete path is the most complex graph mutation:
|
||||||
|
* 1. Read deleted node's neighbor list
|
||||||
|
* 2. For each neighbor, remove deleted node from their list
|
||||||
|
* 3. Try to fill the gap with one of deleted node's other neighbors
|
||||||
|
* 4. Handle medoid deletion (pick new medoid)
|
||||||
|
*
|
||||||
|
* Edge cases this targets:
|
||||||
|
* - Delete the medoid (entry point) -- forces medoid reassignment
|
||||||
|
* - Delete all nodes except one -- graph degenerates
|
||||||
|
* - Delete nodes in a chain -- cascading dangling edges
|
||||||
|
* - Re-insert at deleted rowids -- stale graph edges to old data
|
||||||
|
* - Delete nonexistent rowids -- should be no-op
|
||||||
|
* - Insert-delete-insert same rowid rapidly
|
||||||
|
* - Delete when graph has exactly n_neighbors entries (full nodes)
|
||||||
|
*
|
||||||
|
* Key code paths:
|
||||||
|
* - diskann_delete -> diskann_repair_reverse_edges
|
||||||
|
* - diskann_medoid_handle_delete
|
||||||
|
* - diskann_node_clear_neighbor
|
||||||
|
* - Interaction between delete and concurrent search
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) {
|
||||||
|
if (*size == 0) return def;
|
||||||
|
uint8_t b = **data;
|
||||||
|
(*data)++;
|
||||||
|
(*size)--;
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 20) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
/* int8 quantizer to exercise that distance code path */
|
||||||
|
uint8_t quant = fuzz_byte(&data, &size, 0) % 2;
|
||||||
|
const char *qname = quant ? "int8" : "binary";
|
||||||
|
|
||||||
|
char sql[256];
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[8] INDEXED BY diskann(neighbor_quantizer=%s, n_neighbors=8))",
|
||||||
|
qname);
|
||||||
|
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
sqlite3_stmt *stmtInsert = NULL, *stmtDelete = NULL, *stmtKnn = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?",
|
||||||
|
-1, &stmtKnn, NULL);
|
||||||
|
|
||||||
|
if (!stmtInsert || !stmtDelete || !stmtKnn) goto cleanup;
|
||||||
|
|
||||||
|
/* Phase 1: Build a graph of exactly n_neighbors+2 = 10 nodes.
|
||||||
|
* This makes every node nearly full, maximizing the chance that
|
||||||
|
* inserts trigger the "full node" path in add_reverse_edge. */
|
||||||
|
for (int i = 1; i <= 10; i++) {
|
||||||
|
float vec[8];
|
||||||
|
for (int j = 0; j < 8; j++) {
|
||||||
|
vec[j] = (float)((int8_t)fuzz_byte(&data, &size, (uint8_t)(i*13+j*7))) / 20.0f;
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, i);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Phase 2: Fuzz-driven delete-heavy workload */
|
||||||
|
while (size >= 2) {
|
||||||
|
uint8_t op = fuzz_byte(&data, &size, 0);
|
||||||
|
uint8_t param = fuzz_byte(&data, &size, 0);
|
||||||
|
|
||||||
|
switch (op % 6) {
|
||||||
|
case 0: /* Delete existing node */
|
||||||
|
case 1: { /* (weighted toward deletes) */
|
||||||
|
int64_t rowid = (int64_t)(param % 16) + 1;
|
||||||
|
sqlite3_reset(stmtDelete);
|
||||||
|
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||||
|
sqlite3_step(stmtDelete);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 2: { /* Delete then immediately re-insert same rowid */
|
||||||
|
int64_t rowid = (int64_t)(param % 10) + 1;
|
||||||
|
sqlite3_reset(stmtDelete);
|
||||||
|
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||||
|
sqlite3_step(stmtDelete);
|
||||||
|
|
||||||
|
float vec[8];
|
||||||
|
for (int j = 0; j < 8; j++) {
|
||||||
|
vec[j] = (float)((int8_t)fuzz_byte(&data, &size, (uint8_t)(rowid+j))) / 15.0f;
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 3: { /* KNN query on potentially sparse/empty graph */
|
||||||
|
float qvec[8];
|
||||||
|
for (int j = 0; j < 8; j++) {
|
||||||
|
qvec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f;
|
||||||
|
}
|
||||||
|
int k = (param % 15) + 1;
|
||||||
|
sqlite3_reset(stmtKnn);
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_bind_int(stmtKnn, 2, k);
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 4: { /* Insert new node */
|
||||||
|
int64_t rowid = (int64_t)(param % 32) + 1;
|
||||||
|
float vec[8];
|
||||||
|
for (int j = 0; j < 8; j++) {
|
||||||
|
vec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f;
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 5: { /* Delete ALL remaining nodes, then insert fresh */
|
||||||
|
for (int i = 1; i <= 32; i++) {
|
||||||
|
sqlite3_reset(stmtDelete);
|
||||||
|
sqlite3_bind_int64(stmtDelete, 1, i);
|
||||||
|
sqlite3_step(stmtDelete);
|
||||||
|
}
|
||||||
|
/* Now insert one node into empty graph */
|
||||||
|
float vec[8] = {1.0f, 0, 0, 0, 0, 0, 0, 0};
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, 1);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Final KNN on whatever state the graph is in */
|
||||||
|
{
|
||||||
|
float qvec[8] = {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f};
|
||||||
|
sqlite3_reset(stmtKnn);
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_bind_int(stmtKnn, 2, 10);
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup:
|
||||||
|
sqlite3_finalize(stmtInsert);
|
||||||
|
sqlite3_finalize(stmtDelete);
|
||||||
|
sqlite3_finalize(stmtKnn);
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
123
tests/fuzz/diskann-graph-corrupt.c
Normal file
123
tests/fuzz/diskann-graph-corrupt.c
Normal file
|
|
@ -0,0 +1,123 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target for DiskANN shadow table corruption resilience.
|
||||||
|
* Creates and populates a DiskANN table, then corrupts shadow table blobs
|
||||||
|
* using fuzz data and runs queries.
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 16) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
rc = sqlite3_exec(db,
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8))",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
/* Insert a few vectors to create graph structure */
|
||||||
|
{
|
||||||
|
sqlite3_stmt *stmt;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmt, NULL);
|
||||||
|
for (int i = 1; i <= 10; i++) {
|
||||||
|
float vec[8];
|
||||||
|
for (int j = 0; j < 8; j++) {
|
||||||
|
vec[j] = (float)i * 0.1f + (float)j * 0.01f;
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmt);
|
||||||
|
sqlite3_bind_int64(stmt, 1, i);
|
||||||
|
sqlite3_bind_blob(stmt, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmt);
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Corrupt shadow table data using fuzz bytes */
|
||||||
|
size_t offset = 0;
|
||||||
|
|
||||||
|
/* Determine which row and column to corrupt */
|
||||||
|
int target_row = (data[offset++] % 10) + 1;
|
||||||
|
int corrupt_type = data[offset++] % 3; /* 0=validity, 1=neighbor_ids, 2=qvecs */
|
||||||
|
|
||||||
|
const char *column_name;
|
||||||
|
switch (corrupt_type) {
|
||||||
|
case 0: column_name = "neighbors_validity"; break;
|
||||||
|
case 1: column_name = "neighbor_ids"; break;
|
||||||
|
default: column_name = "neighbor_quantized_vectors"; break;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Read the blob, corrupt it, write it back */
|
||||||
|
{
|
||||||
|
sqlite3_stmt *readStmt;
|
||||||
|
char sqlbuf[256];
|
||||||
|
snprintf(sqlbuf, sizeof(sqlbuf),
|
||||||
|
"SELECT %s FROM v_diskann_nodes00 WHERE rowid = ?", column_name);
|
||||||
|
rc = sqlite3_prepare_v2(db, sqlbuf, -1, &readStmt, NULL);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
sqlite3_bind_int64(readStmt, 1, target_row);
|
||||||
|
if (sqlite3_step(readStmt) == SQLITE_ROW) {
|
||||||
|
const void *blob = sqlite3_column_blob(readStmt, 0);
|
||||||
|
int blobSize = sqlite3_column_bytes(readStmt, 0);
|
||||||
|
if (blob && blobSize > 0) {
|
||||||
|
unsigned char *corrupt = sqlite3_malloc(blobSize);
|
||||||
|
if (corrupt) {
|
||||||
|
memcpy(corrupt, blob, blobSize);
|
||||||
|
/* Apply fuzz bytes as XOR corruption */
|
||||||
|
size_t remaining = size - offset;
|
||||||
|
for (size_t i = 0; i < remaining && i < (size_t)blobSize; i++) {
|
||||||
|
corrupt[i % blobSize] ^= data[offset + i];
|
||||||
|
}
|
||||||
|
/* Write back */
|
||||||
|
sqlite3_stmt *writeStmt;
|
||||||
|
snprintf(sqlbuf, sizeof(sqlbuf),
|
||||||
|
"UPDATE v_diskann_nodes00 SET %s = ? WHERE rowid = ?", column_name);
|
||||||
|
rc = sqlite3_prepare_v2(db, sqlbuf, -1, &writeStmt, NULL);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
sqlite3_bind_blob(writeStmt, 1, corrupt, blobSize, SQLITE_TRANSIENT);
|
||||||
|
sqlite3_bind_int64(writeStmt, 2, target_row);
|
||||||
|
sqlite3_step(writeStmt);
|
||||||
|
sqlite3_finalize(writeStmt);
|
||||||
|
}
|
||||||
|
sqlite3_free(corrupt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sqlite3_finalize(readStmt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Run queries on corrupted graph -- should not crash */
|
||||||
|
{
|
||||||
|
float qvec[8] = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
|
||||||
|
sqlite3_stmt *knnStmt;
|
||||||
|
rc = sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = 5",
|
||||||
|
-1, &knnStmt, NULL);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
sqlite3_bind_blob(knnStmt, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||||
|
while (sqlite3_step(knnStmt) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(knnStmt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Full scan */
|
||||||
|
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||||
|
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
164
tests/fuzz/diskann-int8-quant.c
Normal file
164
tests/fuzz/diskann-int8-quant.c
Normal file
|
|
@ -0,0 +1,164 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target for DiskANN int8 quantizer edge cases.
|
||||||
|
*
|
||||||
|
* The binary quantizer is simple (sign bit), but the int8 quantizer has
|
||||||
|
* interesting arithmetic:
|
||||||
|
* i8_val = (i8)(((src - (-1.0f)) / step) - 128.0f)
|
||||||
|
* where step = 2.0f / 255.0f
|
||||||
|
*
|
||||||
|
* Edge cases in this formula:
|
||||||
|
* - src values outside [-1, 1] cause clamping issues (no explicit clamp!)
|
||||||
|
* - src = NaN, +Inf, -Inf (from corrupted vectors or div-by-zero)
|
||||||
|
* - src very close to boundaries (-1.0, 1.0) -- rounding
|
||||||
|
* - The cast to i8 can overflow for extreme src values
|
||||||
|
*
|
||||||
|
* Also exercises int8 distance functions:
|
||||||
|
* - distance_l2_sqr_int8: accumulates squared differences, possible overflow
|
||||||
|
* - distance_cosine_int8: dot product with normalization
|
||||||
|
* - distance_l1_int8: absolute differences
|
||||||
|
*
|
||||||
|
* This fuzzer also tests the cosine distance metric path which the
|
||||||
|
* other fuzzers (using L2 default) don't cover.
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) {
|
||||||
|
if (*size == 0) return def;
|
||||||
|
uint8_t b = **data;
|
||||||
|
(*data)++;
|
||||||
|
(*size)--;
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
static float fuzz_extreme_float(const uint8_t **data, size_t *size) {
|
||||||
|
uint8_t mode = fuzz_byte(data, size, 0) % 8;
|
||||||
|
uint8_t raw = fuzz_byte(data, size, 0);
|
||||||
|
switch (mode) {
|
||||||
|
case 0: return (float)((int8_t)raw) / 10.0f; /* Normal range */
|
||||||
|
case 1: return (float)((int8_t)raw) * 100.0f; /* Large values */
|
||||||
|
case 2: return (float)((int8_t)raw) / 1000.0f; /* Tiny values near 0 */
|
||||||
|
case 3: return -1.0f; /* Exact boundary */
|
||||||
|
case 4: return 1.0f; /* Exact boundary */
|
||||||
|
case 5: return 0.0f; /* Zero */
|
||||||
|
case 6: return (float)raw / 255.0f; /* [0, 1] range */
|
||||||
|
case 7: return -(float)raw / 255.0f; /* [-1, 0] range */
|
||||||
|
}
|
||||||
|
return 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 40) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
/* Test both distance metrics with int8 quantizer */
|
||||||
|
uint8_t metric_choice = fuzz_byte(&data, &size, 0) % 2;
|
||||||
|
const char *metric = metric_choice ? "cosine" : "L2";
|
||||||
|
|
||||||
|
int dims = 8 + (fuzz_byte(&data, &size, 0) % 3) * 8; /* 8, 16, or 24 */
|
||||||
|
|
||||||
|
char sql[512];
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[%d] distance_metric=%s "
|
||||||
|
"INDEXED BY diskann(neighbor_quantizer=int8, n_neighbors=8, search_list_size=16))",
|
||||||
|
dims, metric);
|
||||||
|
|
||||||
|
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
sqlite3_stmt *stmtInsert = NULL, *stmtKnn = NULL, *stmtDelete = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?",
|
||||||
|
-1, &stmtKnn, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
|
||||||
|
|
||||||
|
if (!stmtInsert || !stmtKnn || !stmtDelete) goto cleanup;
|
||||||
|
|
||||||
|
/* Insert vectors with extreme float values to stress quantization */
|
||||||
|
float *vec = malloc(dims * sizeof(float));
|
||||||
|
if (!vec) goto cleanup;
|
||||||
|
|
||||||
|
for (int i = 1; i <= 16; i++) {
|
||||||
|
for (int j = 0; j < dims; j++) {
|
||||||
|
vec[j] = fuzz_extreme_float(&data, &size);
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, i);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, dims * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Fuzz-driven operations */
|
||||||
|
while (size >= 2) {
|
||||||
|
uint8_t op = fuzz_byte(&data, &size, 0) % 4;
|
||||||
|
uint8_t param = fuzz_byte(&data, &size, 0);
|
||||||
|
|
||||||
|
switch (op) {
|
||||||
|
case 0: { /* KNN with extreme query values */
|
||||||
|
for (int j = 0; j < dims; j++) {
|
||||||
|
vec[j] = fuzz_extreme_float(&data, &size);
|
||||||
|
}
|
||||||
|
int k = (param % 10) + 1;
|
||||||
|
sqlite3_reset(stmtKnn);
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, vec, dims * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_bind_int(stmtKnn, 2, k);
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 1: { /* Insert with extreme values */
|
||||||
|
int64_t rowid = (int64_t)(param % 32) + 1;
|
||||||
|
for (int j = 0; j < dims; j++) {
|
||||||
|
vec[j] = fuzz_extreme_float(&data, &size);
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, dims * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 2: { /* Delete */
|
||||||
|
int64_t rowid = (int64_t)(param % 32) + 1;
|
||||||
|
sqlite3_reset(stmtDelete);
|
||||||
|
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||||
|
sqlite3_step(stmtDelete);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 3: { /* KNN with all-zero or all-same-value query */
|
||||||
|
float val = (param % 3 == 0) ? 0.0f :
|
||||||
|
(param % 3 == 1) ? 1.0f : -1.0f;
|
||||||
|
for (int j = 0; j < dims; j++) vec[j] = val;
|
||||||
|
sqlite3_reset(stmtKnn);
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, vec, dims * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_bind_int(stmtKnn, 2, 5);
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
free(vec);
|
||||||
|
|
||||||
|
cleanup:
|
||||||
|
sqlite3_finalize(stmtInsert);
|
||||||
|
sqlite3_finalize(stmtKnn);
|
||||||
|
sqlite3_finalize(stmtDelete);
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
100
tests/fuzz/diskann-operations.c
Normal file
100
tests/fuzz/diskann-operations.c
Normal file
|
|
@ -0,0 +1,100 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target for DiskANN insert/delete/query operation sequences.
|
||||||
|
* Uses fuzz bytes to drive random operations on a DiskANN-indexed table.
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 6) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
|
sqlite3_stmt *stmtDelete = NULL;
|
||||||
|
sqlite3_stmt *stmtKnn = NULL;
|
||||||
|
sqlite3_stmt *stmtScan = NULL;
|
||||||
|
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
rc = sqlite3_exec(db,
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8))",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = 3",
|
||||||
|
-1, &stmtKnn, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid FROM v", -1, &stmtScan, NULL);
|
||||||
|
|
||||||
|
if (!stmtInsert || !stmtDelete || !stmtKnn || !stmtScan) goto cleanup;
|
||||||
|
|
||||||
|
size_t i = 0;
|
||||||
|
while (i + 2 <= size) {
|
||||||
|
uint8_t op = data[i++] % 4;
|
||||||
|
uint8_t rowid_byte = data[i++];
|
||||||
|
int64_t rowid = (int64_t)(rowid_byte % 32) + 1;
|
||||||
|
|
||||||
|
switch (op) {
|
||||||
|
case 0: {
|
||||||
|
/* INSERT: consume 32 bytes for 8 floats, or use what's left */
|
||||||
|
float vec[8] = {0};
|
||||||
|
for (int j = 0; j < 8 && i < size; j++, i++) {
|
||||||
|
vec[j] = (float)((int8_t)data[i]) / 10.0f;
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 1: {
|
||||||
|
/* DELETE */
|
||||||
|
sqlite3_reset(stmtDelete);
|
||||||
|
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||||
|
sqlite3_step(stmtDelete);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 2: {
|
||||||
|
/* KNN query */
|
||||||
|
float qvec[8] = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
|
||||||
|
sqlite3_reset(stmtKnn);
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 3: {
|
||||||
|
/* Full scan */
|
||||||
|
sqlite3_reset(stmtScan);
|
||||||
|
while (sqlite3_step(stmtScan) == SQLITE_ROW) {}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Final operations -- must not crash regardless of prior state */
|
||||||
|
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||||
|
|
||||||
|
cleanup:
|
||||||
|
sqlite3_finalize(stmtInsert);
|
||||||
|
sqlite3_finalize(stmtDelete);
|
||||||
|
sqlite3_finalize(stmtKnn);
|
||||||
|
sqlite3_finalize(stmtScan);
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
131
tests/fuzz/diskann-prune-direct.c
Normal file
131
tests/fuzz/diskann-prune-direct.c
Normal file
|
|
@ -0,0 +1,131 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target for DiskANN RobustPrune algorithm (diskann_prune_select).
|
||||||
|
*
|
||||||
|
* diskann_prune_select is exposed for testing and takes:
|
||||||
|
* - inter_distances: flattened NxN matrix of inter-candidate distances
|
||||||
|
* - p_distances: N distances from node p to each candidate
|
||||||
|
* - num_candidates, alpha, max_neighbors
|
||||||
|
*
|
||||||
|
* This is a pure function that doesn't need a database, so we can
|
||||||
|
* call it directly with fuzz-controlled inputs. This gives the fuzzer
|
||||||
|
* maximum speed (no SQLite overhead) to explore:
|
||||||
|
*
|
||||||
|
* - alpha boundary: alpha=0 (prunes nothing), alpha=very large (prunes all)
|
||||||
|
* - max_neighbors = 0, 1, num_candidates, > num_candidates
|
||||||
|
* - num_candidates = 0, 1, large
|
||||||
|
* - Distance matrices with: all zeros, all same, negative values, NaN, Inf
|
||||||
|
* - Non-symmetric distance matrices (should still work)
|
||||||
|
* - Memory: large num_candidates to stress malloc
|
||||||
|
*
|
||||||
|
* Key code paths:
|
||||||
|
* - diskann_prune_select alpha-pruning loop
|
||||||
|
* - Boundary: selectedCount reaches max_neighbors exactly
|
||||||
|
* - All candidates pruned before max_neighbors reached
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
/* Declare the test-exposed function.
|
||||||
|
* diskann_prune_select is not static -- it's a public symbol. */
|
||||||
|
extern int diskann_prune_select(
|
||||||
|
const float *inter_distances, const float *p_distances,
|
||||||
|
int num_candidates, float alpha, int max_neighbors,
|
||||||
|
int *outSelected, int *outCount);
|
||||||
|
|
||||||
|
static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) {
|
||||||
|
if (*size == 0) return def;
|
||||||
|
uint8_t b = **data;
|
||||||
|
(*data)++;
|
||||||
|
(*size)--;
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 8) return 0;
|
||||||
|
|
||||||
|
/* Consume parameters from fuzz data */
|
||||||
|
int num_candidates = fuzz_byte(&data, &size, 0) % 33; /* 0..32 */
|
||||||
|
int max_neighbors = fuzz_byte(&data, &size, 0) % 17; /* 0..16 */
|
||||||
|
|
||||||
|
/* Alpha: pick from interesting values */
|
||||||
|
uint8_t alpha_idx = fuzz_byte(&data, &size, 0) % 8;
|
||||||
|
float alpha_values[] = {0.0f, 0.5f, 1.0f, 1.2f, 1.5f, 2.0f, 10.0f, 100.0f};
|
||||||
|
float alpha = alpha_values[alpha_idx];
|
||||||
|
|
||||||
|
if (num_candidates == 0) {
|
||||||
|
/* Test empty case */
|
||||||
|
int outCount = -1;
|
||||||
|
int rc = diskann_prune_select(NULL, NULL, 0, alpha, max_neighbors,
|
||||||
|
NULL, &outCount);
|
||||||
|
assert(rc == 0 /* SQLITE_OK */);
|
||||||
|
assert(outCount == 0);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Allocate arrays */
|
||||||
|
int n = num_candidates;
|
||||||
|
float *inter_distances = malloc(n * n * sizeof(float));
|
||||||
|
float *p_distances = malloc(n * sizeof(float));
|
||||||
|
int *outSelected = malloc(n * sizeof(int));
|
||||||
|
if (!inter_distances || !p_distances || !outSelected) {
|
||||||
|
free(inter_distances);
|
||||||
|
free(p_distances);
|
||||||
|
free(outSelected);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Fill p_distances from fuzz data (sorted ascending for correct input) */
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
uint8_t raw = fuzz_byte(&data, &size, (uint8_t)(i * 10));
|
||||||
|
p_distances[i] = (float)raw / 10.0f;
|
||||||
|
}
|
||||||
|
/* Sort p_distances ascending (prune_select expects sorted input) */
|
||||||
|
for (int i = 1; i < n; i++) {
|
||||||
|
float tmp = p_distances[i];
|
||||||
|
int j = i - 1;
|
||||||
|
while (j >= 0 && p_distances[j] > tmp) {
|
||||||
|
p_distances[j + 1] = p_distances[j];
|
||||||
|
j--;
|
||||||
|
}
|
||||||
|
p_distances[j + 1] = tmp;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Fill inter-distance matrix from fuzz data */
|
||||||
|
for (int i = 0; i < n * n; i++) {
|
||||||
|
uint8_t raw = fuzz_byte(&data, &size, (uint8_t)(i % 256));
|
||||||
|
inter_distances[i] = (float)raw / 10.0f;
|
||||||
|
}
|
||||||
|
/* Make diagonal zero */
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
inter_distances[i * n + i] = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
int outCount = -1;
|
||||||
|
int rc = diskann_prune_select(inter_distances, p_distances,
|
||||||
|
n, alpha, max_neighbors,
|
||||||
|
outSelected, &outCount);
|
||||||
|
/* Basic sanity: should not crash, count should be valid */
|
||||||
|
assert(rc == 0);
|
||||||
|
assert(outCount >= 0);
|
||||||
|
assert(outCount <= max_neighbors || max_neighbors == 0);
|
||||||
|
assert(outCount <= n);
|
||||||
|
|
||||||
|
/* Verify outSelected flags are consistent with outCount */
|
||||||
|
int flagCount = 0;
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
if (outSelected[i]) flagCount++;
|
||||||
|
}
|
||||||
|
assert(flagCount == outCount);
|
||||||
|
|
||||||
|
free(inter_distances);
|
||||||
|
free(p_distances);
|
||||||
|
free(outSelected);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
10
tests/fuzz/diskann.dict
Normal file
10
tests/fuzz/diskann.dict
Normal file
|
|
@ -0,0 +1,10 @@
|
||||||
|
"neighbor_quantizer"
|
||||||
|
"binary"
|
||||||
|
"int8"
|
||||||
|
"n_neighbors"
|
||||||
|
"search_list_size"
|
||||||
|
"search_list_size_search"
|
||||||
|
"search_list_size_insert"
|
||||||
|
"alpha"
|
||||||
|
"="
|
||||||
|
","
|
||||||
192
tests/fuzz/ivf-cell-overflow.c
Normal file
192
tests/fuzz/ivf-cell-overflow.c
Normal file
|
|
@ -0,0 +1,192 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target: IVF cell overflow and boundary conditions.
|
||||||
|
*
|
||||||
|
* Pushes cells past VEC0_IVF_CELL_MAX_VECTORS (64) to trigger cell
|
||||||
|
* splitting, then exercises blob I/O at slot boundaries.
|
||||||
|
*
|
||||||
|
* Targets:
|
||||||
|
* - Cell splitting when n_vectors reaches cap (64)
|
||||||
|
* - Blob offset arithmetic: slot * vecSize, slot / 8, slot % 8
|
||||||
|
* - Validity bitmap at byte boundaries (slot 7->8, 15->16, etc.)
|
||||||
|
* - Insert into full cell -> create new cell path
|
||||||
|
* - Delete from various slot positions (first, last, middle)
|
||||||
|
* - Multiple cells per centroid
|
||||||
|
* - assign-vectors command with multi-cell centroids
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 8) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
// Use small dimensions for speed but enough vectors to overflow cells
|
||||||
|
int dim = (data[0] % 8) + 2; // 2..9
|
||||||
|
int nlist = (data[1] % 4) + 1; // 1..4
|
||||||
|
// We need >64 vectors to overflow a cell
|
||||||
|
int num_vecs = (data[2] % 64) + 65; // 65..128
|
||||||
|
int delete_pattern = data[3]; // Controls which vectors to delete
|
||||||
|
|
||||||
|
const uint8_t *payload = data + 4;
|
||||||
|
size_t payload_size = size - 4;
|
||||||
|
|
||||||
|
char sql[256];
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[%d] indexed by ivf(nlist=%d, nprobe=%d))",
|
||||||
|
dim, nlist, nlist);
|
||||||
|
|
||||||
|
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
// Insert enough vectors to overflow at least one cell
|
||||||
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
|
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
size_t offset = 0;
|
||||||
|
for (int i = 0; i < num_vecs; i++) {
|
||||||
|
float *vec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (!vec) break;
|
||||||
|
for (int d = 0; d < dim; d++) {
|
||||||
|
if (offset < payload_size) {
|
||||||
|
vec[d] = ((float)(int8_t)payload[offset++]) / 50.0f;
|
||||||
|
} else {
|
||||||
|
// Cluster vectors near specific centroids to ensure some cells overflow
|
||||||
|
int cluster = i % nlist;
|
||||||
|
vec[d] = (float)cluster + (float)(i % 10) * 0.01f + d * 0.001f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, (int64_t)(i + 1));
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
sqlite3_free(vec);
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmtInsert);
|
||||||
|
|
||||||
|
// Train to assign vectors to centroids (triggers cell building)
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// Delete vectors at boundary positions based on fuzz data
|
||||||
|
// This tests validity bitmap manipulation at different slot positions
|
||||||
|
for (int i = 0; i < num_vecs; i++) {
|
||||||
|
int byte_idx = i / 8;
|
||||||
|
if (byte_idx < (int)payload_size && (payload[byte_idx] & (1 << (i % 8)))) {
|
||||||
|
// Use delete_pattern to thin deletions
|
||||||
|
if ((delete_pattern + i) % 3 == 0) {
|
||||||
|
char delsql[64];
|
||||||
|
snprintf(delsql, sizeof(delsql), "DELETE FROM v WHERE rowid = %d", i + 1);
|
||||||
|
sqlite3_exec(db, delsql, NULL, NULL, NULL);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert more vectors after deletions (into cells with holes)
|
||||||
|
{
|
||||||
|
sqlite3_stmt *si = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &si, NULL);
|
||||||
|
if (si) {
|
||||||
|
for (int i = 0; i < 10; i++) {
|
||||||
|
float *vec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (!vec) break;
|
||||||
|
for (int d = 0; d < dim; d++)
|
||||||
|
vec[d] = (float)(i + 200) * 0.01f;
|
||||||
|
sqlite3_reset(si);
|
||||||
|
sqlite3_bind_int64(si, 1, (int64_t)(num_vecs + i + 1));
|
||||||
|
sqlite3_bind_blob(si, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(si);
|
||||||
|
sqlite3_free(vec);
|
||||||
|
}
|
||||||
|
sqlite3_finalize(si);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// KNN query that must scan multiple cells per centroid
|
||||||
|
{
|
||||||
|
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (qvec) {
|
||||||
|
for (int d = 0; d < dim; d++) qvec[d] = 0.0f;
|
||||||
|
sqlite3_stmt *sk = NULL;
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 20");
|
||||||
|
sqlite3_prepare_v2(db, sql, -1, &sk, NULL);
|
||||||
|
if (sk) {
|
||||||
|
sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
while (sqlite3_step(sk) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(sk);
|
||||||
|
}
|
||||||
|
sqlite3_free(qvec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test assign-vectors with multi-cell state
|
||||||
|
// First clear centroids
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"INSERT INTO v(v) VALUES ('clear-centroids')",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// Set centroids manually, then assign
|
||||||
|
for (int c = 0; c < nlist; c++) {
|
||||||
|
float *cvec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (!cvec) break;
|
||||||
|
for (int d = 0; d < dim; d++) cvec[d] = (float)c + d * 0.1f;
|
||||||
|
|
||||||
|
char cmd[128];
|
||||||
|
snprintf(cmd, sizeof(cmd),
|
||||||
|
"INSERT INTO v(v, emb) VALUES ('set-centroid:%d', ?)", c);
|
||||||
|
sqlite3_stmt *sc = NULL;
|
||||||
|
sqlite3_prepare_v2(db, cmd, -1, &sc, NULL);
|
||||||
|
if (sc) {
|
||||||
|
sqlite3_bind_blob(sc, 1, cvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(sc);
|
||||||
|
sqlite3_finalize(sc);
|
||||||
|
}
|
||||||
|
sqlite3_free(cvec);
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"INSERT INTO v(v) VALUES ('assign-vectors')",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// Final query after assign-vectors
|
||||||
|
{
|
||||||
|
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (qvec) {
|
||||||
|
for (int d = 0; d < dim; d++) qvec[d] = 1.0f;
|
||||||
|
sqlite3_stmt *sk = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 5",
|
||||||
|
-1, &sk, NULL);
|
||||||
|
if (sk) {
|
||||||
|
sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
while (sqlite3_step(sk) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(sk);
|
||||||
|
}
|
||||||
|
sqlite3_free(qvec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Full scan
|
||||||
|
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||||
|
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
36
tests/fuzz/ivf-create.c
Normal file
36
tests/fuzz/ivf-create.c
Normal file
|
|
@ -0,0 +1,36 @@
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
sqlite3_stmt *stmt;
|
||||||
|
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
sqlite3_str *s = sqlite3_str_new(NULL);
|
||||||
|
assert(s);
|
||||||
|
sqlite3_str_appendall(s, "CREATE VIRTUAL TABLE v USING vec0(emb float[4] indexed by ivf(");
|
||||||
|
sqlite3_str_appendf(s, "%.*s", (int)size, data);
|
||||||
|
sqlite3_str_appendall(s, "))");
|
||||||
|
const char *zSql = sqlite3_str_finish(s);
|
||||||
|
assert(zSql);
|
||||||
|
|
||||||
|
rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, NULL);
|
||||||
|
sqlite3_free((void *)zSql);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
sqlite3_step(stmt);
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
16
tests/fuzz/ivf-create.dict
Normal file
16
tests/fuzz/ivf-create.dict
Normal file
|
|
@ -0,0 +1,16 @@
|
||||||
|
"nlist"
|
||||||
|
"nprobe"
|
||||||
|
"quantizer"
|
||||||
|
"oversample"
|
||||||
|
"binary"
|
||||||
|
"int8"
|
||||||
|
"none"
|
||||||
|
"="
|
||||||
|
","
|
||||||
|
"("
|
||||||
|
")"
|
||||||
|
"0"
|
||||||
|
"1"
|
||||||
|
"128"
|
||||||
|
"65536"
|
||||||
|
"65537"
|
||||||
180
tests/fuzz/ivf-kmeans.c
Normal file
180
tests/fuzz/ivf-kmeans.c
Normal file
|
|
@ -0,0 +1,180 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target: IVF k-means clustering.
|
||||||
|
*
|
||||||
|
* Builds a table, inserts fuzz-controlled vectors, then runs
|
||||||
|
* compute-centroids with fuzz-controlled parameters (nlist, max_iter, seed).
|
||||||
|
* Targets:
|
||||||
|
* - kmeans with N < k (clamping), N == 1, k == 1
|
||||||
|
* - kmeans with duplicate/identical vectors (all distances zero)
|
||||||
|
* - kmeans with NaN/Inf vectors
|
||||||
|
* - Empty cluster reassignment path (farthest-point heuristic)
|
||||||
|
* - Large nlist relative to N
|
||||||
|
* - The compute-centroids:{json} command parsing
|
||||||
|
* - clear-centroids followed by compute-centroids (round-trip)
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 10) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
// Parse fuzz header
|
||||||
|
// Byte 0-1: dimension (1..128)
|
||||||
|
// Byte 2: nlist for CREATE (1..64)
|
||||||
|
// Byte 3: nlist override for compute-centroids (0 = use default)
|
||||||
|
// Byte 4: max_iter (1..50)
|
||||||
|
// Byte 5-8: seed
|
||||||
|
// Byte 9: num_vectors (1..64)
|
||||||
|
// Remaining: vector float data
|
||||||
|
|
||||||
|
int dim = (data[0] | (data[1] << 8)) % 128 + 1;
|
||||||
|
int nlist_create = (data[2] % 64) + 1;
|
||||||
|
int nlist_override = data[3] % 65; // 0 means use table default
|
||||||
|
int max_iter = (data[4] % 50) + 1;
|
||||||
|
uint32_t seed = (uint32_t)data[5] | ((uint32_t)data[6] << 8) |
|
||||||
|
((uint32_t)data[7] << 16) | ((uint32_t)data[8] << 24);
|
||||||
|
int num_vecs = (data[9] % 64) + 1;
|
||||||
|
|
||||||
|
const uint8_t *payload = data + 10;
|
||||||
|
size_t payload_size = size - 10;
|
||||||
|
|
||||||
|
char sql[256];
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[%d] indexed by ivf(nlist=%d, nprobe=%d))",
|
||||||
|
dim, nlist_create, nlist_create);
|
||||||
|
|
||||||
|
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
// Insert vectors
|
||||||
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
|
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
size_t offset = 0;
|
||||||
|
for (int i = 0; i < num_vecs; i++) {
|
||||||
|
float *vec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (!vec) break;
|
||||||
|
|
||||||
|
for (int d = 0; d < dim; d++) {
|
||||||
|
if (offset + 4 <= payload_size) {
|
||||||
|
memcpy(&vec[d], payload + offset, sizeof(float));
|
||||||
|
offset += 4;
|
||||||
|
} else if (offset < payload_size) {
|
||||||
|
// Scale to interesting range including values > 1, < -1
|
||||||
|
vec[d] = ((float)(int8_t)payload[offset++]) / 5.0f;
|
||||||
|
} else {
|
||||||
|
// Reuse earlier bytes to fill remaining dimensions
|
||||||
|
vec[d] = (float)(i * dim + d) * 0.01f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, (int64_t)(i + 1));
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
sqlite3_free(vec);
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmtInsert);
|
||||||
|
|
||||||
|
// Exercise compute-centroids with JSON options
|
||||||
|
{
|
||||||
|
char cmd[256];
|
||||||
|
snprintf(cmd, sizeof(cmd),
|
||||||
|
"INSERT INTO v(rowid) VALUES "
|
||||||
|
"('compute-centroids:{\"nlist\":%d,\"max_iterations\":%d,\"seed\":%u}')",
|
||||||
|
nlist_override, max_iter, seed);
|
||||||
|
sqlite3_exec(db, cmd, NULL, NULL, NULL);
|
||||||
|
}
|
||||||
|
|
||||||
|
// KNN query after training
|
||||||
|
{
|
||||||
|
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (qvec) {
|
||||||
|
for (int d = 0; d < dim; d++) {
|
||||||
|
qvec[d] = (d < 3) ? 1.0f : 0.0f;
|
||||||
|
}
|
||||||
|
sqlite3_stmt *stmtKnn = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 5",
|
||||||
|
-1, &stmtKnn, NULL);
|
||||||
|
if (stmtKnn) {
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(stmtKnn);
|
||||||
|
}
|
||||||
|
sqlite3_free(qvec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear centroids and re-compute to test round-trip
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"INSERT INTO v(v) VALUES ('clear-centroids')",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// Insert a few more vectors in untrained state
|
||||||
|
{
|
||||||
|
sqlite3_stmt *si = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &si, NULL);
|
||||||
|
if (si) {
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
float *vec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (!vec) break;
|
||||||
|
for (int d = 0; d < dim; d++) vec[d] = (float)(i + 100) * 0.1f;
|
||||||
|
sqlite3_reset(si);
|
||||||
|
sqlite3_bind_int64(si, 1, (int64_t)(num_vecs + i + 1));
|
||||||
|
sqlite3_bind_blob(si, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(si);
|
||||||
|
sqlite3_free(vec);
|
||||||
|
}
|
||||||
|
sqlite3_finalize(si);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-train
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// Delete some rows after training, then query
|
||||||
|
sqlite3_exec(db, "DELETE FROM v WHERE rowid = 1", NULL, NULL, NULL);
|
||||||
|
sqlite3_exec(db, "DELETE FROM v WHERE rowid = 2", NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// Query after deletes
|
||||||
|
{
|
||||||
|
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (qvec) {
|
||||||
|
for (int d = 0; d < dim; d++) qvec[d] = 0.5f;
|
||||||
|
sqlite3_stmt *stmtKnn = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 10",
|
||||||
|
-1, &stmtKnn, NULL);
|
||||||
|
if (stmtKnn) {
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(stmtKnn);
|
||||||
|
}
|
||||||
|
sqlite3_free(qvec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
199
tests/fuzz/ivf-knn-deep.c
Normal file
199
tests/fuzz/ivf-knn-deep.c
Normal file
|
|
@ -0,0 +1,199 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target: IVF KNN search deep paths.
|
||||||
|
*
|
||||||
|
* Exercises the full KNN pipeline with fuzz-controlled:
|
||||||
|
* - nprobe values (including > nlist, =1, =nlist)
|
||||||
|
* - Query vectors (including adversarial floats)
|
||||||
|
* - Mix of trained/untrained state
|
||||||
|
* - Oversample + rescore path (quantizer=int8 with oversample>1)
|
||||||
|
* - Multiple interleaved KNN queries
|
||||||
|
* - Candidate array realloc path (many vectors in probed cells)
|
||||||
|
*
|
||||||
|
* Targets:
|
||||||
|
* - ivf_scan_cells_from_stmt: candidate realloc, distance computation
|
||||||
|
* - ivf_query_knn: centroid sorting, nprobe selection
|
||||||
|
* - Oversample rescore: re-ranking with full-precision vectors
|
||||||
|
* - qsort with NaN distances
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
static uint16_t read_u16(const uint8_t *p) {
|
||||||
|
return (uint16_t)(p[0] | (p[1] << 8));
|
||||||
|
}
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 16) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
// Header
|
||||||
|
int dim = (data[0] % 32) + 2; // 2..33
|
||||||
|
int nlist = (data[1] % 16) + 1; // 1..16
|
||||||
|
int nprobe_initial = (data[2] % 20) + 1; // 1..20 (can be > nlist)
|
||||||
|
int quantizer_type = data[3] % 3; // 0=none, 1=int8, 2=binary
|
||||||
|
int oversample = (data[4] % 4) + 1; // 1..4
|
||||||
|
int num_vecs = (data[5] % 80) + 4; // 4..83
|
||||||
|
int num_queries = (data[6] % 8) + 1; // 1..8
|
||||||
|
int k_limit = (data[7] % 20) + 1; // 1..20
|
||||||
|
|
||||||
|
const uint8_t *payload = data + 8;
|
||||||
|
size_t payload_size = size - 8;
|
||||||
|
|
||||||
|
// For binary quantizer, dimension must be multiple of 8
|
||||||
|
if (quantizer_type == 2) {
|
||||||
|
dim = ((dim + 7) / 8) * 8;
|
||||||
|
if (dim == 0) dim = 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char *qname;
|
||||||
|
switch (quantizer_type) {
|
||||||
|
case 1: qname = "int8"; break;
|
||||||
|
case 2: qname = "binary"; break;
|
||||||
|
default: qname = "none"; break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Oversample only valid with quantization
|
||||||
|
if (quantizer_type == 0) oversample = 1;
|
||||||
|
|
||||||
|
// Cap nprobe to nlist for CREATE (parser rejects nprobe > nlist)
|
||||||
|
int nprobe_create = nprobe_initial <= nlist ? nprobe_initial : nlist;
|
||||||
|
|
||||||
|
char sql[512];
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[%d] indexed by ivf(nlist=%d, nprobe=%d, quantizer=%s%s))",
|
||||||
|
dim, nlist, nprobe_create, qname,
|
||||||
|
oversample > 1 ? ", oversample=2" : "");
|
||||||
|
|
||||||
|
// If that fails (e.g. oversample with none), try without oversample
|
||||||
|
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[%d] indexed by ivf(nlist=%d, nprobe=%d, quantizer=%s))",
|
||||||
|
dim, nlist, nprobe_create, qname);
|
||||||
|
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert vectors
|
||||||
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
|
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
size_t offset = 0;
|
||||||
|
for (int i = 0; i < num_vecs; i++) {
|
||||||
|
float *vec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (!vec) break;
|
||||||
|
for (int d = 0; d < dim; d++) {
|
||||||
|
if (offset < payload_size) {
|
||||||
|
vec[d] = ((float)(int8_t)payload[offset++]) / 20.0f;
|
||||||
|
} else {
|
||||||
|
vec[d] = (float)((i * dim + d) % 256 - 128) / 128.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, (int64_t)(i + 1));
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
sqlite3_free(vec);
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmtInsert);
|
||||||
|
|
||||||
|
// Query BEFORE training (flat scan path)
|
||||||
|
{
|
||||||
|
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (qvec) {
|
||||||
|
for (int d = 0; d < dim; d++) qvec[d] = 0.5f;
|
||||||
|
sqlite3_stmt *sk = NULL;
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit);
|
||||||
|
sqlite3_prepare_v2(db, sql, -1, &sk, NULL);
|
||||||
|
if (sk) {
|
||||||
|
sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
while (sqlite3_step(sk) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(sk);
|
||||||
|
}
|
||||||
|
sqlite3_free(qvec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Train
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// Change nprobe at runtime (can exceed nlist -- tests clamping in query)
|
||||||
|
{
|
||||||
|
char cmd[64];
|
||||||
|
snprintf(cmd, sizeof(cmd),
|
||||||
|
"INSERT INTO v(v) VALUES ('nprobe=%d')", nprobe_initial);
|
||||||
|
sqlite3_exec(db, cmd, NULL, NULL, NULL);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiple KNN queries with different fuzz-derived query vectors
|
||||||
|
for (int q = 0; q < num_queries; q++) {
|
||||||
|
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (!qvec) break;
|
||||||
|
for (int d = 0; d < dim; d++) {
|
||||||
|
if (offset < payload_size) {
|
||||||
|
qvec[d] = ((float)(int8_t)payload[offset++]) / 10.0f;
|
||||||
|
} else {
|
||||||
|
qvec[d] = (q == 0) ? 1.0f : 0.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3_stmt *sk = NULL;
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit);
|
||||||
|
sqlite3_prepare_v2(db, sql, -1, &sk, NULL);
|
||||||
|
if (sk) {
|
||||||
|
sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
while (sqlite3_step(sk) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(sk);
|
||||||
|
}
|
||||||
|
sqlite3_free(qvec);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete half the vectors then query again
|
||||||
|
for (int i = 1; i <= num_vecs / 2; i++) {
|
||||||
|
char delsql[64];
|
||||||
|
snprintf(delsql, sizeof(delsql), "DELETE FROM v WHERE rowid = %d", i);
|
||||||
|
sqlite3_exec(db, delsql, NULL, NULL, NULL);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query after mass deletion
|
||||||
|
{
|
||||||
|
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (qvec) {
|
||||||
|
for (int d = 0; d < dim; d++) qvec[d] = -0.5f;
|
||||||
|
sqlite3_stmt *sk = NULL;
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit);
|
||||||
|
sqlite3_prepare_v2(db, sql, -1, &sk, NULL);
|
||||||
|
if (sk) {
|
||||||
|
sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
while (sqlite3_step(sk) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(sk);
|
||||||
|
}
|
||||||
|
sqlite3_free(qvec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
121
tests/fuzz/ivf-operations.c
Normal file
121
tests/fuzz/ivf-operations.c
Normal file
|
|
@ -0,0 +1,121 @@
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 6) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
|
sqlite3_stmt *stmtDelete = NULL;
|
||||||
|
sqlite3_stmt *stmtKnn = NULL;
|
||||||
|
sqlite3_stmt *stmtScan = NULL;
|
||||||
|
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
rc = sqlite3_exec(db,
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0(emb float[4] indexed by ivf(nlist=4, nprobe=4))",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 3",
|
||||||
|
-1, &stmtKnn, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid FROM v", -1, &stmtScan, NULL);
|
||||||
|
|
||||||
|
if (!stmtInsert || !stmtDelete || !stmtKnn || !stmtScan) goto cleanup;
|
||||||
|
|
||||||
|
size_t i = 0;
|
||||||
|
while (i + 2 <= size) {
|
||||||
|
uint8_t op = data[i++] % 7;
|
||||||
|
uint8_t rowid_byte = data[i++];
|
||||||
|
int64_t rowid = (int64_t)(rowid_byte % 32) + 1;
|
||||||
|
|
||||||
|
switch (op) {
|
||||||
|
case 0: {
|
||||||
|
// INSERT: consume 16 bytes for 4 floats, or use what's left
|
||||||
|
float vec[4] = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||||
|
for (int j = 0; j < 4 && i < size; j++, i++) {
|
||||||
|
vec[j] = (float)((int8_t)data[i]) / 10.0f;
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 1: {
|
||||||
|
// DELETE
|
||||||
|
sqlite3_reset(stmtDelete);
|
||||||
|
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||||
|
sqlite3_step(stmtDelete);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 2: {
|
||||||
|
// KNN query with a fixed query vector
|
||||||
|
float qvec[4] = {1.0f, 0.0f, 0.0f, 0.0f};
|
||||||
|
sqlite3_reset(stmtKnn);
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 3: {
|
||||||
|
// Full scan
|
||||||
|
sqlite3_reset(stmtScan);
|
||||||
|
while (sqlite3_step(stmtScan) == SQLITE_ROW) {}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 4: {
|
||||||
|
// compute-centroids command
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 5: {
|
||||||
|
// clear-centroids command
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"INSERT INTO v(v) VALUES ('clear-centroids')",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 6: {
|
||||||
|
// nprobe=N command
|
||||||
|
if (i < size) {
|
||||||
|
uint8_t n = data[i++];
|
||||||
|
int nprobe = (n % 4) + 1;
|
||||||
|
char buf[64];
|
||||||
|
snprintf(buf, sizeof(buf),
|
||||||
|
"INSERT INTO v(v) VALUES ('nprobe=%d')", nprobe);
|
||||||
|
sqlite3_exec(db, buf, NULL, NULL, NULL);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Final operations — must not crash regardless of prior state
|
||||||
|
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||||
|
|
||||||
|
cleanup:
|
||||||
|
sqlite3_finalize(stmtInsert);
|
||||||
|
sqlite3_finalize(stmtDelete);
|
||||||
|
sqlite3_finalize(stmtKnn);
|
||||||
|
sqlite3_finalize(stmtScan);
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
129
tests/fuzz/ivf-quantize.c
Normal file
129
tests/fuzz/ivf-quantize.c
Normal file
|
|
@ -0,0 +1,129 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target: IVF quantization functions.
|
||||||
|
*
|
||||||
|
* Directly exercises ivf_quantize_int8 and ivf_quantize_binary with
|
||||||
|
* fuzz-controlled dimensions and float data. Targets:
|
||||||
|
* - ivf_quantize_int8: clamping, int8 overflow boundary
|
||||||
|
* - ivf_quantize_binary: D not divisible by 8, memset(D/8) undercount
|
||||||
|
* - Round-trip through CREATE TABLE + INSERT with quantized IVF
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 8) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
// Byte 0: quantizer type (0=int8, 1=binary)
|
||||||
|
// Byte 1: dimension (1..64, but we test edge cases)
|
||||||
|
// Byte 2: nlist (1..8)
|
||||||
|
// Byte 3: num_vectors to insert (1..32)
|
||||||
|
// Remaining: float data
|
||||||
|
int qtype = data[0] % 2;
|
||||||
|
int dim = (data[1] % 64) + 1;
|
||||||
|
int nlist = (data[2] % 8) + 1;
|
||||||
|
int num_vecs = (data[3] % 32) + 1;
|
||||||
|
const uint8_t *payload = data + 4;
|
||||||
|
size_t payload_size = size - 4;
|
||||||
|
|
||||||
|
// For binary quantizer, D must be multiple of 8 to avoid the D/8 bug
|
||||||
|
// in production. But we explicitly want to test non-multiples too to
|
||||||
|
// find the bug. Use dim as-is.
|
||||||
|
const char *quantizer = qtype ? "binary" : "int8";
|
||||||
|
|
||||||
|
// Binary quantizer needs D multiple of 8 in current code, but let's
|
||||||
|
// test both valid and invalid dimensions to see what happens.
|
||||||
|
// For binary with non-multiple-of-8, the code does memset(dst, 0, D/8)
|
||||||
|
// which underallocates when D%8 != 0.
|
||||||
|
char sql[256];
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[%d] indexed by ivf(nlist=%d, nprobe=%d, quantizer=%s))",
|
||||||
|
dim, nlist, nlist, quantizer);
|
||||||
|
|
||||||
|
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
// Insert vectors with fuzz-controlled float values
|
||||||
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
|
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
size_t offset = 0;
|
||||||
|
for (int i = 0; i < num_vecs && offset < payload_size; i++) {
|
||||||
|
// Build float vector from fuzz data
|
||||||
|
float *vec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (!vec) break;
|
||||||
|
|
||||||
|
for (int d = 0; d < dim; d++) {
|
||||||
|
if (offset + 4 <= payload_size) {
|
||||||
|
// Use raw bytes as float -- can produce NaN, Inf, denormals
|
||||||
|
memcpy(&vec[d], payload + offset, sizeof(float));
|
||||||
|
offset += 4;
|
||||||
|
} else if (offset < payload_size) {
|
||||||
|
// Partial: use byte as scaled value
|
||||||
|
vec[d] = ((float)(int8_t)payload[offset++]) / 50.0f;
|
||||||
|
} else {
|
||||||
|
vec[d] = 0.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, (int64_t)(i + 1));
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
sqlite3_free(vec);
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmtInsert);
|
||||||
|
|
||||||
|
// Trigger compute-centroids to exercise kmeans + quantization together
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// KNN query with fuzz-derived query vector
|
||||||
|
{
|
||||||
|
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (qvec) {
|
||||||
|
for (int d = 0; d < dim; d++) {
|
||||||
|
if (offset < payload_size) {
|
||||||
|
qvec[d] = ((float)(int8_t)payload[offset++]) / 10.0f;
|
||||||
|
} else {
|
||||||
|
qvec[d] = 1.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3_stmt *stmtKnn = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 5",
|
||||||
|
-1, &stmtKnn, NULL);
|
||||||
|
if (stmtKnn) {
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(stmtKnn);
|
||||||
|
}
|
||||||
|
sqlite3_free(qvec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Full scan
|
||||||
|
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||||
|
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
182
tests/fuzz/ivf-rescore.c
Normal file
182
tests/fuzz/ivf-rescore.c
Normal file
|
|
@ -0,0 +1,182 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target: IVF oversample + rescore path.
|
||||||
|
*
|
||||||
|
* Specifically targets the code path where quantizer != none AND
|
||||||
|
* oversample > 1, which triggers:
|
||||||
|
* 1. Quantized KNN scan to collect oversample*k candidates
|
||||||
|
* 2. Full-precision vector lookup from _ivf_vectors table
|
||||||
|
* 3. Re-scoring with float32 distances
|
||||||
|
* 4. Re-sort and truncation
|
||||||
|
*
|
||||||
|
* This path has the most complex memory management in the KNN query:
|
||||||
|
* - Two separate distance computations (quantized + float)
|
||||||
|
* - Cross-table lookups (cells + vectors KV store)
|
||||||
|
* - Candidate array resizing
|
||||||
|
* - qsort over partially re-scored arrays
|
||||||
|
*
|
||||||
|
* Also tests the int8 + binary quantization round-trip fidelity
|
||||||
|
* under adversarial float inputs.
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 12) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
// Header
|
||||||
|
int quantizer_type = (data[0] % 2) + 1; // 1=int8, 2=binary (never none)
|
||||||
|
int dim = (data[1] % 32) + 8; // 8..39
|
||||||
|
int nlist = (data[2] % 8) + 1; // 1..8
|
||||||
|
int oversample = (data[3] % 4) + 2; // 2..5 (always > 1)
|
||||||
|
int num_vecs = (data[4] % 60) + 8; // 8..67
|
||||||
|
int k_limit = (data[5] % 15) + 1; // 1..15
|
||||||
|
|
||||||
|
const uint8_t *payload = data + 6;
|
||||||
|
size_t payload_size = size - 6;
|
||||||
|
|
||||||
|
// Binary quantizer needs D multiple of 8
|
||||||
|
if (quantizer_type == 2) {
|
||||||
|
dim = ((dim + 7) / 8) * 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char *qname = (quantizer_type == 1) ? "int8" : "binary";
|
||||||
|
|
||||||
|
char sql[512];
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[%d] indexed by ivf(nlist=%d, nprobe=%d, quantizer=%s, oversample=%d))",
|
||||||
|
dim, nlist, nlist, qname, oversample);
|
||||||
|
|
||||||
|
rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
// Insert vectors with diverse values
|
||||||
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
|
if (!stmtInsert) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
size_t offset = 0;
|
||||||
|
for (int i = 0; i < num_vecs; i++) {
|
||||||
|
float *vec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (!vec) break;
|
||||||
|
for (int d = 0; d < dim; d++) {
|
||||||
|
if (offset + 4 <= payload_size) {
|
||||||
|
// Use raw bytes as float for adversarial values
|
||||||
|
memcpy(&vec[d], payload + offset, sizeof(float));
|
||||||
|
offset += 4;
|
||||||
|
// Sanitize: replace NaN/Inf with bounded values to avoid
|
||||||
|
// poisoning the entire computation. We want edge values,
|
||||||
|
// not complete nonsense.
|
||||||
|
if (isnan(vec[d]) || isinf(vec[d])) {
|
||||||
|
vec[d] = (vec[d] > 0) ? 1e6f : -1e6f;
|
||||||
|
if (isnan(vec[d])) vec[d] = 0.0f;
|
||||||
|
}
|
||||||
|
} else if (offset < payload_size) {
|
||||||
|
vec[d] = ((float)(int8_t)payload[offset++]) / 30.0f;
|
||||||
|
} else {
|
||||||
|
vec[d] = (float)(i * dim + d) * 0.001f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, (int64_t)(i + 1));
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
sqlite3_free(vec);
|
||||||
|
}
|
||||||
|
sqlite3_finalize(stmtInsert);
|
||||||
|
|
||||||
|
// Train
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// Multiple KNN queries to exercise rescore path
|
||||||
|
for (int q = 0; q < 4; q++) {
|
||||||
|
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (!qvec) break;
|
||||||
|
for (int d = 0; d < dim; d++) {
|
||||||
|
if (offset < payload_size) {
|
||||||
|
qvec[d] = ((float)(int8_t)payload[offset++]) / 10.0f;
|
||||||
|
} else {
|
||||||
|
qvec[d] = (q == 0) ? 1.0f : (q == 1) ? -1.0f : 0.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3_stmt *sk = NULL;
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit);
|
||||||
|
sqlite3_prepare_v2(db, sql, -1, &sk, NULL);
|
||||||
|
if (sk) {
|
||||||
|
sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
while (sqlite3_step(sk) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(sk);
|
||||||
|
}
|
||||||
|
sqlite3_free(qvec);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete some vectors, then query again (rescore with missing _ivf_vectors rows)
|
||||||
|
for (int i = 1; i <= num_vecs / 3; i++) {
|
||||||
|
char delsql[64];
|
||||||
|
snprintf(delsql, sizeof(delsql), "DELETE FROM v WHERE rowid = %d", i);
|
||||||
|
sqlite3_exec(db, delsql, NULL, NULL, NULL);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (qvec) {
|
||||||
|
for (int d = 0; d < dim; d++) qvec[d] = 0.5f;
|
||||||
|
sqlite3_stmt *sk = NULL;
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit);
|
||||||
|
sqlite3_prepare_v2(db, sql, -1, &sk, NULL);
|
||||||
|
if (sk) {
|
||||||
|
sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
while (sqlite3_step(sk) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(sk);
|
||||||
|
}
|
||||||
|
sqlite3_free(qvec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrain after deletions
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// Query after retrain
|
||||||
|
{
|
||||||
|
float *qvec = sqlite3_malloc(dim * sizeof(float));
|
||||||
|
if (qvec) {
|
||||||
|
for (int d = 0; d < dim; d++) qvec[d] = -0.3f;
|
||||||
|
sqlite3_stmt *sk = NULL;
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT %d", k_limit);
|
||||||
|
sqlite3_prepare_v2(db, sql, -1, &sk, NULL);
|
||||||
|
if (sk) {
|
||||||
|
sqlite3_bind_blob(sk, 1, qvec, dim * sizeof(float), SQLITE_TRANSIENT);
|
||||||
|
while (sqlite3_step(sk) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(sk);
|
||||||
|
}
|
||||||
|
sqlite3_free(qvec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
228
tests/fuzz/ivf-shadow-corrupt.c
Normal file
228
tests/fuzz/ivf-shadow-corrupt.c
Normal file
|
|
@ -0,0 +1,228 @@
|
||||||
|
/**
|
||||||
|
* Fuzz target: IVF shadow table corruption.
|
||||||
|
*
|
||||||
|
* Creates a trained IVF table, then corrupts IVF shadow table blobs
|
||||||
|
* (centroids, cells validity/rowids/vectors, rowid_map) with fuzz data.
|
||||||
|
* Then exercises all read/write paths. Must not crash.
|
||||||
|
*
|
||||||
|
* Targets:
|
||||||
|
* - Cell validity bitmap with wrong size
|
||||||
|
* - Cell rowids blob with wrong size/alignment
|
||||||
|
* - Cell vectors blob with wrong size
|
||||||
|
* - Centroid blob with wrong size
|
||||||
|
* - n_vectors inconsistent with validity bitmap
|
||||||
|
* - Missing rowid_map entries
|
||||||
|
* - KNN scan over corrupted cells
|
||||||
|
* - Insert/delete with corrupted rowid_map
|
||||||
|
*/
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 4) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
// Create IVF table and insert enough vectors to train
|
||||||
|
rc = sqlite3_exec(db,
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[8] indexed by ivf(nlist=2, nprobe=2))",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
// Insert 10 vectors
|
||||||
|
{
|
||||||
|
sqlite3_stmt *si = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &si, NULL);
|
||||||
|
if (!si) { sqlite3_close(db); return 0; }
|
||||||
|
for (int i = 0; i < 10; i++) {
|
||||||
|
float vec[8];
|
||||||
|
for (int d = 0; d < 8; d++) {
|
||||||
|
vec[d] = (float)(i * 8 + d) * 0.1f;
|
||||||
|
}
|
||||||
|
sqlite3_reset(si);
|
||||||
|
sqlite3_bind_int64(si, 1, i + 1);
|
||||||
|
sqlite3_bind_blob(si, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(si);
|
||||||
|
}
|
||||||
|
sqlite3_finalize(si);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Train
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// Now corrupt shadow tables based on fuzz input
|
||||||
|
uint8_t target = data[0] % 10;
|
||||||
|
const uint8_t *payload = data + 1;
|
||||||
|
int payload_size = (int)(size - 1);
|
||||||
|
|
||||||
|
// Limit payload to avoid huge allocations
|
||||||
|
if (payload_size > 4096) payload_size = 4096;
|
||||||
|
|
||||||
|
sqlite3_stmt *stmt = NULL;
|
||||||
|
|
||||||
|
switch (target) {
|
||||||
|
case 0: {
|
||||||
|
// Corrupt cell validity blob
|
||||||
|
rc = sqlite3_prepare_v2(db,
|
||||||
|
"UPDATE v_ivf_cells00 SET validity = ? WHERE rowid = 1",
|
||||||
|
-1, &stmt, NULL);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC);
|
||||||
|
sqlite3_step(stmt); sqlite3_finalize(stmt);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 1: {
|
||||||
|
// Corrupt cell rowids blob
|
||||||
|
rc = sqlite3_prepare_v2(db,
|
||||||
|
"UPDATE v_ivf_cells00 SET rowids = ? WHERE rowid = 1",
|
||||||
|
-1, &stmt, NULL);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC);
|
||||||
|
sqlite3_step(stmt); sqlite3_finalize(stmt);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 2: {
|
||||||
|
// Corrupt cell vectors blob
|
||||||
|
rc = sqlite3_prepare_v2(db,
|
||||||
|
"UPDATE v_ivf_cells00 SET vectors = ? WHERE rowid = 1",
|
||||||
|
-1, &stmt, NULL);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC);
|
||||||
|
sqlite3_step(stmt); sqlite3_finalize(stmt);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 3: {
|
||||||
|
// Corrupt centroid blob
|
||||||
|
rc = sqlite3_prepare_v2(db,
|
||||||
|
"UPDATE v_ivf_centroids00 SET centroid = ? WHERE centroid_id = 0",
|
||||||
|
-1, &stmt, NULL);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC);
|
||||||
|
sqlite3_step(stmt); sqlite3_finalize(stmt);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 4: {
|
||||||
|
// Set n_vectors to a bogus value (larger than cell capacity)
|
||||||
|
int bogus_n = 99999;
|
||||||
|
if (payload_size >= 4) {
|
||||||
|
memcpy(&bogus_n, payload, 4);
|
||||||
|
bogus_n = abs(bogus_n) % 100000;
|
||||||
|
}
|
||||||
|
char sql[128];
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"UPDATE v_ivf_cells00 SET n_vectors = %d WHERE rowid = 1", bogus_n);
|
||||||
|
sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 5: {
|
||||||
|
// Delete rowid_map entries (orphan vectors)
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"DELETE FROM v_ivf_rowid_map00 WHERE rowid IN (1, 2, 3)",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 6: {
|
||||||
|
// Corrupt rowid_map slot values
|
||||||
|
char sql[128];
|
||||||
|
int bogus_slot = payload_size > 0 ? (int)payload[0] * 1000 : 99999;
|
||||||
|
snprintf(sql, sizeof(sql),
|
||||||
|
"UPDATE v_ivf_rowid_map00 SET slot = %d WHERE rowid = 1", bogus_slot);
|
||||||
|
sqlite3_exec(db, sql, NULL, NULL, NULL);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 7: {
|
||||||
|
// Corrupt rowid_map cell_id values
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"UPDATE v_ivf_rowid_map00 SET cell_id = 99999 WHERE rowid = 1",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 8: {
|
||||||
|
// Delete all centroids (make trained but no centroids)
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"DELETE FROM v_ivf_centroids00",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 9: {
|
||||||
|
// Set validity to NULL
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"UPDATE v_ivf_cells00 SET validity = NULL WHERE rowid = 1",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exercise all read paths over corrupted state — must not crash
|
||||||
|
float qvec[8] = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
|
||||||
|
|
||||||
|
// KNN query
|
||||||
|
{
|
||||||
|
sqlite3_stmt *sk = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? LIMIT 5",
|
||||||
|
-1, &sk, NULL);
|
||||||
|
if (sk) {
|
||||||
|
sqlite3_bind_blob(sk, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||||
|
while (sqlite3_step(sk) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(sk);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Full scan
|
||||||
|
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// Point query
|
||||||
|
sqlite3_exec(db, "SELECT * FROM v WHERE rowid = 1", NULL, NULL, NULL);
|
||||||
|
sqlite3_exec(db, "SELECT * FROM v WHERE rowid = 5", NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// Delete
|
||||||
|
sqlite3_exec(db, "DELETE FROM v WHERE rowid = 3", NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// Insert after corruption
|
||||||
|
{
|
||||||
|
float newvec[8] = {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f};
|
||||||
|
sqlite3_stmt *si = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(v, emb) VALUES (?, ?)", -1, &si, NULL);
|
||||||
|
if (si) {
|
||||||
|
sqlite3_bind_int64(si, 1, 100);
|
||||||
|
sqlite3_bind_blob(si, 2, newvec, sizeof(newvec), SQLITE_STATIC);
|
||||||
|
sqlite3_step(si);
|
||||||
|
sqlite3_finalize(si);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// compute-centroids over corrupted state
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"INSERT INTO v(v) VALUES ('compute-centroids')",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
|
// clear-centroids
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"INSERT INTO v(v) VALUES ('clear-centroids')",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
|
@ -8,9 +7,6 @@
|
||||||
#include "sqlite3.h"
|
#include "sqlite3.h"
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
|
|
||||||
extern int sqlite3_vec_numpy_init(sqlite3 *db, char **pzErrMsg,
|
|
||||||
const sqlite3_api_routines *pApi);
|
|
||||||
|
|
||||||
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
int rc = SQLITE_OK;
|
int rc = SQLITE_OK;
|
||||||
sqlite3 *db;
|
sqlite3 *db;
|
||||||
|
|
@ -20,17 +16,20 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
assert(rc == SQLITE_OK);
|
assert(rc == SQLITE_OK);
|
||||||
rc = sqlite3_vec_init(db, NULL, NULL);
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
assert(rc == SQLITE_OK);
|
assert(rc == SQLITE_OK);
|
||||||
rc = sqlite3_vec_numpy_init(db, NULL, NULL);
|
|
||||||
assert(rc == SQLITE_OK);
|
|
||||||
|
|
||||||
rc = sqlite3_prepare_v2(db, "select * from vec_npy_each(?)", -1, &stmt, NULL);
|
sqlite3_str *s = sqlite3_str_new(NULL);
|
||||||
assert(rc == SQLITE_OK);
|
assert(s);
|
||||||
sqlite3_bind_blob(stmt, 1, data, size, SQLITE_STATIC);
|
sqlite3_str_appendall(s, "CREATE VIRTUAL TABLE v USING vec0(emb float[128] indexed by rescore(");
|
||||||
rc = sqlite3_step(stmt);
|
sqlite3_str_appendf(s, "%.*s", (int)size, data);
|
||||||
while (rc == SQLITE_ROW) {
|
sqlite3_str_appendall(s, "))");
|
||||||
rc = sqlite3_step(stmt);
|
const char *zSql = sqlite3_str_finish(s);
|
||||||
|
assert(zSql);
|
||||||
|
|
||||||
|
rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, NULL);
|
||||||
|
sqlite3_free((void *)zSql);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
sqlite3_step(stmt);
|
||||||
}
|
}
|
||||||
|
|
||||||
sqlite3_finalize(stmt);
|
sqlite3_finalize(stmt);
|
||||||
sqlite3_close(db);
|
sqlite3_close(db);
|
||||||
return 0;
|
return 0;
|
||||||
20
tests/fuzz/rescore-create.dict
Normal file
20
tests/fuzz/rescore-create.dict
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
"rescore"
|
||||||
|
"quantizer"
|
||||||
|
"bit"
|
||||||
|
"int8"
|
||||||
|
"oversample"
|
||||||
|
"indexed"
|
||||||
|
"by"
|
||||||
|
"float"
|
||||||
|
"("
|
||||||
|
")"
|
||||||
|
","
|
||||||
|
"="
|
||||||
|
"["
|
||||||
|
"]"
|
||||||
|
"1"
|
||||||
|
"8"
|
||||||
|
"16"
|
||||||
|
"128"
|
||||||
|
"256"
|
||||||
|
"1024"
|
||||||
151
tests/fuzz/rescore-interleave.c
Normal file
151
tests/fuzz/rescore-interleave.c
Normal file
|
|
@ -0,0 +1,151 @@
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fuzz target: interleaved insert/update/delete/KNN operations on rescore
|
||||||
|
* tables with BOTH quantizer types, exercising the int8 quantizer path
|
||||||
|
* and the update code path that the existing rescore-operations.c misses.
|
||||||
|
*
|
||||||
|
* Key differences from rescore-operations.c:
|
||||||
|
* - Tests BOTH bit and int8 quantizers (the existing target only tests bit)
|
||||||
|
* - Fuzz-controlled query vectors (not fixed [1,0,0,...])
|
||||||
|
* - Exercises the UPDATE path (line 9080+ in sqlite-vec.c)
|
||||||
|
* - Tests with 16 dimensions (more realistic, exercises more of the
|
||||||
|
* quantization loop)
|
||||||
|
* - Interleaves KNN between mutations to stress the blob_reopen path
|
||||||
|
* when _rescore_vectors rows have been deleted/modified
|
||||||
|
*/
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 8) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
|
sqlite3_stmt *stmtUpdate = NULL;
|
||||||
|
sqlite3_stmt *stmtDelete = NULL;
|
||||||
|
sqlite3_stmt *stmtKnn = NULL;
|
||||||
|
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
/* Use first byte to pick quantizer */
|
||||||
|
int use_int8 = data[0] & 1;
|
||||||
|
data++; size--;
|
||||||
|
|
||||||
|
const char *create_sql = use_int8
|
||||||
|
? "CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[16] indexed by rescore(quantizer=int8))"
|
||||||
|
: "CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[16] indexed by rescore(quantizer=bit))";
|
||||||
|
|
||||||
|
rc = sqlite3_exec(db, create_sql, NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"UPDATE v SET emb = ? WHERE rowid = ?", -1, &stmtUpdate, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? "
|
||||||
|
"ORDER BY distance LIMIT 5", -1, &stmtKnn, NULL);
|
||||||
|
|
||||||
|
if (!stmtInsert || !stmtUpdate || !stmtDelete || !stmtKnn)
|
||||||
|
goto cleanup;
|
||||||
|
|
||||||
|
size_t i = 0;
|
||||||
|
while (i + 2 <= size) {
|
||||||
|
uint8_t op = data[i++] % 5; /* 5 operations now */
|
||||||
|
uint8_t rowid_byte = data[i++];
|
||||||
|
int64_t rowid = (int64_t)(rowid_byte % 24) + 1;
|
||||||
|
|
||||||
|
switch (op) {
|
||||||
|
case 0: {
|
||||||
|
/* INSERT: consume bytes for 16 floats */
|
||||||
|
float vec[16] = {0};
|
||||||
|
for (int j = 0; j < 16 && i < size; j++, i++) {
|
||||||
|
vec[j] = (float)((int8_t)data[i]) / 8.0f;
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 1: {
|
||||||
|
/* DELETE */
|
||||||
|
sqlite3_reset(stmtDelete);
|
||||||
|
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||||
|
sqlite3_step(stmtDelete);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 2: {
|
||||||
|
/* KNN with fuzz-controlled query vector */
|
||||||
|
float qvec[16] = {0};
|
||||||
|
for (int j = 0; j < 16 && i < size; j++, i++) {
|
||||||
|
qvec[j] = (float)((int8_t)data[i]) / 4.0f;
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtKnn);
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {
|
||||||
|
(void)sqlite3_column_int64(stmtKnn, 0);
|
||||||
|
(void)sqlite3_column_double(stmtKnn, 1);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 3: {
|
||||||
|
/* UPDATE: modify an existing vector (exercises rescore update path) */
|
||||||
|
float vec[16] = {0};
|
||||||
|
for (int j = 0; j < 16 && i < size; j++, i++) {
|
||||||
|
vec[j] = (float)((int8_t)data[i]) / 6.0f;
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtUpdate);
|
||||||
|
sqlite3_bind_blob(stmtUpdate, 1, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_bind_int64(stmtUpdate, 2, rowid);
|
||||||
|
sqlite3_step(stmtUpdate);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 4: {
|
||||||
|
/* INSERT then immediately UPDATE same row (stresses blob lifecycle) */
|
||||||
|
float vec1[16] = {0};
|
||||||
|
float vec2[16] = {0};
|
||||||
|
for (int j = 0; j < 16 && i < size; j++, i++) {
|
||||||
|
vec1[j] = (float)((int8_t)data[i]) / 10.0f;
|
||||||
|
vec2[j] = -vec1[j]; /* opposite direction */
|
||||||
|
}
|
||||||
|
/* Insert */
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec1, sizeof(vec1), SQLITE_TRANSIENT);
|
||||||
|
if (sqlite3_step(stmtInsert) == SQLITE_DONE) {
|
||||||
|
/* Only update if insert succeeded (rowid might already exist) */
|
||||||
|
sqlite3_reset(stmtUpdate);
|
||||||
|
sqlite3_bind_blob(stmtUpdate, 1, vec2, sizeof(vec2), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_bind_int64(stmtUpdate, 2, rowid);
|
||||||
|
sqlite3_step(stmtUpdate);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Final consistency check: full scan must not crash */
|
||||||
|
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||||
|
|
||||||
|
cleanup:
|
||||||
|
sqlite3_finalize(stmtInsert);
|
||||||
|
sqlite3_finalize(stmtUpdate);
|
||||||
|
sqlite3_finalize(stmtDelete);
|
||||||
|
sqlite3_finalize(stmtKnn);
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
178
tests/fuzz/rescore-knn-deep.c
Normal file
178
tests/fuzz/rescore-knn-deep.c
Normal file
|
|
@ -0,0 +1,178 @@
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fuzz target: deep exercise of rescore KNN with fuzz-controlled query vectors
|
||||||
|
* and both quantizer types (bit + int8), multiple distance metrics.
|
||||||
|
*
|
||||||
|
* The existing rescore-operations.c only tests bit quantizer with a fixed
|
||||||
|
* query vector. This target:
|
||||||
|
* - Tests both bit and int8 quantizers
|
||||||
|
* - Uses fuzz-controlled query vectors (hits NaN/Inf/denormal paths)
|
||||||
|
* - Tests all distance metrics with int8 (L2, cosine, L1)
|
||||||
|
* - Exercises large LIMIT values (oversample multiplication)
|
||||||
|
* - Tests KNN with rowid IN constraints
|
||||||
|
* - Exercises the insert->query->update->query->delete->query cycle
|
||||||
|
*/
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 20) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
/* Use first 4 bytes for configuration */
|
||||||
|
uint8_t config = data[0];
|
||||||
|
uint8_t num_inserts = (data[1] % 20) + 3; /* 3..22 inserts */
|
||||||
|
uint8_t limit_val = (data[2] % 50) + 1; /* 1..50 for LIMIT */
|
||||||
|
uint8_t metric_choice = data[3] % 3;
|
||||||
|
data += 4;
|
||||||
|
size -= 4;
|
||||||
|
|
||||||
|
int use_int8 = config & 1;
|
||||||
|
const char *metric_str;
|
||||||
|
switch (metric_choice) {
|
||||||
|
case 0: metric_str = ""; break; /* default L2 */
|
||||||
|
case 1: metric_str = " distance_metric=cosine"; break;
|
||||||
|
case 2: metric_str = " distance_metric=l1"; break;
|
||||||
|
default: metric_str = ""; break;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Build CREATE TABLE statement */
|
||||||
|
char create_sql[256];
|
||||||
|
if (use_int8) {
|
||||||
|
snprintf(create_sql, sizeof(create_sql),
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[16] indexed by rescore(quantizer=int8)%s)", metric_str);
|
||||||
|
} else {
|
||||||
|
/* bit quantizer ignores distance_metric for the coarse pass (always hamming),
|
||||||
|
but the float rescore phase uses the specified metric */
|
||||||
|
snprintf(create_sql, sizeof(create_sql),
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[16] indexed by rescore(quantizer=bit)%s)", metric_str);
|
||||||
|
}
|
||||||
|
|
||||||
|
rc = sqlite3_exec(db, create_sql, NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
/* Insert vectors using fuzz data */
|
||||||
|
{
|
||||||
|
sqlite3_stmt *ins = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &ins, NULL);
|
||||||
|
if (!ins) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
size_t cursor = 0;
|
||||||
|
for (int i = 0; i < num_inserts && cursor + 1 < size; i++) {
|
||||||
|
float vec[16];
|
||||||
|
for (int j = 0; j < 16; j++) {
|
||||||
|
if (cursor < size) {
|
||||||
|
/* Map fuzz byte to float -- includes potential for
|
||||||
|
interesting float values via reinterpretation */
|
||||||
|
int8_t sb = (int8_t)data[cursor++];
|
||||||
|
vec[j] = (float)sb / 5.0f;
|
||||||
|
} else {
|
||||||
|
vec[j] = 0.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sqlite3_reset(ins);
|
||||||
|
sqlite3_bind_int64(ins, 1, (sqlite3_int64)(i + 1));
|
||||||
|
sqlite3_bind_blob(ins, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(ins);
|
||||||
|
}
|
||||||
|
sqlite3_finalize(ins);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Build a fuzz-controlled query vector from remaining data */
|
||||||
|
float qvec[16] = {0};
|
||||||
|
{
|
||||||
|
size_t cursor = 0;
|
||||||
|
for (int j = 0; j < 16 && cursor < size; j++) {
|
||||||
|
int8_t sb = (int8_t)data[cursor++];
|
||||||
|
qvec[j] = (float)sb / 3.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* KNN query with fuzz-controlled vector and LIMIT */
|
||||||
|
{
|
||||||
|
char knn_sql[256];
|
||||||
|
snprintf(knn_sql, sizeof(knn_sql),
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? "
|
||||||
|
"ORDER BY distance LIMIT %d", (int)limit_val);
|
||||||
|
|
||||||
|
sqlite3_stmt *knn = NULL;
|
||||||
|
sqlite3_prepare_v2(db, knn_sql, -1, &knn, NULL);
|
||||||
|
if (knn) {
|
||||||
|
sqlite3_bind_blob(knn, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||||
|
while (sqlite3_step(knn) == SQLITE_ROW) {
|
||||||
|
/* Read results to ensure distance computation didn't produce garbage
|
||||||
|
that crashes the cursor iteration */
|
||||||
|
(void)sqlite3_column_int64(knn, 0);
|
||||||
|
(void)sqlite3_column_double(knn, 1);
|
||||||
|
}
|
||||||
|
sqlite3_finalize(knn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Update some vectors, then query again */
|
||||||
|
{
|
||||||
|
float uvec[16];
|
||||||
|
for (int j = 0; j < 16; j++) uvec[j] = qvec[15 - j]; /* reverse of query */
|
||||||
|
sqlite3_stmt *upd = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"UPDATE v SET emb = ? WHERE rowid = 1", -1, &upd, NULL);
|
||||||
|
if (upd) {
|
||||||
|
sqlite3_bind_blob(upd, 1, uvec, sizeof(uvec), SQLITE_STATIC);
|
||||||
|
sqlite3_step(upd);
|
||||||
|
sqlite3_finalize(upd);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Second KNN after update */
|
||||||
|
{
|
||||||
|
sqlite3_stmt *knn = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? "
|
||||||
|
"ORDER BY distance LIMIT 10", -1, &knn, NULL);
|
||||||
|
if (knn) {
|
||||||
|
sqlite3_bind_blob(knn, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||||
|
while (sqlite3_step(knn) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(knn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Delete half the rows, then KNN again */
|
||||||
|
for (int i = 1; i <= num_inserts; i += 2) {
|
||||||
|
char del_sql[64];
|
||||||
|
snprintf(del_sql, sizeof(del_sql),
|
||||||
|
"DELETE FROM v WHERE rowid = %d", i);
|
||||||
|
sqlite3_exec(db, del_sql, NULL, NULL, NULL);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Third KNN after deletes -- exercises distance computation over
|
||||||
|
zeroed-out slots in the quantized chunk */
|
||||||
|
{
|
||||||
|
sqlite3_stmt *knn = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? "
|
||||||
|
"ORDER BY distance LIMIT 5", -1, &knn, NULL);
|
||||||
|
if (knn) {
|
||||||
|
sqlite3_bind_blob(knn, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||||
|
while (sqlite3_step(knn) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(knn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
96
tests/fuzz/rescore-operations.c
Normal file
96
tests/fuzz/rescore-operations.c
Normal file
|
|
@ -0,0 +1,96 @@
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 6) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
sqlite3_stmt *stmtInsert = NULL;
|
||||||
|
sqlite3_stmt *stmtDelete = NULL;
|
||||||
|
sqlite3_stmt *stmtKnn = NULL;
|
||||||
|
sqlite3_stmt *stmtScan = NULL;
|
||||||
|
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
rc = sqlite3_exec(db,
|
||||||
|
"CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[8] indexed by rescore(quantizer=bit))",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? ORDER BY distance LIMIT 3",
|
||||||
|
-1, &stmtKnn, NULL);
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid FROM v", -1, &stmtScan, NULL);
|
||||||
|
|
||||||
|
if (!stmtInsert || !stmtDelete || !stmtKnn || !stmtScan) goto cleanup;
|
||||||
|
|
||||||
|
size_t i = 0;
|
||||||
|
while (i + 2 <= size) {
|
||||||
|
uint8_t op = data[i++] % 4;
|
||||||
|
uint8_t rowid_byte = data[i++];
|
||||||
|
int64_t rowid = (int64_t)(rowid_byte % 32) + 1;
|
||||||
|
|
||||||
|
switch (op) {
|
||||||
|
case 0: {
|
||||||
|
// INSERT: consume 32 bytes for 8 floats, or use what's left
|
||||||
|
float vec[8] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
|
||||||
|
for (int j = 0; j < 8 && i < size; j++, i++) {
|
||||||
|
vec[j] = (float)((int8_t)data[i]) / 10.0f;
|
||||||
|
}
|
||||||
|
sqlite3_reset(stmtInsert);
|
||||||
|
sqlite3_bind_int64(stmtInsert, 1, rowid);
|
||||||
|
sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(stmtInsert);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 1: {
|
||||||
|
// DELETE
|
||||||
|
sqlite3_reset(stmtDelete);
|
||||||
|
sqlite3_bind_int64(stmtDelete, 1, rowid);
|
||||||
|
sqlite3_step(stmtDelete);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 2: {
|
||||||
|
// KNN query with a fixed query vector
|
||||||
|
float qvec[8] = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
|
||||||
|
sqlite3_reset(stmtKnn);
|
||||||
|
sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||||
|
while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 3: {
|
||||||
|
// Full scan
|
||||||
|
sqlite3_reset(stmtScan);
|
||||||
|
while (sqlite3_step(stmtScan) == SQLITE_ROW) {}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Final operations -- must not crash regardless of prior state
|
||||||
|
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||||
|
|
||||||
|
cleanup:
|
||||||
|
sqlite3_finalize(stmtInsert);
|
||||||
|
sqlite3_finalize(stmtDelete);
|
||||||
|
sqlite3_finalize(stmtKnn);
|
||||||
|
sqlite3_finalize(stmtScan);
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
177
tests/fuzz/rescore-quantize-edge.c
Normal file
177
tests/fuzz/rescore-quantize-edge.c
Normal file
|
|
@ -0,0 +1,177 @@
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
/* Test wrappers from sqlite-vec-rescore.c (SQLITE_VEC_TEST build) */
|
||||||
|
extern void _test_rescore_quantize_float_to_bit(const float *src, uint8_t *dst, size_t dim);
|
||||||
|
extern void _test_rescore_quantize_float_to_int8(const float *src, int8_t *dst, size_t dim);
|
||||||
|
extern size_t _test_rescore_quantized_byte_size_bit(size_t dimensions);
|
||||||
|
extern size_t _test_rescore_quantized_byte_size_int8(size_t dimensions);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fuzz target: edge cases in rescore quantization functions.
|
||||||
|
*
|
||||||
|
* The existing rescore-quantize.c only tests dimensions that are multiples of 8
|
||||||
|
* and never passes special float values. This target:
|
||||||
|
*
|
||||||
|
* - Tests rescore_quantized_byte_size with arbitrary dimension values
|
||||||
|
* (including 0, 1, 7, MAX values -- looking for integer division issues)
|
||||||
|
* - Passes raw float reinterpretation of fuzz bytes (NaN, Inf, denormals,
|
||||||
|
* negative zero -- these are the values that break min/max/range logic)
|
||||||
|
* - Tests the int8 quantizer with all-identical values (range=0 branch)
|
||||||
|
* - Tests the int8 quantizer with extreme ranges (overflow in scale calc)
|
||||||
|
* - Tests bit quantizer with exact float threshold (0.0f boundary)
|
||||||
|
*/
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 8) return 0;
|
||||||
|
|
||||||
|
uint8_t mode = data[0] % 5;
|
||||||
|
data++; size--;
|
||||||
|
|
||||||
|
switch (mode) {
|
||||||
|
case 0: {
|
||||||
|
/* Test rescore_quantized_byte_size with fuzz-controlled dimensions.
|
||||||
|
This function does dimensions / CHAR_BIT for bit, dimensions for int8.
|
||||||
|
We're checking it doesn't do anything weird with edge values. */
|
||||||
|
if (size < sizeof(size_t)) return 0;
|
||||||
|
size_t dim;
|
||||||
|
memcpy(&dim, data, sizeof(dim));
|
||||||
|
|
||||||
|
/* These should never crash, just return values */
|
||||||
|
size_t bit_size = _test_rescore_quantized_byte_size_bit(dim);
|
||||||
|
size_t int8_size = _test_rescore_quantized_byte_size_int8(dim);
|
||||||
|
|
||||||
|
/* Verify basic invariants */
|
||||||
|
(void)bit_size;
|
||||||
|
(void)int8_size;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case 1: {
|
||||||
|
/* Bit quantize with raw reinterpreted floats (NaN, Inf, denormal).
|
||||||
|
The key check: src[i] >= 0.0f -- NaN comparison is always false,
|
||||||
|
so NaN should produce 0-bits. But denormals and -0.0f are tricky. */
|
||||||
|
size_t num_floats = size / sizeof(float);
|
||||||
|
if (num_floats == 0) return 0;
|
||||||
|
/* Round to multiple of 8 for bit quantizer */
|
||||||
|
size_t dim = (num_floats / 8) * 8;
|
||||||
|
if (dim == 0) return 0;
|
||||||
|
|
||||||
|
const float *src = (const float *)data;
|
||||||
|
size_t bit_bytes = dim / 8;
|
||||||
|
uint8_t *dst = (uint8_t *)malloc(bit_bytes);
|
||||||
|
if (!dst) return 0;
|
||||||
|
|
||||||
|
_test_rescore_quantize_float_to_bit(src, dst, dim);
|
||||||
|
|
||||||
|
/* Verify: for each bit, if src >= 0 then bit should be set */
|
||||||
|
for (size_t i = 0; i < dim; i++) {
|
||||||
|
int bit_set = (dst[i / 8] >> (i % 8)) & 1;
|
||||||
|
if (src[i] >= 0.0f) {
|
||||||
|
assert(bit_set == 1);
|
||||||
|
} else if (src[i] < 0.0f) {
|
||||||
|
/* Definitely negative -- bit must be 0 */
|
||||||
|
assert(bit_set == 0);
|
||||||
|
}
|
||||||
|
/* NaN: comparison is false, so bit_set should be 0 */
|
||||||
|
}
|
||||||
|
|
||||||
|
free(dst);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case 2: {
|
||||||
|
/* Int8 quantize with raw reinterpreted floats.
|
||||||
|
The dangerous paths:
|
||||||
|
- All values identical (range == 0) -> memset path
|
||||||
|
- vmin/vmax with NaN (NaN < anything is false, NaN > anything is false)
|
||||||
|
- Extreme range causing scale = 255/range to be Inf or 0
|
||||||
|
- denormals near the clamping boundaries */
|
||||||
|
size_t num_floats = size / sizeof(float);
|
||||||
|
if (num_floats == 0) return 0;
|
||||||
|
|
||||||
|
const float *src = (const float *)data;
|
||||||
|
int8_t *dst = (int8_t *)malloc(num_floats);
|
||||||
|
if (!dst) return 0;
|
||||||
|
|
||||||
|
_test_rescore_quantize_float_to_int8(src, dst, num_floats);
|
||||||
|
|
||||||
|
/* Output must always be in [-128, 127] (trivially true for int8_t,
|
||||||
|
but check the actual clamping logic worked) */
|
||||||
|
for (size_t i = 0; i < num_floats; i++) {
|
||||||
|
assert(dst[i] >= -128 && dst[i] <= 127);
|
||||||
|
}
|
||||||
|
|
||||||
|
free(dst);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case 3: {
|
||||||
|
/* Int8 quantize stress: all-same values (range=0 branch) */
|
||||||
|
size_t dim = (size < 64) ? size : 64;
|
||||||
|
if (dim == 0) return 0;
|
||||||
|
|
||||||
|
float *src = (float *)malloc(dim * sizeof(float));
|
||||||
|
int8_t *dst = (int8_t *)malloc(dim);
|
||||||
|
if (!src || !dst) { free(src); free(dst); return 0; }
|
||||||
|
|
||||||
|
/* Fill with a single value derived from fuzz data */
|
||||||
|
float val;
|
||||||
|
memcpy(&val, data, sizeof(float) < size ? sizeof(float) : size);
|
||||||
|
for (size_t i = 0; i < dim; i++) src[i] = val;
|
||||||
|
|
||||||
|
_test_rescore_quantize_float_to_int8(src, dst, dim);
|
||||||
|
|
||||||
|
/* All outputs should be 0 when range == 0 */
|
||||||
|
for (size_t i = 0; i < dim; i++) {
|
||||||
|
assert(dst[i] == 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
free(src);
|
||||||
|
free(dst);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
case 4: {
|
||||||
|
/* Int8 quantize with extreme range: one huge positive, one huge negative.
|
||||||
|
Tests scale = 255.0f / range overflow to Inf, then v * Inf = Inf,
|
||||||
|
then clamping to [-128, 127]. */
|
||||||
|
if (size < 2 * sizeof(float)) return 0;
|
||||||
|
|
||||||
|
float extreme[2];
|
||||||
|
memcpy(extreme, data, 2 * sizeof(float));
|
||||||
|
|
||||||
|
/* Only test if both are finite (NaN/Inf tested in case 2) */
|
||||||
|
if (!isfinite(extreme[0]) || !isfinite(extreme[1])) return 0;
|
||||||
|
|
||||||
|
/* Build a vector with these two extreme values plus some fuzz */
|
||||||
|
size_t dim = 16;
|
||||||
|
float src[16];
|
||||||
|
src[0] = extreme[0];
|
||||||
|
src[1] = extreme[1];
|
||||||
|
for (size_t i = 2; i < dim; i++) {
|
||||||
|
if (2 * sizeof(float) + (i - 2) < size) {
|
||||||
|
src[i] = (float)((int8_t)data[2 * sizeof(float) + (i - 2)]) * 1000.0f;
|
||||||
|
} else {
|
||||||
|
src[i] = 0.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int8_t dst[16];
|
||||||
|
_test_rescore_quantize_float_to_int8(src, dst, dim);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < dim; i++) {
|
||||||
|
assert(dst[i] >= -128 && dst[i] <= 127);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
54
tests/fuzz/rescore-quantize.c
Normal file
54
tests/fuzz/rescore-quantize.c
Normal file
|
|
@ -0,0 +1,54 @@
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
/* These are SQLITE_VEC_TEST wrappers defined in sqlite-vec-rescore.c */
|
||||||
|
extern void _test_rescore_quantize_float_to_bit(const float *src, uint8_t *dst, size_t dim);
|
||||||
|
extern void _test_rescore_quantize_float_to_int8(const float *src, int8_t *dst, size_t dim);
|
||||||
|
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
/* Need at least 4 bytes for one float */
|
||||||
|
if (size < 4) return 0;
|
||||||
|
|
||||||
|
/* Use the input as an array of floats. Dimensions must be a multiple of 8
|
||||||
|
* for the bit quantizer. */
|
||||||
|
size_t num_floats = size / sizeof(float);
|
||||||
|
if (num_floats == 0) return 0;
|
||||||
|
|
||||||
|
/* Round down to multiple of 8 for bit quantizer compatibility */
|
||||||
|
size_t dim = (num_floats / 8) * 8;
|
||||||
|
if (dim == 0) dim = 8;
|
||||||
|
if (dim > num_floats) return 0;
|
||||||
|
|
||||||
|
const float *src = (const float *)data;
|
||||||
|
|
||||||
|
/* Allocate output buffers */
|
||||||
|
size_t bit_bytes = dim / 8;
|
||||||
|
uint8_t *bit_dst = (uint8_t *)malloc(bit_bytes);
|
||||||
|
int8_t *int8_dst = (int8_t *)malloc(dim);
|
||||||
|
if (!bit_dst || !int8_dst) {
|
||||||
|
free(bit_dst);
|
||||||
|
free(int8_dst);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Test bit quantization */
|
||||||
|
_test_rescore_quantize_float_to_bit(src, bit_dst, dim);
|
||||||
|
|
||||||
|
/* Test int8 quantization */
|
||||||
|
_test_rescore_quantize_float_to_int8(src, int8_dst, dim);
|
||||||
|
|
||||||
|
/* Verify int8 output is in range */
|
||||||
|
for (size_t i = 0; i < dim; i++) {
|
||||||
|
assert(int8_dst[i] >= -128 && int8_dst[i] <= 127);
|
||||||
|
}
|
||||||
|
|
||||||
|
free(bit_dst);
|
||||||
|
free(int8_dst);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
230
tests/fuzz/rescore-shadow-corrupt.c
Normal file
230
tests/fuzz/rescore-shadow-corrupt.c
Normal file
|
|
@ -0,0 +1,230 @@
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "sqlite-vec.h"
|
||||||
|
#include "sqlite3.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fuzz target: corrupt rescore shadow tables then exercise KNN/read/write.
|
||||||
|
*
|
||||||
|
* This targets the dangerous code paths in rescore_knn (Phase 1 + 2):
|
||||||
|
* - sqlite3_blob_read into baseVectors with potentially wrong-sized blobs
|
||||||
|
* - distance computation on corrupted/partial quantized data
|
||||||
|
* - blob_reopen on _rescore_vectors with missing/corrupted rows
|
||||||
|
* - insert/delete after corruption (blob_write to wrong offsets)
|
||||||
|
*
|
||||||
|
* The existing shadow-corrupt.c only tests vec0 without rescore.
|
||||||
|
*/
|
||||||
|
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
|
||||||
|
if (size < 4) return 0;
|
||||||
|
|
||||||
|
int rc;
|
||||||
|
sqlite3 *db;
|
||||||
|
|
||||||
|
rc = sqlite3_open(":memory:", &db);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
rc = sqlite3_vec_init(db, NULL, NULL);
|
||||||
|
assert(rc == SQLITE_OK);
|
||||||
|
|
||||||
|
/* Pick quantizer type from first byte */
|
||||||
|
int use_int8 = data[0] & 1;
|
||||||
|
int target = (data[1] % 8);
|
||||||
|
const uint8_t *payload = data + 2;
|
||||||
|
int payload_size = (int)(size - 2);
|
||||||
|
|
||||||
|
const char *create_sql = use_int8
|
||||||
|
? "CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[16] indexed by rescore(quantizer=int8))"
|
||||||
|
: "CREATE VIRTUAL TABLE v USING vec0("
|
||||||
|
"emb float[16] indexed by rescore(quantizer=bit))";
|
||||||
|
|
||||||
|
rc = sqlite3_exec(db, create_sql, NULL, NULL, NULL);
|
||||||
|
if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
/* Insert several vectors so there's a full chunk to corrupt */
|
||||||
|
{
|
||||||
|
sqlite3_stmt *ins = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &ins, NULL);
|
||||||
|
if (!ins) { sqlite3_close(db); return 0; }
|
||||||
|
|
||||||
|
for (int i = 1; i <= 8; i++) {
|
||||||
|
float vec[16];
|
||||||
|
for (int j = 0; j < 16; j++) vec[j] = (float)(i * 10 + j) / 100.0f;
|
||||||
|
sqlite3_reset(ins);
|
||||||
|
sqlite3_bind_int64(ins, 1, i);
|
||||||
|
sqlite3_bind_blob(ins, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
|
||||||
|
sqlite3_step(ins);
|
||||||
|
}
|
||||||
|
sqlite3_finalize(ins);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Now corrupt rescore shadow tables based on fuzz input */
|
||||||
|
sqlite3_stmt *stmt = NULL;
|
||||||
|
|
||||||
|
switch (target) {
|
||||||
|
case 0: {
|
||||||
|
/* Corrupt _rescore_chunks00 vectors blob with fuzz data */
|
||||||
|
rc = sqlite3_prepare_v2(db,
|
||||||
|
"UPDATE v_rescore_chunks00 SET vectors = ? WHERE rowid = 1",
|
||||||
|
-1, &stmt, NULL);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC);
|
||||||
|
sqlite3_step(stmt);
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
stmt = NULL;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 1: {
|
||||||
|
/* Corrupt _rescore_vectors00 vector blob for a specific row */
|
||||||
|
rc = sqlite3_prepare_v2(db,
|
||||||
|
"UPDATE v_rescore_vectors00 SET vector = ? WHERE rowid = 3",
|
||||||
|
-1, &stmt, NULL);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC);
|
||||||
|
sqlite3_step(stmt);
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
stmt = NULL;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 2: {
|
||||||
|
/* Truncate the quantized chunk blob to wrong size */
|
||||||
|
rc = sqlite3_prepare_v2(db,
|
||||||
|
"UPDATE v_rescore_chunks00 SET vectors = X'DEADBEEF' WHERE rowid = 1",
|
||||||
|
-1, &stmt, NULL);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
sqlite3_step(stmt);
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
stmt = NULL;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 3: {
|
||||||
|
/* Delete rows from _rescore_vectors (orphan the float vectors) */
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"DELETE FROM v_rescore_vectors00 WHERE rowid IN (2, 4, 6)",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 4: {
|
||||||
|
/* Delete the chunk row entirely from _rescore_chunks */
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"DELETE FROM v_rescore_chunks00 WHERE rowid = 1",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 5: {
|
||||||
|
/* Set vectors to NULL in _rescore_chunks */
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"UPDATE v_rescore_chunks00 SET vectors = NULL WHERE rowid = 1",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 6: {
|
||||||
|
/* Set vector to NULL in _rescore_vectors */
|
||||||
|
sqlite3_exec(db,
|
||||||
|
"UPDATE v_rescore_vectors00 SET vector = NULL WHERE rowid = 3",
|
||||||
|
NULL, NULL, NULL);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 7: {
|
||||||
|
/* Corrupt BOTH tables with fuzz data */
|
||||||
|
int half = payload_size / 2;
|
||||||
|
rc = sqlite3_prepare_v2(db,
|
||||||
|
"UPDATE v_rescore_chunks00 SET vectors = ? WHERE rowid = 1",
|
||||||
|
-1, &stmt, NULL);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
sqlite3_bind_blob(stmt, 1, payload, half, SQLITE_STATIC);
|
||||||
|
sqlite3_step(stmt);
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
stmt = NULL;
|
||||||
|
}
|
||||||
|
rc = sqlite3_prepare_v2(db,
|
||||||
|
"UPDATE v_rescore_vectors00 SET vector = ? WHERE rowid = 1",
|
||||||
|
-1, &stmt, NULL);
|
||||||
|
if (rc == SQLITE_OK) {
|
||||||
|
sqlite3_bind_blob(stmt, 1, payload + half,
|
||||||
|
payload_size - half, SQLITE_STATIC);
|
||||||
|
sqlite3_step(stmt);
|
||||||
|
sqlite3_finalize(stmt);
|
||||||
|
stmt = NULL;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Exercise ALL read/write paths -- NONE should crash */
|
||||||
|
|
||||||
|
/* KNN query (triggers rescore_knn Phase 1 + Phase 2) */
|
||||||
|
{
|
||||||
|
float qvec[16] = {1,0,0,0, 0,1,0,0, 0,0,1,0, 0,0,0,1};
|
||||||
|
sqlite3_stmt *knn = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? "
|
||||||
|
"ORDER BY distance LIMIT 5", -1, &knn, NULL);
|
||||||
|
if (knn) {
|
||||||
|
sqlite3_bind_blob(knn, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||||
|
while (sqlite3_step(knn) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(knn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Full scan (triggers reading from _rescore_vectors) */
|
||||||
|
sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
|
||||||
|
|
||||||
|
/* Point lookups */
|
||||||
|
sqlite3_exec(db, "SELECT * FROM v WHERE rowid = 1", NULL, NULL, NULL);
|
||||||
|
sqlite3_exec(db, "SELECT * FROM v WHERE rowid = 3", NULL, NULL, NULL);
|
||||||
|
|
||||||
|
/* Insert after corruption */
|
||||||
|
{
|
||||||
|
float vec[16] = {0};
|
||||||
|
sqlite3_stmt *ins = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"INSERT INTO v(rowid, emb) VALUES (99, ?)", -1, &ins, NULL);
|
||||||
|
if (ins) {
|
||||||
|
sqlite3_bind_blob(ins, 1, vec, sizeof(vec), SQLITE_STATIC);
|
||||||
|
sqlite3_step(ins);
|
||||||
|
sqlite3_finalize(ins);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Delete after corruption */
|
||||||
|
sqlite3_exec(db, "DELETE FROM v WHERE rowid = 5", NULL, NULL, NULL);
|
||||||
|
|
||||||
|
/* Update after corruption */
|
||||||
|
{
|
||||||
|
float vec[16] = {1,1,1,1, 1,1,1,1, 1,1,1,1, 1,1,1,1};
|
||||||
|
sqlite3_stmt *upd = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"UPDATE v SET emb = ? WHERE rowid = 1", -1, &upd, NULL);
|
||||||
|
if (upd) {
|
||||||
|
sqlite3_bind_blob(upd, 1, vec, sizeof(vec), SQLITE_STATIC);
|
||||||
|
sqlite3_step(upd);
|
||||||
|
sqlite3_finalize(upd);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* KNN again after modifications to corrupted state */
|
||||||
|
{
|
||||||
|
float qvec[16] = {0,0,0,0, 0,0,0,0, 1,1,1,1, 1,1,1,1};
|
||||||
|
sqlite3_stmt *knn = NULL;
|
||||||
|
sqlite3_prepare_v2(db,
|
||||||
|
"SELECT rowid, distance FROM v WHERE emb MATCH ? "
|
||||||
|
"ORDER BY distance LIMIT 3", -1, &knn, NULL);
|
||||||
|
if (knn) {
|
||||||
|
sqlite3_bind_blob(knn, 1, qvec, sizeof(qvec), SQLITE_STATIC);
|
||||||
|
while (sqlite3_step(knn) == SQLITE_ROW) {}
|
||||||
|
sqlite3_finalize(knn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlite3_exec(db, "DROP TABLE v", NULL, NULL, NULL);
|
||||||
|
sqlite3_close(db);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
81
tests/generate_legacy_db.py
Normal file
81
tests/generate_legacy_db.py
Normal file
|
|
@ -0,0 +1,81 @@
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.10"
|
||||||
|
# dependencies = ["sqlite-vec==0.1.6"]
|
||||||
|
# ///
|
||||||
|
"""Generate a legacy sqlite-vec database for backwards-compat testing.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
uv run --script generate_legacy_db.py
|
||||||
|
|
||||||
|
Creates tests/fixtures/legacy-v0.1.6.db with a vec0 table containing
|
||||||
|
test data that can be read by the current version of sqlite-vec.
|
||||||
|
"""
|
||||||
|
import sqlite3
|
||||||
|
import sqlite_vec
|
||||||
|
import struct
|
||||||
|
import os
|
||||||
|
|
||||||
|
FIXTURE_DIR = os.path.join(os.path.dirname(__file__), "fixtures")
|
||||||
|
DB_PATH = os.path.join(FIXTURE_DIR, "legacy-v0.1.6.db")
|
||||||
|
|
||||||
|
DIMS = 4
|
||||||
|
N_ROWS = 50
|
||||||
|
|
||||||
|
|
||||||
|
def _f32(vals):
|
||||||
|
return struct.pack(f"{len(vals)}f", *vals)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
os.makedirs(FIXTURE_DIR, exist_ok=True)
|
||||||
|
if os.path.exists(DB_PATH):
|
||||||
|
os.remove(DB_PATH)
|
||||||
|
|
||||||
|
db = sqlite3.connect(DB_PATH)
|
||||||
|
db.enable_load_extension(True)
|
||||||
|
sqlite_vec.load(db)
|
||||||
|
|
||||||
|
# Print version for verification
|
||||||
|
version = db.execute("SELECT vec_version()").fetchone()[0]
|
||||||
|
print(f"sqlite-vec version: {version}")
|
||||||
|
|
||||||
|
# Create a basic vec0 table — flat index, no fancy features
|
||||||
|
db.execute(f"CREATE VIRTUAL TABLE legacy_vectors USING vec0(emb float[{DIMS}])")
|
||||||
|
|
||||||
|
# Insert test data: vectors where element[0] == rowid for easy verification
|
||||||
|
for i in range(1, N_ROWS + 1):
|
||||||
|
vec = [float(i), 0.0, 0.0, 0.0]
|
||||||
|
db.execute("INSERT INTO legacy_vectors(rowid, emb) VALUES (?, ?)", [i, _f32(vec)])
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
# Verify
|
||||||
|
count = db.execute("SELECT count(*) FROM legacy_vectors").fetchone()[0]
|
||||||
|
print(f"Inserted {count} rows")
|
||||||
|
|
||||||
|
# Test KNN works
|
||||||
|
query = _f32([1.0, 0.0, 0.0, 0.0])
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT rowid, distance FROM legacy_vectors WHERE emb MATCH ? AND k = 5",
|
||||||
|
[query],
|
||||||
|
).fetchall()
|
||||||
|
print(f"KNN top 5: {[(r[0], round(r[1], 4)) for r in rows]}")
|
||||||
|
assert rows[0][0] == 1 # closest to [1,0,0,0]
|
||||||
|
assert len(rows) == 5
|
||||||
|
|
||||||
|
# Also create a table with name == column name (the conflict case)
|
||||||
|
# This was allowed in old versions — new code must not break on reconnect
|
||||||
|
db.execute("CREATE VIRTUAL TABLE emb USING vec0(emb float[4])")
|
||||||
|
for i in range(1, 11):
|
||||||
|
db.execute("INSERT INTO emb(rowid, emb) VALUES (?, ?)", [i, _f32([float(i), 0, 0, 0])])
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
count2 = db.execute("SELECT count(*) FROM emb").fetchone()[0]
|
||||||
|
print(f"Table 'emb' with column 'emb': {count2} rows (name conflict case)")
|
||||||
|
|
||||||
|
db.close()
|
||||||
|
print(f"\nGenerated: {DB_PATH}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -3,6 +3,11 @@
|
||||||
|
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#ifndef SQLITE_VEC_ENABLE_IVF
|
||||||
|
#define SQLITE_VEC_ENABLE_IVF 1
|
||||||
|
#endif
|
||||||
|
|
||||||
int min_idx(
|
int min_idx(
|
||||||
const float *distances,
|
const float *distances,
|
||||||
|
|
@ -62,12 +67,81 @@ enum Vec0DistanceMetrics {
|
||||||
VEC0_DISTANCE_METRIC_L1 = 3,
|
VEC0_DISTANCE_METRIC_L1 = 3,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum Vec0IndexType {
|
||||||
|
VEC0_INDEX_TYPE_FLAT = 1,
|
||||||
|
#ifdef SQLITE_VEC_ENABLE_RESCORE
|
||||||
|
VEC0_INDEX_TYPE_RESCORE = 2,
|
||||||
|
#endif
|
||||||
|
VEC0_INDEX_TYPE_IVF = 3,
|
||||||
|
VEC0_INDEX_TYPE_DISKANN = 4,
|
||||||
|
};
|
||||||
|
|
||||||
|
enum Vec0RescoreQuantizerType {
|
||||||
|
VEC0_RESCORE_QUANTIZER_BIT = 1,
|
||||||
|
VEC0_RESCORE_QUANTIZER_INT8 = 2,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Vec0RescoreConfig {
|
||||||
|
enum Vec0RescoreQuantizerType quantizer_type;
|
||||||
|
int oversample;
|
||||||
|
};
|
||||||
|
|
||||||
|
#if SQLITE_VEC_ENABLE_IVF
|
||||||
|
enum Vec0IvfQuantizer {
|
||||||
|
VEC0_IVF_QUANTIZER_NONE = 0,
|
||||||
|
VEC0_IVF_QUANTIZER_INT8 = 1,
|
||||||
|
VEC0_IVF_QUANTIZER_BINARY = 2,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Vec0IvfConfig {
|
||||||
|
int nlist;
|
||||||
|
int nprobe;
|
||||||
|
int quantizer;
|
||||||
|
int oversample;
|
||||||
|
};
|
||||||
|
#else
|
||||||
|
struct Vec0IvfConfig { char _unused; };
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef SQLITE_VEC_ENABLE_RESCORE
|
||||||
|
enum Vec0RescoreQuantizerType {
|
||||||
|
VEC0_RESCORE_QUANTIZER_BIT = 1,
|
||||||
|
VEC0_RESCORE_QUANTIZER_INT8 = 2,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Vec0RescoreConfig {
|
||||||
|
enum Vec0RescoreQuantizerType quantizer_type;
|
||||||
|
int oversample;
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
enum Vec0DiskannQuantizerType {
|
||||||
|
VEC0_DISKANN_QUANTIZER_BINARY = 1,
|
||||||
|
VEC0_DISKANN_QUANTIZER_INT8 = 2,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Vec0DiskannConfig {
|
||||||
|
enum Vec0DiskannQuantizerType quantizer_type;
|
||||||
|
int n_neighbors;
|
||||||
|
int search_list_size;
|
||||||
|
int search_list_size_search;
|
||||||
|
int search_list_size_insert;
|
||||||
|
float alpha;
|
||||||
|
int buffer_threshold;
|
||||||
|
};
|
||||||
|
|
||||||
struct VectorColumnDefinition {
|
struct VectorColumnDefinition {
|
||||||
char *name;
|
char *name;
|
||||||
int name_length;
|
int name_length;
|
||||||
size_t dimensions;
|
size_t dimensions;
|
||||||
enum VectorElementType element_type;
|
enum VectorElementType element_type;
|
||||||
enum Vec0DistanceMetrics distance_metric;
|
enum Vec0DistanceMetrics distance_metric;
|
||||||
|
enum Vec0IndexType index_type;
|
||||||
|
#ifdef SQLITE_VEC_ENABLE_RESCORE
|
||||||
|
struct Vec0RescoreConfig rescore;
|
||||||
|
#endif
|
||||||
|
struct Vec0IvfConfig ivf;
|
||||||
|
struct Vec0DiskannConfig diskann;
|
||||||
};
|
};
|
||||||
|
|
||||||
int vec0_parse_vector_column(const char *source, int source_length,
|
int vec0_parse_vector_column(const char *source, int source_length,
|
||||||
|
|
@ -78,10 +152,90 @@ int vec0_parse_partition_key_definition(const char *source, int source_length,
|
||||||
int *out_column_name_length,
|
int *out_column_name_length,
|
||||||
int *out_column_type);
|
int *out_column_type);
|
||||||
|
|
||||||
|
size_t diskann_quantized_vector_byte_size(
|
||||||
|
enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions);
|
||||||
|
|
||||||
|
int diskann_validity_byte_size(int n_neighbors);
|
||||||
|
size_t diskann_neighbor_ids_byte_size(int n_neighbors);
|
||||||
|
size_t diskann_neighbor_qvecs_byte_size(
|
||||||
|
int n_neighbors, enum Vec0DiskannQuantizerType quantizer_type,
|
||||||
|
size_t dimensions);
|
||||||
|
int diskann_node_init(
|
||||||
|
int n_neighbors, enum Vec0DiskannQuantizerType quantizer_type,
|
||||||
|
size_t dimensions,
|
||||||
|
unsigned char **outValidity, int *outValiditySize,
|
||||||
|
unsigned char **outNeighborIds, int *outNeighborIdsSize,
|
||||||
|
unsigned char **outNeighborQvecs, int *outNeighborQvecsSize);
|
||||||
|
int diskann_validity_get(const unsigned char *validity, int i);
|
||||||
|
void diskann_validity_set(unsigned char *validity, int i, int value);
|
||||||
|
int diskann_validity_count(const unsigned char *validity, int n_neighbors);
|
||||||
|
long long diskann_neighbor_id_get(const unsigned char *neighbor_ids, int i);
|
||||||
|
void diskann_neighbor_id_set(unsigned char *neighbor_ids, int i, long long rowid);
|
||||||
|
const unsigned char *diskann_neighbor_qvec_get(
|
||||||
|
const unsigned char *qvecs, int i,
|
||||||
|
enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions);
|
||||||
|
void diskann_neighbor_qvec_set(
|
||||||
|
unsigned char *qvecs, int i, const unsigned char *src_qvec,
|
||||||
|
enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions);
|
||||||
|
void diskann_node_set_neighbor(
|
||||||
|
unsigned char *validity, unsigned char *neighbor_ids, unsigned char *qvecs, int i,
|
||||||
|
long long neighbor_rowid, const unsigned char *neighbor_qvec,
|
||||||
|
enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions);
|
||||||
|
void diskann_node_clear_neighbor(
|
||||||
|
unsigned char *validity, unsigned char *neighbor_ids, unsigned char *qvecs, int i,
|
||||||
|
enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions);
|
||||||
|
int diskann_quantize_vector(
|
||||||
|
const float *src, size_t dimensions,
|
||||||
|
enum Vec0DiskannQuantizerType quantizer_type,
|
||||||
|
unsigned char *out);
|
||||||
|
|
||||||
|
int diskann_prune_select(
|
||||||
|
const float *inter_distances, const float *p_distances,
|
||||||
|
int num_candidates, float alpha, int max_neighbors,
|
||||||
|
int *outSelected, int *outCount);
|
||||||
|
|
||||||
#ifdef SQLITE_VEC_TEST
|
#ifdef SQLITE_VEC_TEST
|
||||||
float _test_distance_l2_sqr_float(const float *a, const float *b, size_t dims);
|
float _test_distance_l2_sqr_float(const float *a, const float *b, size_t dims);
|
||||||
float _test_distance_cosine_float(const float *a, const float *b, size_t dims);
|
float _test_distance_cosine_float(const float *a, const float *b, size_t dims);
|
||||||
float _test_distance_hamming(const unsigned char *a, const unsigned char *b, size_t dims);
|
float _test_distance_hamming(const unsigned char *a, const unsigned char *b, size_t dims);
|
||||||
|
|
||||||
|
#ifdef SQLITE_VEC_ENABLE_RESCORE
|
||||||
|
void _test_rescore_quantize_float_to_bit(const float *src, uint8_t *dst, size_t dim);
|
||||||
|
void _test_rescore_quantize_float_to_int8(const float *src, int8_t *dst, size_t dim);
|
||||||
|
size_t _test_rescore_quantized_byte_size_bit(size_t dimensions);
|
||||||
|
size_t _test_rescore_quantized_byte_size_int8(size_t dimensions);
|
||||||
|
#endif
|
||||||
|
#if SQLITE_VEC_ENABLE_IVF
|
||||||
|
void ivf_quantize_int8(const float *src, int8_t *dst, int D);
|
||||||
|
void ivf_quantize_binary(const float *src, uint8_t *dst, int D);
|
||||||
|
#endif
|
||||||
|
// DiskANN candidate list (opaque struct, use accessors)
|
||||||
|
struct DiskannCandidateList {
|
||||||
|
void *items; // opaque
|
||||||
|
int count;
|
||||||
|
int capacity;
|
||||||
|
};
|
||||||
|
|
||||||
|
int _test_diskann_candidate_list_init(struct DiskannCandidateList *list, int capacity);
|
||||||
|
void _test_diskann_candidate_list_free(struct DiskannCandidateList *list);
|
||||||
|
int _test_diskann_candidate_list_insert(struct DiskannCandidateList *list, long long rowid, float distance);
|
||||||
|
int _test_diskann_candidate_list_next_unvisited(const struct DiskannCandidateList *list);
|
||||||
|
int _test_diskann_candidate_list_count(const struct DiskannCandidateList *list);
|
||||||
|
long long _test_diskann_candidate_list_rowid(const struct DiskannCandidateList *list, int i);
|
||||||
|
float _test_diskann_candidate_list_distance(const struct DiskannCandidateList *list, int i);
|
||||||
|
void _test_diskann_candidate_list_set_visited(struct DiskannCandidateList *list, int i);
|
||||||
|
|
||||||
|
// DiskANN visited set (opaque struct, use accessors)
|
||||||
|
struct DiskannVisitedSet {
|
||||||
|
void *slots; // opaque
|
||||||
|
int capacity;
|
||||||
|
int count;
|
||||||
|
};
|
||||||
|
|
||||||
|
int _test_diskann_visited_set_init(struct DiskannVisitedSet *set, int capacity);
|
||||||
|
void _test_diskann_visited_set_free(struct DiskannVisitedSet *set);
|
||||||
|
int _test_diskann_visited_set_contains(const struct DiskannVisitedSet *set, long long rowid);
|
||||||
|
int _test_diskann_visited_set_insert(struct DiskannVisitedSet *set, long long rowid);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#endif /* SQLITE_VEC_INTERNAL_H */
|
#endif /* SQLITE_VEC_INTERNAL_H */
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from helpers import exec, vec0_shadow_table_contents
|
import struct
|
||||||
|
import pytest
|
||||||
|
from helpers import exec, vec0_shadow_table_contents, _f32
|
||||||
|
|
||||||
|
|
||||||
def test_constructor_limit(db, snapshot):
|
def test_constructor_limit(db, snapshot):
|
||||||
|
|
@ -126,3 +128,198 @@ def test_knn(db, snapshot):
|
||||||
) == snapshot(name="illegal KNN w/ aux")
|
) == snapshot(name="illegal KNN w/ aux")
|
||||||
|
|
||||||
|
|
||||||
|
# ======================================================================
|
||||||
|
# Auxiliary columns with non-flat indexes
|
||||||
|
# ======================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_rescore_aux_shadow_tables(db, snapshot):
|
||||||
|
"""Rescore + aux column: verify shadow tables are created correctly."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
" emb float[128] indexed by rescore(quantizer=bit),"
|
||||||
|
" +label text,"
|
||||||
|
" +score float"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
assert exec(db, "SELECT name, sql FROM sqlite_master WHERE type='table' AND name LIKE 't_%' ORDER BY name") == snapshot(
|
||||||
|
name="rescore aux shadow tables"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rescore_aux_insert_knn(db, snapshot):
|
||||||
|
"""Insert with aux data, KNN should return aux column values."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
" emb float[128] indexed by rescore(quantizer=bit),"
|
||||||
|
" +label text"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
import random
|
||||||
|
random.seed(77)
|
||||||
|
data = [
|
||||||
|
("alpha", [random.gauss(0, 1) for _ in range(128)]),
|
||||||
|
("beta", [random.gauss(0, 1) for _ in range(128)]),
|
||||||
|
("gamma", [random.gauss(0, 1) for _ in range(128)]),
|
||||||
|
]
|
||||||
|
for label, vec in data:
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(emb, label) VALUES (?, ?)",
|
||||||
|
[_f32(vec), label],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exec(db, "SELECT rowid, label FROM t ORDER BY rowid") == snapshot(
|
||||||
|
name="rescore aux select all"
|
||||||
|
)
|
||||||
|
assert vec0_shadow_table_contents(db, "t", skip_info=True) == snapshot(
|
||||||
|
name="rescore aux shadow contents"
|
||||||
|
)
|
||||||
|
|
||||||
|
# KNN should include aux column, "alpha" closest to its own vector
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT label, distance FROM t WHERE emb MATCH ? ORDER BY distance LIMIT 3",
|
||||||
|
[_f32(data[0][1])],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) == 3
|
||||||
|
assert rows[0][0] == "alpha"
|
||||||
|
|
||||||
|
|
||||||
|
def test_rescore_aux_update(db):
|
||||||
|
"""UPDATE aux column on rescore table should work without affecting vectors."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
" emb float[128] indexed by rescore(quantizer=bit),"
|
||||||
|
" +label text"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
import random
|
||||||
|
random.seed(88)
|
||||||
|
vec = [random.gauss(0, 1) for _ in range(128)]
|
||||||
|
db.execute("INSERT INTO t(rowid, emb, label) VALUES (1, ?, 'original')", [_f32(vec)])
|
||||||
|
db.execute("UPDATE t SET label = 'updated' WHERE rowid = 1")
|
||||||
|
|
||||||
|
assert db.execute("SELECT label FROM t WHERE rowid = 1").fetchone()[0] == "updated"
|
||||||
|
|
||||||
|
# KNN still works with updated aux
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT rowid, label FROM t WHERE emb MATCH ? ORDER BY distance LIMIT 1",
|
||||||
|
[_f32(vec)],
|
||||||
|
).fetchall()
|
||||||
|
assert rows[0][0] == 1
|
||||||
|
assert rows[0][1] == "updated"
|
||||||
|
|
||||||
|
|
||||||
|
def test_rescore_aux_delete(db, snapshot):
|
||||||
|
"""DELETE should remove aux data from shadow table."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
" emb float[128] indexed by rescore(quantizer=bit),"
|
||||||
|
" +label text"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
import random
|
||||||
|
random.seed(99)
|
||||||
|
for i in range(5):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, emb, label) VALUES (?, ?, ?)",
|
||||||
|
[i + 1, _f32([random.gauss(0, 1) for _ in range(128)]), f"item-{i+1}"],
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = 3")
|
||||||
|
|
||||||
|
assert exec(db, "SELECT rowid, label FROM t ORDER BY rowid") == snapshot(
|
||||||
|
name="rescore aux after delete"
|
||||||
|
)
|
||||||
|
assert exec(db, "SELECT rowid, value00 FROM t_auxiliary ORDER BY rowid") == snapshot(
|
||||||
|
name="rescore aux shadow after delete"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_diskann_aux_shadow_tables(db, snapshot):
|
||||||
|
"""DiskANN + aux column: verify shadow tables are created correctly."""
|
||||||
|
db.execute("""
|
||||||
|
CREATE VIRTUAL TABLE t USING vec0(
|
||||||
|
emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8),
|
||||||
|
+label text,
|
||||||
|
+score float
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
assert exec(db, "SELECT name, sql FROM sqlite_master WHERE type='table' AND name LIKE 't_%' ORDER BY name") == snapshot(
|
||||||
|
name="diskann aux shadow tables"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_diskann_aux_insert_knn(db, snapshot):
|
||||||
|
"""DiskANN + aux: insert, KNN, verify aux values returned."""
|
||||||
|
db.execute("""
|
||||||
|
CREATE VIRTUAL TABLE t USING vec0(
|
||||||
|
emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8),
|
||||||
|
+label text
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
data = [
|
||||||
|
("red", [1, 0, 0, 0, 0, 0, 0, 0]),
|
||||||
|
("green", [0, 1, 0, 0, 0, 0, 0, 0]),
|
||||||
|
("blue", [0, 0, 1, 0, 0, 0, 0, 0]),
|
||||||
|
]
|
||||||
|
for label, vec in data:
|
||||||
|
db.execute("INSERT INTO t(emb, label) VALUES (?, ?)", [_f32(vec), label])
|
||||||
|
|
||||||
|
assert exec(db, "SELECT rowid, label FROM t ORDER BY rowid") == snapshot(
|
||||||
|
name="diskann aux select all"
|
||||||
|
)
|
||||||
|
assert vec0_shadow_table_contents(db, "t", skip_info=True) == snapshot(
|
||||||
|
name="diskann aux shadow contents"
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT label, distance FROM t WHERE emb MATCH ? AND k = 3",
|
||||||
|
[_f32([1, 0, 0, 0, 0, 0, 0, 0])],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) >= 1
|
||||||
|
assert rows[0][0] == "red"
|
||||||
|
|
||||||
|
|
||||||
|
def test_diskann_aux_update_and_delete(db, snapshot):
|
||||||
|
"""DiskANN + aux: update aux column, delete row, verify cleanup."""
|
||||||
|
db.execute("""
|
||||||
|
CREATE VIRTUAL TABLE t USING vec0(
|
||||||
|
emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8),
|
||||||
|
+label text
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
for i in range(5):
|
||||||
|
vec = [0.0] * 8
|
||||||
|
vec[i % 8] = 1.0
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, emb, label) VALUES (?, ?, ?)",
|
||||||
|
[i + 1, _f32(vec), f"item-{i+1}"],
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("UPDATE t SET label = 'UPDATED' WHERE rowid = 2")
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = 3")
|
||||||
|
|
||||||
|
assert exec(db, "SELECT rowid, label FROM t ORDER BY rowid") == snapshot(
|
||||||
|
name="diskann aux after update+delete"
|
||||||
|
)
|
||||||
|
assert exec(db, "SELECT rowid, value00 FROM t_auxiliary ORDER BY rowid") == snapshot(
|
||||||
|
name="diskann aux shadow after update+delete"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_diskann_aux_drop_cleans_all(db):
|
||||||
|
"""DROP TABLE should remove aux shadow table too."""
|
||||||
|
db.execute("""
|
||||||
|
CREATE VIRTUAL TABLE t USING vec0(
|
||||||
|
emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8),
|
||||||
|
+label text
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
db.execute("INSERT INTO t(emb, label) VALUES (?, 'test')", [_f32([1]*8)])
|
||||||
|
db.execute("DROP TABLE t")
|
||||||
|
|
||||||
|
tables = [r[0] for r in db.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE name LIKE 't_%'"
|
||||||
|
).fetchall()]
|
||||||
|
assert "t_auxiliary" not in tables
|
||||||
|
|
||||||
|
|
|
||||||
1331
tests/test-diskann.py
Normal file
1331
tests/test-diskann.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -27,3 +27,15 @@ def test_info(db, snapshot):
|
||||||
assert exec(db, "select key, typeof(value) from v_info order by 1") == snapshot()
|
assert exec(db, "select key, typeof(value) from v_info order by 1") == snapshot()
|
||||||
|
|
||||||
|
|
||||||
|
def test_command_column_name_conflict(db):
|
||||||
|
"""Table name matching a column name should error (command column conflict)."""
|
||||||
|
# This would conflict: hidden command column 'embeddings' vs vector column 'embeddings'
|
||||||
|
with pytest.raises(sqlite3.OperationalError, match="conflicts with table name"):
|
||||||
|
db.execute(
|
||||||
|
"create virtual table embeddings using vec0(embeddings float[4])"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Different names should work fine
|
||||||
|
db.execute("create virtual table t using vec0(embeddings float[4])")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -483,3 +483,171 @@ def test_delete_one_chunk_of_two_shrinks_pages(tmp_path):
|
||||||
row = db.execute("select emb from v where rowid = ?", [i]).fetchone()
|
row = db.execute("select emb from v where rowid = ?", [i]).fetchone()
|
||||||
assert row[0] == _f32([float(i)] * dims)
|
assert row[0] == _f32([float(i)] * dims)
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_wal_concurrent_reader_during_write(tmp_path):
|
||||||
|
"""In WAL mode, a reader should see a consistent snapshot while a writer inserts."""
|
||||||
|
dims = 4
|
||||||
|
db_path = str(tmp_path / "test.db")
|
||||||
|
|
||||||
|
# Writer: create table, insert initial rows, enable WAL
|
||||||
|
writer = sqlite3.connect(db_path)
|
||||||
|
writer.enable_load_extension(True)
|
||||||
|
writer.load_extension("dist/vec0")
|
||||||
|
writer.execute("PRAGMA journal_mode=WAL")
|
||||||
|
writer.execute(
|
||||||
|
f"CREATE VIRTUAL TABLE v USING vec0(emb float[{dims}])"
|
||||||
|
)
|
||||||
|
for i in range(1, 11):
|
||||||
|
writer.execute("INSERT INTO v(rowid, emb) VALUES (?, ?)", [i, _f32([float(i)] * dims)])
|
||||||
|
writer.commit()
|
||||||
|
|
||||||
|
# Reader: open separate connection, start read
|
||||||
|
reader = sqlite3.connect(db_path)
|
||||||
|
reader.enable_load_extension(True)
|
||||||
|
reader.load_extension("dist/vec0")
|
||||||
|
|
||||||
|
# Reader sees 10 rows
|
||||||
|
count_before = reader.execute("SELECT count(*) FROM v").fetchone()[0]
|
||||||
|
assert count_before == 10
|
||||||
|
|
||||||
|
# Writer inserts more rows (not yet committed)
|
||||||
|
writer.execute("BEGIN")
|
||||||
|
for i in range(11, 21):
|
||||||
|
writer.execute("INSERT INTO v(rowid, emb) VALUES (?, ?)", [i, _f32([float(i)] * dims)])
|
||||||
|
|
||||||
|
# Reader still sees 10 (WAL snapshot isolation)
|
||||||
|
count_during = reader.execute("SELECT count(*) FROM v").fetchone()[0]
|
||||||
|
assert count_during == 10
|
||||||
|
|
||||||
|
# KNN during writer's transaction should work on reader's snapshot
|
||||||
|
rows = reader.execute(
|
||||||
|
"SELECT rowid FROM v WHERE emb MATCH ? AND k = 5",
|
||||||
|
[_f32([1.0] * dims)],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) == 5
|
||||||
|
assert all(r[0] <= 10 for r in rows) # only original rows
|
||||||
|
|
||||||
|
# Writer commits
|
||||||
|
writer.commit()
|
||||||
|
|
||||||
|
# Reader sees new rows after re-query (new snapshot)
|
||||||
|
count_after = reader.execute("SELECT count(*) FROM v").fetchone()[0]
|
||||||
|
assert count_after == 20
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
reader.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_insert_or_replace_integer_pk(db):
|
||||||
|
"""INSERT OR REPLACE should update vector when rowid already exists."""
|
||||||
|
db.execute("create virtual table v using vec0(emb float[4], chunk_size=8)")
|
||||||
|
|
||||||
|
db.execute(
|
||||||
|
"insert into v(rowid, emb) values (1, ?)", [_f32([1.0, 2.0, 3.0, 4.0])]
|
||||||
|
)
|
||||||
|
# Replace with new vector
|
||||||
|
db.execute(
|
||||||
|
"insert or replace into v(rowid, emb) values (1, ?)",
|
||||||
|
[_f32([10.0, 20.0, 30.0, 40.0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should still have exactly 1 row
|
||||||
|
count = db.execute("select count(*) from v").fetchone()[0]
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
# Vector should be the replaced value
|
||||||
|
row = db.execute("select emb from v where rowid = 1").fetchone()
|
||||||
|
assert row[0] == _f32([10.0, 20.0, 30.0, 40.0])
|
||||||
|
|
||||||
|
|
||||||
|
def test_insert_or_replace_new_row(db):
|
||||||
|
"""INSERT OR REPLACE with a new rowid should just insert normally."""
|
||||||
|
db.execute("create virtual table v using vec0(emb float[4], chunk_size=8)")
|
||||||
|
|
||||||
|
db.execute(
|
||||||
|
"insert or replace into v(rowid, emb) values (1, ?)",
|
||||||
|
[_f32([1.0, 2.0, 3.0, 4.0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
count = db.execute("select count(*) from v").fetchone()[0]
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
row = db.execute("select emb from v where rowid = 1").fetchone()
|
||||||
|
assert row[0] == _f32([1.0, 2.0, 3.0, 4.0])
|
||||||
|
|
||||||
|
|
||||||
|
def test_insert_or_replace_text_pk(db):
|
||||||
|
"""INSERT OR REPLACE should work with text primary keys."""
|
||||||
|
db.execute(
|
||||||
|
"create virtual table v using vec0("
|
||||||
|
"id text primary key, emb float[4], chunk_size=8"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute(
|
||||||
|
"insert into v(id, emb) values ('doc_a', ?)",
|
||||||
|
[_f32([1.0, 2.0, 3.0, 4.0])],
|
||||||
|
)
|
||||||
|
db.execute(
|
||||||
|
"insert or replace into v(id, emb) values ('doc_a', ?)",
|
||||||
|
[_f32([10.0, 20.0, 30.0, 40.0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
count = db.execute("select count(*) from v").fetchone()[0]
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
row = db.execute("select emb from v where id = 'doc_a'").fetchone()
|
||||||
|
assert row[0] == _f32([10.0, 20.0, 30.0, 40.0])
|
||||||
|
|
||||||
|
|
||||||
|
def test_insert_or_replace_with_auxiliary(db):
|
||||||
|
"""INSERT OR REPLACE should also replace auxiliary column values."""
|
||||||
|
db.execute(
|
||||||
|
"create virtual table v using vec0("
|
||||||
|
"emb float[4], +label text, chunk_size=8"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute(
|
||||||
|
"insert into v(rowid, emb, label) values (1, ?, 'old')",
|
||||||
|
[_f32([1.0, 2.0, 3.0, 4.0])],
|
||||||
|
)
|
||||||
|
db.execute(
|
||||||
|
"insert or replace into v(rowid, emb, label) values (1, ?, 'new')",
|
||||||
|
[_f32([10.0, 20.0, 30.0, 40.0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
count = db.execute("select count(*) from v").fetchone()[0]
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
row = db.execute("select emb, label from v where rowid = 1").fetchone()
|
||||||
|
assert row[0] == _f32([10.0, 20.0, 30.0, 40.0])
|
||||||
|
assert row[1] == "new"
|
||||||
|
|
||||||
|
|
||||||
|
def test_insert_or_replace_knn_uses_new_vector(db):
|
||||||
|
"""After INSERT OR REPLACE, KNN should find the new vector, not the old one."""
|
||||||
|
db.execute("create virtual table v using vec0(emb float[4], chunk_size=8)")
|
||||||
|
|
||||||
|
db.execute(
|
||||||
|
"insert into v(rowid, emb) values (1, ?)", [_f32([1.0, 0.0, 0.0, 0.0])]
|
||||||
|
)
|
||||||
|
db.execute(
|
||||||
|
"insert into v(rowid, emb) values (2, ?)", [_f32([0.0, 1.0, 0.0, 0.0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Replace row 1's vector to be very close to row 2
|
||||||
|
db.execute(
|
||||||
|
"insert or replace into v(rowid, emb) values (1, ?)",
|
||||||
|
[_f32([0.0, 0.9, 0.0, 0.0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
# KNN for [0, 1, 0, 0] should return row 2 first (exact), then row 1 (close)
|
||||||
|
rows = db.execute(
|
||||||
|
"select rowid, distance from v where emb match ? and k = 2",
|
||||||
|
[_f32([0.0, 1.0, 0.0, 0.0])],
|
||||||
|
).fetchall()
|
||||||
|
assert rows[0][0] == 2
|
||||||
|
assert rows[1][0] == 1
|
||||||
|
assert rows[1][1] < 0.11 # should be close (L2 distance ≈ 0.1)
|
||||||
|
|
|
||||||
589
tests/test-ivf-mutations.py
Normal file
589
tests/test-ivf-mutations.py
Normal file
|
|
@ -0,0 +1,589 @@
|
||||||
|
"""
|
||||||
|
Thorough IVF mutation tests: insert, delete, update, KNN correctness,
|
||||||
|
error cases, edge cases, and cell overflow scenarios.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
import sqlite3
|
||||||
|
import struct
|
||||||
|
import math
|
||||||
|
from helpers import _f32, exec
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def db():
|
||||||
|
db = sqlite3.connect(":memory:")
|
||||||
|
db.row_factory = sqlite3.Row
|
||||||
|
db.enable_load_extension(True)
|
||||||
|
db.load_extension("dist/vec0")
|
||||||
|
db.enable_load_extension(False)
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
def ivf_total_vectors(db, table="t", col=0):
|
||||||
|
"""Count total vectors across all IVF cells."""
|
||||||
|
return db.execute(
|
||||||
|
f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d}"
|
||||||
|
).fetchone()[0]
|
||||||
|
|
||||||
|
|
||||||
|
def ivf_unassigned_count(db, table="t", col=0):
|
||||||
|
return db.execute(
|
||||||
|
f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d} WHERE centroid_id = -1"
|
||||||
|
).fetchone()[0]
|
||||||
|
|
||||||
|
|
||||||
|
def ivf_assigned_count(db, table="t", col=0):
|
||||||
|
return db.execute(
|
||||||
|
f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d} WHERE centroid_id >= 0"
|
||||||
|
).fetchone()[0]
|
||||||
|
|
||||||
|
|
||||||
|
def knn(db, query, k, table="t", col="v"):
|
||||||
|
"""Run a KNN query and return list of (rowid, distance) tuples."""
|
||||||
|
rows = db.execute(
|
||||||
|
f"SELECT rowid, distance FROM {table} WHERE {col} MATCH ? AND k = ?",
|
||||||
|
[_f32(query), k],
|
||||||
|
).fetchall()
|
||||||
|
return [(r[0], r[1]) for r in rows]
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Single row insert + KNN
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_insert_single_row_knn(db):
|
||||||
|
db.execute("CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf())")
|
||||||
|
db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 0, 0, 0])])
|
||||||
|
results = knn(db, [1, 0, 0, 0], 5)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0][0] == 1
|
||||||
|
assert results[0][1] < 0.001
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Batch insert + KNN recall
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_insert_knn_recall(db):
|
||||||
|
"""Insert 200 vectors, train, verify KNN recall with nprobe=nlist."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=8, nprobe=8))"
|
||||||
|
)
|
||||||
|
for i in range(200):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
|
)
|
||||||
|
assert ivf_total_vectors(db) == 200
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
assert ivf_assigned_count(db) == 200
|
||||||
|
|
||||||
|
# Query near 100 -- closest should be rowid 100
|
||||||
|
results = knn(db, [100.0, 0, 0, 0], 10)
|
||||||
|
assert len(results) == 10
|
||||||
|
assert results[0][0] == 100
|
||||||
|
assert results[0][1] < 0.01
|
||||||
|
|
||||||
|
# All results should be near 100
|
||||||
|
rowids = {r[0] for r in results}
|
||||||
|
assert all(95 <= r <= 105 for r in rowids)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Delete rows, verify they're gone from KNN
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_rows_gone_from_knn(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
|
||||||
|
)
|
||||||
|
for i in range(20):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
# Delete rowid 10
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = 10")
|
||||||
|
|
||||||
|
results = knn(db, [10.0, 0, 0, 0], 20)
|
||||||
|
rowids = {r[0] for r in results}
|
||||||
|
assert 10 not in rowids
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_all_rows_empty_results(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
|
||||||
|
)
|
||||||
|
for i in range(10):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = ?", [i])
|
||||||
|
|
||||||
|
assert ivf_total_vectors(db) == 0
|
||||||
|
results = knn(db, [5.0, 0, 0, 0], 10)
|
||||||
|
assert len(results) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Insert after delete (reuse rowids)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_insert_after_delete_reuse_rowid(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
|
||||||
|
)
|
||||||
|
for i in range(10):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
# Delete rowid 5
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = 5")
|
||||||
|
|
||||||
|
# Re-insert rowid 5 with a very different vector
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (5, ?)", [_f32([999.0, 0, 0, 0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
# KNN near 999 should find rowid 5
|
||||||
|
results = knn(db, [999.0, 0, 0, 0], 1)
|
||||||
|
assert len(results) >= 1
|
||||||
|
assert results[0][0] == 5
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Update vectors (INSERT OR REPLACE), verify KNN reflects new values
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_vector_via_delete_insert(db):
|
||||||
|
"""vec0 IVF update: delete then re-insert with new vector."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
|
||||||
|
)
|
||||||
|
for i in range(10):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
# "Update" rowid 3: delete and re-insert with new vector
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = 3")
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (3, ?)",
|
||||||
|
[_f32([100.0, 0, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
# KNN near 100 should find rowid 3
|
||||||
|
results = knn(db, [100.0, 0, 0, 0], 1)
|
||||||
|
assert results[0][0] == 3
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Error cases: IVF + auxiliary/metadata/partition key columns
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_with_auxiliary_column(db):
|
||||||
|
"""IVF should support auxiliary columns."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(), +extra text)"
|
||||||
|
)
|
||||||
|
tables = [r[0] for r in db.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE name LIKE 't_%' ORDER BY 1"
|
||||||
|
).fetchall()]
|
||||||
|
assert "t_auxiliary" in tables
|
||||||
|
|
||||||
|
|
||||||
|
def test_error_ivf_with_metadata_column(db):
|
||||||
|
result = exec(
|
||||||
|
db,
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(), genre text)",
|
||||||
|
)
|
||||||
|
assert "error" in result
|
||||||
|
assert "metadata" in result.get("message", "").lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_error_ivf_with_partition_key(db):
|
||||||
|
result = exec(
|
||||||
|
db,
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(), user_id integer partition key)",
|
||||||
|
)
|
||||||
|
assert "error" in result
|
||||||
|
assert "partition" in result.get("message", "").lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_flat_with_auxiliary_still_works(db):
|
||||||
|
"""Regression guard: flat-indexed tables with aux columns should still work."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4], +extra text)"
|
||||||
|
)
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v, extra) VALUES (1, ?, 'hello')",
|
||||||
|
[_f32([1, 0, 0, 0])],
|
||||||
|
)
|
||||||
|
row = db.execute("SELECT extra FROM t WHERE rowid = 1").fetchone()
|
||||||
|
assert row[0] == "hello"
|
||||||
|
|
||||||
|
|
||||||
|
def test_flat_with_metadata_still_works(db):
|
||||||
|
"""Regression guard: flat-indexed tables with metadata columns should still work."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4], genre text)"
|
||||||
|
)
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v, genre) VALUES (1, ?, 'rock')",
|
||||||
|
[_f32([1, 0, 0, 0])],
|
||||||
|
)
|
||||||
|
row = db.execute("SELECT genre FROM t WHERE rowid = 1").fetchone()
|
||||||
|
assert row[0] == "rock"
|
||||||
|
|
||||||
|
|
||||||
|
def test_flat_with_partition_key_still_works(db):
|
||||||
|
"""Regression guard: flat-indexed tables with partition key should still work."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4], user_id integer partition key)"
|
||||||
|
)
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v, user_id) VALUES (1, ?, 42)",
|
||||||
|
[_f32([1, 0, 0, 0])],
|
||||||
|
)
|
||||||
|
row = db.execute("SELECT user_id FROM t WHERE rowid = 1").fetchone()
|
||||||
|
assert row[0] == 42
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Edge cases
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_zero_vectors(db):
|
||||||
|
"""Insert zero vectors, verify KNN still works."""
|
||||||
|
db.execute("CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf())")
|
||||||
|
for i in range(5):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([0, 0, 0, 0])],
|
||||||
|
)
|
||||||
|
results = knn(db, [0, 0, 0, 0], 5)
|
||||||
|
assert len(results) == 5
|
||||||
|
# All distances should be 0
|
||||||
|
for _, dist in results:
|
||||||
|
assert dist < 0.001
|
||||||
|
|
||||||
|
|
||||||
|
def test_large_values(db):
|
||||||
|
"""Insert vectors with very large and small values."""
|
||||||
|
db.execute("CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf())")
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1e6, 1e6, 1e6, 1e6])]
|
||||||
|
)
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([1e-6, 1e-6, 1e-6, 1e-6])]
|
||||||
|
)
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (3, ?)", [_f32([-1e6, -1e6, -1e6, -1e6])]
|
||||||
|
)
|
||||||
|
|
||||||
|
results = knn(db, [1e6, 1e6, 1e6, 1e6], 3)
|
||||||
|
assert results[0][0] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_row_compute_centroids(db):
|
||||||
|
"""Single row table, compute-centroids should still work."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=1))"
|
||||||
|
)
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 2, 3, 4])]
|
||||||
|
)
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
assert ivf_assigned_count(db) == 1
|
||||||
|
|
||||||
|
results = knn(db, [1, 2, 3, 4], 1)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0][0] == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Cell overflow (many vectors in one cell)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_cell_overflow_many_vectors(db):
|
||||||
|
"""Insert >64 vectors that all go to same centroid. Should create multiple cells."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=0))"
|
||||||
|
)
|
||||||
|
# Insert 100 very similar vectors
|
||||||
|
for i in range(100):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([1.0 + i * 0.001, 0, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set a single centroid so all vectors go there
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(t, v) VALUES ('set-centroid:0', ?)",
|
||||||
|
[_f32([1.0, 0, 0, 0])],
|
||||||
|
)
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('assign-vectors')")
|
||||||
|
|
||||||
|
assert ivf_assigned_count(db) == 100
|
||||||
|
|
||||||
|
# Should have more than 1 cell (64 max per cell)
|
||||||
|
cell_count = db.execute(
|
||||||
|
"SELECT count(*) FROM t_ivf_cells00 WHERE centroid_id = 0"
|
||||||
|
).fetchone()[0]
|
||||||
|
assert cell_count >= 2 # 100 / 64 = 2 cells needed
|
||||||
|
|
||||||
|
# All vectors should be queryable
|
||||||
|
results = knn(db, [1.0, 0, 0, 0], 100)
|
||||||
|
assert len(results) == 100
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Large batch with training
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_large_batch_with_training(db):
|
||||||
|
"""Insert 500, train, insert 500 more, verify total is 1000."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=16, nprobe=16))"
|
||||||
|
)
|
||||||
|
for i in range(500):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
for i in range(500, 1000):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ivf_total_vectors(db) == 1000
|
||||||
|
|
||||||
|
# KNN should still work
|
||||||
|
results = knn(db, [750.0, 0, 0, 0], 5)
|
||||||
|
assert len(results) == 5
|
||||||
|
assert results[0][0] == 750
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# KNN after interleaved insert/delete
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_knn_after_interleaved_insert_delete(db):
|
||||||
|
"""Insert 20, train, delete 10 closest to query, verify remaining."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4, nprobe=4))"
|
||||||
|
)
|
||||||
|
for i in range(20):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
# Delete rowids 0-9 (closest to query at 5.0)
|
||||||
|
for i in range(10):
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = ?", [i])
|
||||||
|
|
||||||
|
results = knn(db, [5.0, 0, 0, 0], 10)
|
||||||
|
rowids = {r[0] for r in results}
|
||||||
|
# None of the deleted rowids should appear
|
||||||
|
assert all(r >= 10 for r in rowids)
|
||||||
|
assert len(results) == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_knn_empty_centroids_after_deletes(db):
|
||||||
|
"""Some centroids may become empty after deletes. Should not crash."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4, nprobe=2))"
|
||||||
|
)
|
||||||
|
# Insert vectors clustered near 4 points
|
||||||
|
for i in range(40):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([float(i % 10) * 10, 0, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
# Delete a bunch, potentially emptying some centroids
|
||||||
|
for i in range(30):
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = ?", [i])
|
||||||
|
|
||||||
|
# Should not crash even with empty centroids
|
||||||
|
results = knn(db, [50.0, 0, 0, 0], 20)
|
||||||
|
assert len(results) <= 10 # only 10 left
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# KNN returns correct distances
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_knn_correct_distances(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
|
||||||
|
)
|
||||||
|
db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([0, 0, 0, 0])])
|
||||||
|
db.execute("INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([3, 0, 0, 0])])
|
||||||
|
db.execute("INSERT INTO t(rowid, v) VALUES (3, ?)", [_f32([0, 4, 0, 0])])
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
results = knn(db, [0, 0, 0, 0], 3)
|
||||||
|
result_map = {r[0]: r[1] for r in results}
|
||||||
|
|
||||||
|
# L2 distances (sqrt of sum of squared differences)
|
||||||
|
assert abs(result_map[1] - 0.0) < 0.01
|
||||||
|
assert abs(result_map[2] - 3.0) < 0.01 # sqrt(3^2) = 3
|
||||||
|
assert abs(result_map[3] - 4.0) < 0.01 # sqrt(4^2) = 4
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Delete in flat mode leaves no orphan rowid_map entries
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_flat_mode_rowid_map_count(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
|
||||||
|
)
|
||||||
|
for i in range(5):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = 0")
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = 2")
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = 4")
|
||||||
|
|
||||||
|
assert db.execute("SELECT count(*) FROM t_ivf_rowid_map00").fetchone()[0] == 2
|
||||||
|
assert ivf_unassigned_count(db) == 2
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Duplicate rowid insert
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_reinsert_as_update(db):
|
||||||
|
"""Simulate update via delete + insert on same rowid."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
|
||||||
|
)
|
||||||
|
db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 0, 0, 0])])
|
||||||
|
|
||||||
|
# Delete then re-insert as "update"
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = 1")
|
||||||
|
db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([99, 0, 0, 0])])
|
||||||
|
|
||||||
|
results = knn(db, [99, 0, 0, 0], 1)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0][0] == 1
|
||||||
|
assert results[0][1] < 0.01
|
||||||
|
|
||||||
|
|
||||||
|
def test_duplicate_rowid_insert_fails(db):
|
||||||
|
"""Inserting a duplicate rowid should fail with a constraint error."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
|
||||||
|
)
|
||||||
|
db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 0, 0, 0])])
|
||||||
|
|
||||||
|
result = exec(
|
||||||
|
db,
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (1, ?)",
|
||||||
|
[_f32([99, 0, 0, 0])],
|
||||||
|
)
|
||||||
|
assert "error" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Interleaved insert/delete with KNN correctness
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_interleaved_ops_correctness(db):
|
||||||
|
"""Complex sequence of inserts and deletes, verify KNN is always correct."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4, nprobe=4))"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Phase 1: Insert 50 vectors
|
||||||
|
for i in range(50):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
# Phase 2: Delete even-numbered rowids
|
||||||
|
for i in range(0, 50, 2):
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = ?", [i])
|
||||||
|
|
||||||
|
# Phase 3: Insert new vectors with higher rowids
|
||||||
|
for i in range(50, 75):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([float(i), 0, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Phase 4: Delete some of the new ones
|
||||||
|
for i in range(60, 70):
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = ?", [i])
|
||||||
|
|
||||||
|
# KNN query: should only find existing vectors
|
||||||
|
results = knn(db, [25.0, 0, 0, 0], 50)
|
||||||
|
rowids = {r[0] for r in results}
|
||||||
|
|
||||||
|
# Verify no deleted rowids appear
|
||||||
|
deleted = set(range(0, 50, 2)) | set(range(60, 70))
|
||||||
|
assert len(rowids & deleted) == 0
|
||||||
|
|
||||||
|
# Verify we get the right count (25 odd + 15 new - 10 deleted new = 30)
|
||||||
|
expected_alive = set(range(1, 50, 2)) | set(range(50, 60)) | set(range(70, 75))
|
||||||
|
assert rowids.issubset(expected_alive)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_update_vector_blocked(db):
|
||||||
|
"""UPDATE on a vector column with IVF index should error (index would become stale)."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(emb float[4] indexed by ivf(nlist=2))"
|
||||||
|
)
|
||||||
|
db.execute("INSERT INTO t(rowid, emb) VALUES (1, ?)", [_f32([1, 0, 0, 0])])
|
||||||
|
db.execute("INSERT INTO t(rowid, emb) VALUES (2, ?)", [_f32([0, 1, 0, 0])])
|
||||||
|
|
||||||
|
with pytest.raises(sqlite3.OperationalError, match="UPDATE on vector column.*not supported for IVF"):
|
||||||
|
db.execute("UPDATE t SET emb = ? WHERE rowid = 1", [_f32([0, 0, 1, 0])])
|
||||||
272
tests/test-ivf-quantization.py
Normal file
272
tests/test-ivf-quantization.py
Normal file
|
|
@ -0,0 +1,272 @@
|
||||||
|
import pytest
|
||||||
|
import sqlite3
|
||||||
|
from helpers import _f32, exec
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def db():
|
||||||
|
db = sqlite3.connect(":memory:")
|
||||||
|
db.row_factory = sqlite3.Row
|
||||||
|
db.enable_load_extension(True)
|
||||||
|
db.load_extension("dist/vec0")
|
||||||
|
db.enable_load_extension(False)
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Parser tests — quantizer and oversample options
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_quantizer_binary(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
"v float[768] indexed by ivf(nlist=64, quantizer=binary, oversample=10))"
|
||||||
|
)
|
||||||
|
tables = [
|
||||||
|
r[0]
|
||||||
|
for r in db.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
|
||||||
|
).fetchall()
|
||||||
|
]
|
||||||
|
assert "t_ivf_centroids00" in tables
|
||||||
|
assert "t_ivf_cells00" in tables
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_quantizer_int8(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
"v float[768] indexed by ivf(nlist=64, quantizer=int8))"
|
||||||
|
)
|
||||||
|
tables = [
|
||||||
|
r[0]
|
||||||
|
for r in db.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
|
||||||
|
).fetchall()
|
||||||
|
]
|
||||||
|
assert "t_ivf_centroids00" in tables
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_quantizer_none_explicit(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
"v float[768] indexed by ivf(quantizer=none))"
|
||||||
|
)
|
||||||
|
# Should work — same as no quantizer
|
||||||
|
tables = [
|
||||||
|
r[0]
|
||||||
|
for r in db.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
|
||||||
|
).fetchall()
|
||||||
|
]
|
||||||
|
assert "t_ivf_centroids00" in tables
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_quantizer_all_params(db):
|
||||||
|
"""All params together: nlist, nprobe, quantizer, oversample."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
"v float[768] distance_metric=cosine "
|
||||||
|
"indexed by ivf(nlist=128, nprobe=16, quantizer=int8, oversample=4))"
|
||||||
|
)
|
||||||
|
tables = [
|
||||||
|
r[0]
|
||||||
|
for r in db.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
|
||||||
|
).fetchall()
|
||||||
|
]
|
||||||
|
assert "t_ivf_centroids00" in tables
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_error_oversample_without_quantizer(db):
|
||||||
|
"""oversample > 1 without quantizer should error."""
|
||||||
|
result = exec(
|
||||||
|
db,
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
"v float[768] indexed by ivf(oversample=10))",
|
||||||
|
)
|
||||||
|
assert "error" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_error_unknown_quantizer(db):
|
||||||
|
result = exec(
|
||||||
|
db,
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
"v float[768] indexed by ivf(quantizer=pq))",
|
||||||
|
)
|
||||||
|
assert "error" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_oversample_1_without_quantizer_ok(db):
|
||||||
|
"""oversample=1 (default) is fine without quantizer."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
"v float[768] indexed by ivf(nlist=64))"
|
||||||
|
)
|
||||||
|
# Should succeed — oversample defaults to 1
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Functional tests — insert, train, query with quantized IVF
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_int8_insert_and_query(db):
|
||||||
|
"""int8 quantized IVF: insert, train, query."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
"v float[4] indexed by ivf(nlist=2, quantizer=int8, oversample=4))"
|
||||||
|
)
|
||||||
|
for i in range(20):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
# Should be able to query
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT rowid, distance FROM t WHERE v MATCH ? AND k = 5",
|
||||||
|
[_f32([10.0, 0, 0, 0])],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) == 5
|
||||||
|
# Top result should be close to 10
|
||||||
|
assert rows[0][0] in range(8, 13)
|
||||||
|
|
||||||
|
# Full vectors should be in _ivf_vectors table
|
||||||
|
fv_count = db.execute("SELECT count(*) FROM t_ivf_vectors00").fetchone()[0]
|
||||||
|
assert fv_count == 20
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_binary_insert_and_query(db):
|
||||||
|
"""Binary quantized IVF: insert, train, query."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
"v float[8] indexed by ivf(nlist=2, quantizer=binary, oversample=4))"
|
||||||
|
)
|
||||||
|
for i in range(20):
|
||||||
|
# Vectors with varying sign patterns
|
||||||
|
v = [(i * 0.1 - 1.0) + j * 0.3 for j in range(8)]
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32(v)]
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT rowid FROM t WHERE v MATCH ? AND k = 5",
|
||||||
|
[_f32([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5])],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) == 5
|
||||||
|
|
||||||
|
# Full vectors stored
|
||||||
|
fv_count = db.execute("SELECT count(*) FROM t_ivf_vectors00").fetchone()[0]
|
||||||
|
assert fv_count == 20
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_int8_cell_sizes_smaller(db):
|
||||||
|
"""Cell blobs should be smaller with int8 quantization."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
"v float[768] indexed by ivf(nlist=2, quantizer=int8, oversample=1))"
|
||||||
|
)
|
||||||
|
for i in range(10):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([float(x) / 768 for x in range(768)])],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cell vectors blob: 10 vectors at int8 = 10 * 768 = 7680 bytes
|
||||||
|
# vs float32 = 10 * 768 * 4 = 30720 bytes
|
||||||
|
# But cells have capacity 64, so blob = 64 * 768 = 49152 (int8) vs 64*768*4=196608 (float32)
|
||||||
|
blob_size = db.execute(
|
||||||
|
"SELECT length(vectors) FROM t_ivf_cells00 LIMIT 1"
|
||||||
|
).fetchone()[0]
|
||||||
|
# int8: 64 slots * 768 bytes = 49152
|
||||||
|
assert blob_size == 64 * 768
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_binary_cell_sizes_smaller(db):
|
||||||
|
"""Cell blobs should be much smaller with binary quantization."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
"v float[768] indexed by ivf(nlist=2, quantizer=binary, oversample=1))"
|
||||||
|
)
|
||||||
|
for i in range(10):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([float(x) / 768 for x in range(768)])],
|
||||||
|
)
|
||||||
|
|
||||||
|
blob_size = db.execute(
|
||||||
|
"SELECT length(vectors) FROM t_ivf_cells00 LIMIT 1"
|
||||||
|
).fetchone()[0]
|
||||||
|
# binary: 64 slots * 768/8 bytes = 6144
|
||||||
|
assert blob_size == 64 * (768 // 8)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_int8_oversample_improves_recall(db):
|
||||||
|
"""Oversample re-ranking should improve recall over oversample=1."""
|
||||||
|
# Create two tables: one with oversample=1, one with oversample=10
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t1 USING vec0("
|
||||||
|
"v float[4] indexed by ivf(nlist=4, quantizer=int8, oversample=1))"
|
||||||
|
)
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t2 USING vec0("
|
||||||
|
"v float[4] indexed by ivf(nlist=4, quantizer=int8, oversample=10))"
|
||||||
|
)
|
||||||
|
for i in range(100):
|
||||||
|
v = _f32([i * 0.1, (i * 0.1) % 3, (i * 0.3) % 5, i * 0.01])
|
||||||
|
db.execute("INSERT INTO t1(rowid, v) VALUES (?, ?)", [i, v])
|
||||||
|
db.execute("INSERT INTO t2(rowid, v) VALUES (?, ?)", [i, v])
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t1(t1) VALUES ('compute-centroids')")
|
||||||
|
db.execute("INSERT INTO t2(t2) VALUES ('compute-centroids')")
|
||||||
|
db.execute("INSERT INTO t1(t1) VALUES ('nprobe=4')")
|
||||||
|
db.execute("INSERT INTO t2(t2) VALUES ('nprobe=4')")
|
||||||
|
|
||||||
|
query = _f32([5.0, 1.5, 2.5, 0.5])
|
||||||
|
r1 = db.execute("SELECT rowid FROM t1 WHERE v MATCH ? AND k=10", [query]).fetchall()
|
||||||
|
r2 = db.execute("SELECT rowid FROM t2 WHERE v MATCH ? AND k=10", [query]).fetchall()
|
||||||
|
|
||||||
|
# Both should return 10 results
|
||||||
|
assert len(r1) == 10
|
||||||
|
assert len(r2) == 10
|
||||||
|
# oversample=10 should have at least as good recall (same or better ordering)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_quantized_delete(db):
|
||||||
|
"""Delete should remove from both cells and _ivf_vectors."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
"v float[4] indexed by ivf(nlist=2, quantizer=int8, oversample=1))"
|
||||||
|
)
|
||||||
|
for i in range(10):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
assert db.execute("SELECT count(*) FROM t_ivf_vectors00").fetchone()[0] == 10
|
||||||
|
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = 5")
|
||||||
|
# _ivf_vectors should have 9 rows
|
||||||
|
assert db.execute("SELECT count(*) FROM t_ivf_vectors00").fetchone()[0] == 9
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_binary_rejects_non_multiple_of_8_dims(db):
|
||||||
|
"""Binary quantizer requires dimensions divisible by 8."""
|
||||||
|
with pytest.raises(sqlite3.OperationalError):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0("
|
||||||
|
" v float[12] indexed by ivf(quantizer=binary)"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dimensions divisible by 8 should work
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t2 USING vec0("
|
||||||
|
" v float[16] indexed by ivf(quantizer=binary)"
|
||||||
|
")"
|
||||||
|
)
|
||||||
426
tests/test-ivf.py
Normal file
426
tests/test-ivf.py
Normal file
|
|
@ -0,0 +1,426 @@
|
||||||
|
import pytest
|
||||||
|
import sqlite3
|
||||||
|
import struct
|
||||||
|
import math
|
||||||
|
from helpers import _f32, exec
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def db():
|
||||||
|
db = sqlite3.connect(":memory:")
|
||||||
|
db.row_factory = sqlite3.Row
|
||||||
|
db.enable_load_extension(True)
|
||||||
|
db.load_extension("dist/vec0")
|
||||||
|
db.enable_load_extension(False)
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
def ivf_total_vectors(db, table="t", col=0):
|
||||||
|
"""Count total vectors across all IVF cells."""
|
||||||
|
return db.execute(
|
||||||
|
f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d}"
|
||||||
|
).fetchone()[0]
|
||||||
|
|
||||||
|
|
||||||
|
def ivf_unassigned_count(db, table="t", col=0):
|
||||||
|
"""Count vectors in unassigned cells (centroid_id=-1)."""
|
||||||
|
return db.execute(
|
||||||
|
f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d} WHERE centroid_id = -1"
|
||||||
|
).fetchone()[0]
|
||||||
|
|
||||||
|
|
||||||
|
def ivf_assigned_count(db, table="t", col=0):
|
||||||
|
"""Count vectors in trained cells (centroid_id >= 0)."""
|
||||||
|
return db.execute(
|
||||||
|
f"SELECT COALESCE(SUM(n_vectors), 0) FROM {table}_ivf_cells{col:02d} WHERE centroid_id >= 0"
|
||||||
|
).fetchone()[0]
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Parser tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_create_defaults(db):
|
||||||
|
"""ivf() with no args uses defaults."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf())"
|
||||||
|
)
|
||||||
|
tables = [
|
||||||
|
r[0]
|
||||||
|
for r in db.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
|
||||||
|
).fetchall()
|
||||||
|
]
|
||||||
|
assert "t_ivf_centroids00" in tables
|
||||||
|
assert "t_ivf_cells00" in tables
|
||||||
|
assert "t_ivf_rowid_map00" in tables
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_create_custom_params(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=64, nprobe=8))"
|
||||||
|
)
|
||||||
|
tables = [
|
||||||
|
r[0]
|
||||||
|
for r in db.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
|
||||||
|
).fetchall()
|
||||||
|
]
|
||||||
|
assert "t_ivf_centroids00" in tables
|
||||||
|
assert "t_ivf_cells00" in tables
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_create_with_distance_metric(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] distance_metric=cosine indexed by ivf(nlist=16))"
|
||||||
|
)
|
||||||
|
tables = [
|
||||||
|
r[0]
|
||||||
|
for r in db.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY 1"
|
||||||
|
).fetchall()
|
||||||
|
]
|
||||||
|
assert "t_ivf_centroids00" in tables
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_create_error_unknown_key(db):
|
||||||
|
result = exec(
|
||||||
|
db,
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(bogus=1))",
|
||||||
|
)
|
||||||
|
assert "error" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_create_error_nprobe_gt_nlist(db):
|
||||||
|
result = exec(
|
||||||
|
db,
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4, nprobe=10))",
|
||||||
|
)
|
||||||
|
assert "error" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Shadow table tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_shadow_tables_created(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=8))"
|
||||||
|
)
|
||||||
|
trained = db.execute(
|
||||||
|
"SELECT value FROM t_info WHERE key = 'ivf_trained_0'"
|
||||||
|
).fetchone()[0]
|
||||||
|
assert str(trained) == "0"
|
||||||
|
|
||||||
|
# No cells yet (created lazily on first insert)
|
||||||
|
count = db.execute(
|
||||||
|
"SELECT count(*) FROM t_ivf_cells00"
|
||||||
|
).fetchone()[0]
|
||||||
|
assert count == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_drop_cleans_up(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
|
||||||
|
)
|
||||||
|
db.execute("DROP TABLE t")
|
||||||
|
tables = [
|
||||||
|
r[0]
|
||||||
|
for r in db.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table'"
|
||||||
|
).fetchall()
|
||||||
|
]
|
||||||
|
assert not any("ivf" in t for t in tables)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Insert tests (flat mode)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_insert_flat_mode(db):
|
||||||
|
"""Before training, vectors go to unassigned cell."""
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
|
||||||
|
)
|
||||||
|
db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 2, 3, 4])])
|
||||||
|
db.execute("INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([5, 6, 7, 8])])
|
||||||
|
|
||||||
|
assert ivf_unassigned_count(db) == 2
|
||||||
|
assert ivf_assigned_count(db) == 0
|
||||||
|
|
||||||
|
# rowid_map should have 2 entries
|
||||||
|
assert db.execute("SELECT count(*) FROM t_ivf_rowid_map00").fetchone()[0] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_delete_flat_mode(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
|
||||||
|
)
|
||||||
|
db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 2, 3, 4])])
|
||||||
|
db.execute("INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([5, 6, 7, 8])])
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = 1")
|
||||||
|
|
||||||
|
assert ivf_unassigned_count(db) == 1
|
||||||
|
assert db.execute("SELECT count(*) FROM t_ivf_rowid_map00").fetchone()[0] == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# KNN flat mode tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_knn_flat_mode(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
|
||||||
|
)
|
||||||
|
db.execute("INSERT INTO t(rowid, v) VALUES (1, ?)", [_f32([1, 0, 0, 0])])
|
||||||
|
db.execute("INSERT INTO t(rowid, v) VALUES (2, ?)", [_f32([2, 0, 0, 0])])
|
||||||
|
db.execute("INSERT INTO t(rowid, v) VALUES (3, ?)", [_f32([9, 0, 0, 0])])
|
||||||
|
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT rowid, distance FROM t WHERE v MATCH ? AND k = 2",
|
||||||
|
[_f32([1.5, 0, 0, 0])],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) == 2
|
||||||
|
rowids = {r[0] for r in rows}
|
||||||
|
assert rowids == {1, 2}
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_knn_flat_empty(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
|
||||||
|
)
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT rowid FROM t WHERE v MATCH ? AND k = 5",
|
||||||
|
[_f32([1, 0, 0, 0])],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# compute-centroids tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_compute_centroids(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4))"
|
||||||
|
)
|
||||||
|
for i in range(40):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)",
|
||||||
|
[i, _f32([i % 10, i // 10, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ivf_unassigned_count(db) == 40
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
# After training: unassigned cell should be gone (or empty), vectors in trained cells
|
||||||
|
assert ivf_unassigned_count(db) == 0
|
||||||
|
assert ivf_assigned_count(db) == 40
|
||||||
|
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 4
|
||||||
|
trained = db.execute(
|
||||||
|
"SELECT value FROM t_info WHERE key='ivf_trained_0'"
|
||||||
|
).fetchone()[0]
|
||||||
|
assert str(trained) == "1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_compute_centroids_recompute(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2))"
|
||||||
|
)
|
||||||
|
for i in range(20):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
|
||||||
|
assert ivf_assigned_count(db) == 20
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Insert after training (assigned mode)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_insert_after_training(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2))"
|
||||||
|
)
|
||||||
|
for i in range(20):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (100, ?)", [_f32([5, 0, 0, 0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should be in a trained cell, not unassigned
|
||||||
|
row = db.execute(
|
||||||
|
"SELECT m.cell_id, c.centroid_id FROM t_ivf_rowid_map00 m "
|
||||||
|
"JOIN t_ivf_cells00 c ON c.rowid = m.cell_id "
|
||||||
|
"WHERE m.rowid = 100"
|
||||||
|
).fetchone()
|
||||||
|
assert row is not None
|
||||||
|
assert row[1] >= 0 # centroid_id >= 0 means trained cell
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# KNN after training (IVF probe mode)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_knn_after_training(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=4, nprobe=4))"
|
||||||
|
)
|
||||||
|
for i in range(100):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT rowid, distance FROM t WHERE v MATCH ? AND k = 5",
|
||||||
|
[_f32([50.0, 0, 0, 0])],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) == 5
|
||||||
|
assert rows[0][0] == 50
|
||||||
|
assert rows[0][1] < 0.01
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_knn_k_larger_than_n(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2, nprobe=2))"
|
||||||
|
)
|
||||||
|
for i in range(5):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT rowid FROM t WHERE v MATCH ? AND k = 100",
|
||||||
|
[_f32([0, 0, 0, 0])],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) == 5
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Manual centroid import (set-centroid, assign-vectors)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_centroid_and_assign(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=0))"
|
||||||
|
)
|
||||||
|
for i in range(20):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(t, v) VALUES ('set-centroid:0', ?)",
|
||||||
|
[_f32([5, 0, 0, 0])],
|
||||||
|
)
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(t, v) VALUES ('set-centroid:1', ?)",
|
||||||
|
[_f32([15, 0, 0, 0])],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('assign-vectors')")
|
||||||
|
|
||||||
|
assert ivf_unassigned_count(db) == 0
|
||||||
|
assert ivf_assigned_count(db) == 20
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# clear-centroids
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_clear_centroids(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2))"
|
||||||
|
)
|
||||||
|
for i in range(20):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 2
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('clear-centroids')")
|
||||||
|
assert db.execute("SELECT count(*) FROM t_ivf_centroids00").fetchone()[0] == 0
|
||||||
|
assert ivf_unassigned_count(db) == 20
|
||||||
|
trained = db.execute(
|
||||||
|
"SELECT value FROM t_info WHERE key='ivf_trained_0'"
|
||||||
|
).fetchone()[0]
|
||||||
|
assert str(trained) == "0"
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Delete after training
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_delete_after_training(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=2))"
|
||||||
|
)
|
||||||
|
for i in range(10):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
assert ivf_assigned_count(db) == 10
|
||||||
|
|
||||||
|
db.execute("DELETE FROM t WHERE rowid = 5")
|
||||||
|
assert ivf_assigned_count(db) == 9
|
||||||
|
assert db.execute("SELECT count(*) FROM t_ivf_rowid_map00").fetchone()[0] == 9
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Recall test
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_ivf_recall_nprobe_equals_nlist(db):
|
||||||
|
db.execute(
|
||||||
|
"CREATE VIRTUAL TABLE t USING vec0(v float[4] indexed by ivf(nlist=8, nprobe=8))"
|
||||||
|
)
|
||||||
|
for i in range(100):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO t(rowid, v) VALUES (?, ?)", [i, _f32([i, 0, 0, 0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
db.execute("INSERT INTO t(t) VALUES ('compute-centroids')")
|
||||||
|
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT rowid FROM t WHERE v MATCH ? AND k = 10",
|
||||||
|
[_f32([50.0, 0, 0, 0])],
|
||||||
|
).fetchall()
|
||||||
|
rowids = {r[0] for r in rows}
|
||||||
|
|
||||||
|
# 45 and 55 are equidistant from 50, so either may appear in top 10
|
||||||
|
assert 50 in rowids
|
||||||
|
assert len(rowids) == 10
|
||||||
|
assert all(45 <= r <= 55 for r in rowids)
|
||||||
138
tests/test-legacy-compat.py
Normal file
138
tests/test-legacy-compat.py
Normal file
|
|
@ -0,0 +1,138 @@
|
||||||
|
"""Backwards compatibility tests: current sqlite-vec reading legacy databases.
|
||||||
|
|
||||||
|
The fixture file tests/fixtures/legacy-v0.1.6.db was generated by
|
||||||
|
tests/generate_legacy_db.py using sqlite-vec v0.1.6. These tests verify
|
||||||
|
that the current version can fully read, query, insert into, and delete
|
||||||
|
from tables created by older versions.
|
||||||
|
"""
|
||||||
|
import sqlite3
|
||||||
|
import struct
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
FIXTURE_PATH = os.path.join(os.path.dirname(__file__), "fixtures", "legacy-v0.1.6.db")
|
||||||
|
|
||||||
|
|
||||||
|
def _f32(vals):
|
||||||
|
return struct.pack(f"{len(vals)}f", *vals)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def legacy_db(tmp_path):
|
||||||
|
"""Copy the legacy fixture to a temp dir so tests can modify it."""
|
||||||
|
if not os.path.exists(FIXTURE_PATH):
|
||||||
|
pytest.skip("Legacy fixture not found — run: uv run --script tests/generate_legacy_db.py")
|
||||||
|
db_path = str(tmp_path / "legacy.db")
|
||||||
|
shutil.copy2(FIXTURE_PATH, db_path)
|
||||||
|
db = sqlite3.connect(db_path)
|
||||||
|
db.row_factory = sqlite3.Row
|
||||||
|
db.enable_load_extension(True)
|
||||||
|
db.load_extension("dist/vec0")
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_select_count(legacy_db):
|
||||||
|
"""Basic SELECT count should return all rows."""
|
||||||
|
count = legacy_db.execute("SELECT count(*) FROM legacy_vectors").fetchone()[0]
|
||||||
|
assert count == 50
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_point_query(legacy_db):
|
||||||
|
"""Point query by rowid should return correct vector."""
|
||||||
|
row = legacy_db.execute(
|
||||||
|
"SELECT rowid, emb FROM legacy_vectors WHERE rowid = 1"
|
||||||
|
).fetchone()
|
||||||
|
assert row["rowid"] == 1
|
||||||
|
vec = struct.unpack("4f", row["emb"])
|
||||||
|
assert vec[0] == pytest.approx(1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_knn(legacy_db):
|
||||||
|
"""KNN query on legacy table should return correct results."""
|
||||||
|
query = _f32([1.0, 0.0, 0.0, 0.0])
|
||||||
|
rows = legacy_db.execute(
|
||||||
|
"SELECT rowid, distance FROM legacy_vectors "
|
||||||
|
"WHERE emb MATCH ? AND k = 5",
|
||||||
|
[query],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) == 5
|
||||||
|
assert rows[0]["rowid"] == 1
|
||||||
|
assert rows[0]["distance"] == pytest.approx(0.0)
|
||||||
|
for i in range(len(rows) - 1):
|
||||||
|
assert rows[i]["distance"] <= rows[i + 1]["distance"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_insert(legacy_db):
|
||||||
|
"""INSERT into legacy table should work."""
|
||||||
|
legacy_db.execute(
|
||||||
|
"INSERT INTO legacy_vectors(rowid, emb) VALUES (100, ?)",
|
||||||
|
[_f32([100.0, 0.0, 0.0, 0.0])],
|
||||||
|
)
|
||||||
|
count = legacy_db.execute("SELECT count(*) FROM legacy_vectors").fetchone()[0]
|
||||||
|
assert count == 51
|
||||||
|
|
||||||
|
rows = legacy_db.execute(
|
||||||
|
"SELECT rowid FROM legacy_vectors WHERE emb MATCH ? AND k = 1",
|
||||||
|
[_f32([100.0, 0.0, 0.0, 0.0])],
|
||||||
|
).fetchall()
|
||||||
|
assert rows[0]["rowid"] == 100
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_delete(legacy_db):
|
||||||
|
"""DELETE from legacy table should work."""
|
||||||
|
legacy_db.execute("DELETE FROM legacy_vectors WHERE rowid = 1")
|
||||||
|
count = legacy_db.execute("SELECT count(*) FROM legacy_vectors").fetchone()[0]
|
||||||
|
assert count == 49
|
||||||
|
|
||||||
|
rows = legacy_db.execute(
|
||||||
|
"SELECT rowid FROM legacy_vectors WHERE emb MATCH ? AND k = 5",
|
||||||
|
[_f32([1.0, 0.0, 0.0, 0.0])],
|
||||||
|
).fetchall()
|
||||||
|
assert 1 not in [r["rowid"] for r in rows]
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_fullscan(legacy_db):
|
||||||
|
"""Full scan should work."""
|
||||||
|
rows = legacy_db.execute(
|
||||||
|
"SELECT rowid FROM legacy_vectors ORDER BY rowid LIMIT 5"
|
||||||
|
).fetchall()
|
||||||
|
assert [r["rowid"] for r in rows] == [1, 2, 3, 4, 5]
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_name_conflict_table(legacy_db):
|
||||||
|
"""Legacy table where column name == table name should work.
|
||||||
|
|
||||||
|
The v0.1.6 DB has: CREATE VIRTUAL TABLE emb USING vec0(emb float[4])
|
||||||
|
Current code should NOT add the command column for this table
|
||||||
|
(detected via _info version check), avoiding the name conflict.
|
||||||
|
"""
|
||||||
|
count = legacy_db.execute("SELECT count(*) FROM emb").fetchone()[0]
|
||||||
|
assert count == 10
|
||||||
|
|
||||||
|
rows = legacy_db.execute(
|
||||||
|
"SELECT rowid, distance FROM emb WHERE emb MATCH ? AND k = 3",
|
||||||
|
[_f32([1.0, 0.0, 0.0, 0.0])],
|
||||||
|
).fetchall()
|
||||||
|
assert len(rows) == 3
|
||||||
|
assert rows[0]["rowid"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_name_conflict_insert_delete(legacy_db):
|
||||||
|
"""INSERT and DELETE on legacy name-conflict table."""
|
||||||
|
legacy_db.execute(
|
||||||
|
"INSERT INTO emb(rowid, emb) VALUES (100, ?)",
|
||||||
|
[_f32([100.0, 0.0, 0.0, 0.0])],
|
||||||
|
)
|
||||||
|
assert legacy_db.execute("SELECT count(*) FROM emb").fetchone()[0] == 11
|
||||||
|
|
||||||
|
legacy_db.execute("DELETE FROM emb WHERE rowid = 5")
|
||||||
|
assert legacy_db.execute("SELECT count(*) FROM emb").fetchone()[0] == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_no_command_column(legacy_db):
|
||||||
|
"""Legacy tables should NOT have the command column."""
|
||||||
|
with pytest.raises(sqlite3.OperationalError):
|
||||||
|
legacy_db.execute(
|
||||||
|
"INSERT INTO legacy_vectors(legacy_vectors) VALUES ('some_command')"
|
||||||
|
)
|
||||||
|
|
@ -119,151 +119,9 @@ FUNCTIONS = [
|
||||||
MODULES = [
|
MODULES = [
|
||||||
"vec0",
|
"vec0",
|
||||||
"vec_each",
|
"vec_each",
|
||||||
# "vec_static_blob_entries",
|
|
||||||
# "vec_static_blobs",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def register_numpy(db, name: str, array):
|
|
||||||
ptr = array.__array_interface__["data"][0]
|
|
||||||
nvectors, dimensions = array.__array_interface__["shape"]
|
|
||||||
element_type = array.__array_interface__["typestr"]
|
|
||||||
|
|
||||||
assert element_type == "<f4"
|
|
||||||
|
|
||||||
name_escaped = db.execute("select printf('%w', ?)", [name]).fetchone()[0]
|
|
||||||
|
|
||||||
db.execute(
|
|
||||||
"""
|
|
||||||
insert into temp.vec_static_blobs(name, data)
|
|
||||||
select ?, vec_static_blob_from_raw(?, ?, ?, ?)
|
|
||||||
""",
|
|
||||||
[name, ptr, element_type, dimensions, nvectors],
|
|
||||||
)
|
|
||||||
|
|
||||||
db.execute(
|
|
||||||
f'create virtual table "{name_escaped}" using vec_static_blob_entries({name_escaped})'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_vec_static_blob_entries():
|
|
||||||
db = connect(EXT_PATH, extra_entrypoint="sqlite3_vec_static_blobs_init")
|
|
||||||
|
|
||||||
x = np.array([[0.1, 0.2, 0.3, 0.4], [0.9, 0.8, 0.7, 0.6]], dtype=np.float32)
|
|
||||||
y = np.array([[0.2, 0.3], [0.9, 0.8], [0.6, 0.5]], dtype=np.float32)
|
|
||||||
z = np.array(
|
|
||||||
[
|
|
||||||
[0.1, 0.1, 0.1, 0.1],
|
|
||||||
[0.2, 0.2, 0.2, 0.2],
|
|
||||||
[0.3, 0.3, 0.3, 0.3],
|
|
||||||
[0.4, 0.4, 0.4, 0.4],
|
|
||||||
[0.5, 0.5, 0.5, 0.5],
|
|
||||||
],
|
|
||||||
dtype=np.float32,
|
|
||||||
)
|
|
||||||
|
|
||||||
register_numpy(db, "x", x)
|
|
||||||
register_numpy(db, "y", y)
|
|
||||||
register_numpy(db, "z", z)
|
|
||||||
assert execute_all(
|
|
||||||
db, "select *, dimensions, count from temp.vec_static_blobs;"
|
|
||||||
) == [
|
|
||||||
{
|
|
||||||
"count": 2,
|
|
||||||
"data": None,
|
|
||||||
"dimensions": 4,
|
|
||||||
"name": "x",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"count": 3,
|
|
||||||
"data": None,
|
|
||||||
"dimensions": 2,
|
|
||||||
"name": "y",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"count": 5,
|
|
||||||
"data": None,
|
|
||||||
"dimensions": 4,
|
|
||||||
"name": "z",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
assert execute_all(db, "select vec_to_json(vector) from x;") == [
|
|
||||||
{
|
|
||||||
"vec_to_json(vector)": "[0.100000,0.200000,0.300000,0.400000]",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"vec_to_json(vector)": "[0.900000,0.800000,0.700000,0.600000]",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
assert execute_all(db, "select (vector) from y limit 2;") == [
|
|
||||||
{
|
|
||||||
"vector": b"\xcd\xccL>\x9a\x99\x99>",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"vector": b"fff?\xcd\xccL?",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
assert execute_all(db, "select rowid, (vector) from z") == [
|
|
||||||
{
|
|
||||||
"rowid": 0,
|
|
||||||
"vector": b"\xcd\xcc\xcc=\xcd\xcc\xcc=\xcd\xcc\xcc=\xcd\xcc\xcc=",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"rowid": 1,
|
|
||||||
"vector": b"\xcd\xccL>\xcd\xccL>\xcd\xccL>\xcd\xccL>",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"rowid": 2,
|
|
||||||
"vector": b"\x9a\x99\x99>\x9a\x99\x99>\x9a\x99\x99>\x9a\x99\x99>",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"rowid": 3,
|
|
||||||
"vector": b"\xcd\xcc\xcc>\xcd\xcc\xcc>\xcd\xcc\xcc>\xcd\xcc\xcc>",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"rowid": 4,
|
|
||||||
"vector": b"\x00\x00\x00?\x00\x00\x00?\x00\x00\x00?\x00\x00\x00?",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
assert execute_all(
|
|
||||||
db,
|
|
||||||
"select rowid, vec_to_json(vector) as v from z where vector match ? and k = 3 order by distance;",
|
|
||||||
[np.array([0.3, 0.3, 0.3, 0.3], dtype=np.float32)],
|
|
||||||
) == [
|
|
||||||
{
|
|
||||||
"rowid": 2,
|
|
||||||
"v": "[0.300000,0.300000,0.300000,0.300000]",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"rowid": 3,
|
|
||||||
"v": "[0.400000,0.400000,0.400000,0.400000]",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"rowid": 1,
|
|
||||||
"v": "[0.200000,0.200000,0.200000,0.200000]",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
assert execute_all(
|
|
||||||
db,
|
|
||||||
"select rowid, vec_to_json(vector) as v from z where vector match ? and k = 3 order by distance;",
|
|
||||||
[np.array([0.6, 0.6, 0.6, 0.6], dtype=np.float32)],
|
|
||||||
) == [
|
|
||||||
{
|
|
||||||
"rowid": 4,
|
|
||||||
"v": "[0.500000,0.500000,0.500000,0.500000]",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"rowid": 3,
|
|
||||||
"v": "[0.400000,0.400000,0.400000,0.400000]",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"rowid": 2,
|
|
||||||
"v": "[0.300000,0.300000,0.300000,0.300000]",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_limits():
|
def test_limits():
|
||||||
db = connect(EXT_PATH)
|
db = connect(EXT_PATH)
|
||||||
with _raises(
|
with _raises(
|
||||||
|
|
@ -507,6 +365,34 @@ def test_vec_distance_l1():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_vec_reject_nan_inf():
|
||||||
|
"""NaN and Inf in float32 vectors should be rejected."""
|
||||||
|
import struct, math
|
||||||
|
|
||||||
|
# NaN via blob
|
||||||
|
nan_blob = struct.pack("4f", 1.0, float("nan"), 3.0, 4.0)
|
||||||
|
with pytest.raises(sqlite3.OperationalError, match="NaN"):
|
||||||
|
db.execute("SELECT vec_length(?)", [nan_blob])
|
||||||
|
|
||||||
|
# Inf via blob
|
||||||
|
inf_blob = struct.pack("4f", 1.0, float("inf"), 3.0, 4.0)
|
||||||
|
with pytest.raises(sqlite3.OperationalError, match="Inf"):
|
||||||
|
db.execute("SELECT vec_length(?)", [inf_blob])
|
||||||
|
|
||||||
|
# -Inf via blob
|
||||||
|
ninf_blob = struct.pack("4f", 1.0, float("-inf"), 3.0, 4.0)
|
||||||
|
with pytest.raises(sqlite3.OperationalError, match="Inf"):
|
||||||
|
db.execute("SELECT vec_length(?)", [ninf_blob])
|
||||||
|
|
||||||
|
# NaN via JSON
|
||||||
|
# Note: JSON doesn't have NaN literal, but strtod may parse "NaN"
|
||||||
|
# This tests the blob path which is the primary input method
|
||||||
|
|
||||||
|
# Valid vectors still work
|
||||||
|
ok_blob = struct.pack("4f", 1.0, 2.0, 3.0, 4.0)
|
||||||
|
assert db.execute("SELECT vec_length(?)", [ok_blob]).fetchone()[0] == 4
|
||||||
|
|
||||||
|
|
||||||
def test_vec_distance_l2():
|
def test_vec_distance_l2():
|
||||||
vec_distance_l2 = lambda *args, a="?", b="?": db.execute(
|
vec_distance_l2 = lambda *args, a="?", b="?": db.execute(
|
||||||
f"select vec_distance_l2({a}, {b})", args
|
f"select vec_distance_l2({a}, {b})", args
|
||||||
|
|
@ -523,11 +409,17 @@ def test_vec_distance_l2():
|
||||||
|
|
||||||
x = vec_distance_l2(a_sql_t, b_sql_t, a=transform, b=transform)
|
x = vec_distance_l2(a_sql_t, b_sql_t, a=transform, b=transform)
|
||||||
y = npy_l2(np.array(a), np.array(b))
|
y = npy_l2(np.array(a), np.array(b))
|
||||||
assert isclose(x, y, abs_tol=1e-6)
|
assert isclose(x, y, rel_tol=1e-5, abs_tol=1e-6)
|
||||||
|
|
||||||
check([1.2, 0.1], [0.4, -0.4])
|
check([1.2, 0.1], [0.4, -0.4])
|
||||||
check([-1.2, -0.1], [-0.4, 0.4])
|
check([-1.2, -0.1], [-0.4, 0.4])
|
||||||
check([1, 2, 3], [-9, -8, -7], dtype=np.int8)
|
check([1, 2, 3], [-9, -8, -7], dtype=np.int8)
|
||||||
|
# Extreme int8 values: diff=255, squared=65025 which overflows i16
|
||||||
|
# This tests the NEON widening multiply fix (slight float rounding expected)
|
||||||
|
check([-128] * 8, [127] * 8, dtype=np.int8)
|
||||||
|
check([-128] * 16, [127] * 16, dtype=np.int8)
|
||||||
|
check([-128, 127, -128, 127, -128, 127, -128, 127],
|
||||||
|
[127, -128, 127, -128, 127, -128, 127, -128], dtype=np.int8)
|
||||||
|
|
||||||
|
|
||||||
def test_vec_length():
|
def test_vec_length():
|
||||||
|
|
@ -1618,231 +1510,6 @@ def test_vec_each():
|
||||||
vec_each_f32(None)
|
vec_each_f32(None)
|
||||||
|
|
||||||
|
|
||||||
import io
|
|
||||||
|
|
||||||
|
|
||||||
def to_npy(arr):
|
|
||||||
buf = io.BytesIO()
|
|
||||||
np.save(buf, arr)
|
|
||||||
buf.seek(0)
|
|
||||||
return buf.read()
|
|
||||||
|
|
||||||
|
|
||||||
def test_vec_npy_each():
|
|
||||||
db = connect(EXT_PATH, extra_entrypoint="sqlite3_vec_numpy_init")
|
|
||||||
vec_npy_each = lambda *args: execute_all(
|
|
||||||
db, "select rowid, * from vec_npy_each(?)", args
|
|
||||||
)
|
|
||||||
assert vec_npy_each(to_npy(np.array([1.1, 2.2, 3.3], dtype=np.float32))) == [
|
|
||||||
{
|
|
||||||
"rowid": 0,
|
|
||||||
"vector": _f32([1.1, 2.2, 3.3]),
|
|
||||||
},
|
|
||||||
]
|
|
||||||
assert vec_npy_each(to_npy(np.array([[1.1, 2.2, 3.3]], dtype=np.float32))) == [
|
|
||||||
{
|
|
||||||
"rowid": 0,
|
|
||||||
"vector": _f32([1.1, 2.2, 3.3]),
|
|
||||||
},
|
|
||||||
]
|
|
||||||
assert vec_npy_each(
|
|
||||||
to_npy(np.array([[1.1, 2.2, 3.3], [9.9, 8.8, 7.7]], dtype=np.float32))
|
|
||||||
) == [
|
|
||||||
{
|
|
||||||
"rowid": 0,
|
|
||||||
"vector": _f32([1.1, 2.2, 3.3]),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"rowid": 1,
|
|
||||||
"vector": _f32([9.9, 8.8, 7.7]),
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
assert vec_npy_each(to_npy(np.array([], dtype=np.float32))) == []
|
|
||||||
|
|
||||||
|
|
||||||
def test_vec_npy_each_errors():
|
|
||||||
db = connect(EXT_PATH, extra_entrypoint="sqlite3_vec_numpy_init")
|
|
||||||
vec_npy_each = lambda *args: execute_all(
|
|
||||||
db, "select rowid, * from vec_npy_each(?)", args
|
|
||||||
)
|
|
||||||
|
|
||||||
full = b"\x93NUMPY\x01\x00v\x00{'descr': '<f4', 'fortran_order': False, 'shape': (2, 4), } \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
|
||||||
|
|
||||||
# EVIDENCE-OF: V03312_20150 numpy validation too short
|
|
||||||
with _raises("numpy array too short"):
|
|
||||||
vec_npy_each(b"")
|
|
||||||
# EVIDENCE-OF: V11954_28792 numpy validate magic
|
|
||||||
with _raises("numpy array does not contain the 'magic' header"):
|
|
||||||
vec_npy_each(b"\x93NUMPX\x01\x00v\x00")
|
|
||||||
|
|
||||||
with _raises("numpy array header length is invalid"):
|
|
||||||
vec_npy_each(b"\x93NUMPY\x01\x00v\x00")
|
|
||||||
|
|
||||||
with _raises("numpy header did not start with '{'"):
|
|
||||||
vec_npy_each(
|
|
||||||
b"\x93NUMPY\x01\x00v\x00c'descr': '<f4', 'fortran_order': False, 'shape': (2, 4), } \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
|
||||||
)
|
|
||||||
|
|
||||||
with _raises("expected key in numpy header"):
|
|
||||||
vec_npy_each(
|
|
||||||
b"\x93NUMPY\x01\x00v\x00{ \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
|
||||||
)
|
|
||||||
|
|
||||||
with _raises("expected a string as key in numpy header"):
|
|
||||||
vec_npy_each(
|
|
||||||
b"\x93NUMPY\x01\x00v\x00{False: '<f4', 'fortran_order': False, 'shape': (2, 4), } \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
|
||||||
)
|
|
||||||
|
|
||||||
with _raises("expected a ':' after key in numpy header"):
|
|
||||||
vec_npy_each(
|
|
||||||
b"\x93NUMPY\x01\x00v\x00{'descr' \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
|
||||||
)
|
|
||||||
with _raises("expected a ':' after key in numpy header"):
|
|
||||||
vec_npy_each(
|
|
||||||
b"\x93NUMPY\x01\x00v\x00{'descr' False \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
|
||||||
)
|
|
||||||
|
|
||||||
with _raises("expected a string value after 'descr' key"):
|
|
||||||
vec_npy_each(
|
|
||||||
b"\x93NUMPY\x01\x00v\x00{'descr': \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
|
||||||
)
|
|
||||||
|
|
||||||
with _raises("Only '<f4' values are supported in sqlite-vec numpy functions"):
|
|
||||||
vec_npy_each(
|
|
||||||
b"\x93NUMPY\x01\x00v\x00{'descr': '=f4', 'fortran_order': False, 'shape': (2, 4), } \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
|
||||||
)
|
|
||||||
|
|
||||||
with _raises(
|
|
||||||
"Only fortran_order = False is supported in sqlite-vec numpy functions"
|
|
||||||
):
|
|
||||||
vec_npy_each(
|
|
||||||
b"\x93NUMPY\x01\x00v\x00{'descr': '<f4', 'fortran_order': True, 'shape': (2, 4), } \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
|
||||||
)
|
|
||||||
|
|
||||||
with _raises(
|
|
||||||
"Error parsing numpy array: Expected left parenthesis '(' after shape key"
|
|
||||||
):
|
|
||||||
vec_npy_each(
|
|
||||||
b"\x93NUMPY\x01\x00v\x00{'shape': 2, 'descr': '<f4', 'fortran_order': False, } \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
|
||||||
)
|
|
||||||
|
|
||||||
with _raises(
|
|
||||||
"Error parsing numpy array: Expected an initial number in shape value"
|
|
||||||
):
|
|
||||||
vec_npy_each(
|
|
||||||
b"\x93NUMPY\x01\x00v\x00{'shape': (, 'descr': '<f4', 'fortran_order': False, } \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
|
||||||
)
|
|
||||||
|
|
||||||
with _raises("Error parsing numpy array: Expected comma after first shape value"):
|
|
||||||
vec_npy_each(
|
|
||||||
b"\x93NUMPY\x01\x00v\x00{'shape': (2), 'descr': '<f4', 'fortran_order': False, } \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
|
||||||
)
|
|
||||||
|
|
||||||
with _raises(
|
|
||||||
"Error parsing numpy array: unexpected header EOF while parsing shape"
|
|
||||||
):
|
|
||||||
vec_npy_each(
|
|
||||||
b"\x93NUMPY\x01\x00v\x00{'shape': (2, \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
|
||||||
)
|
|
||||||
|
|
||||||
with _raises("Error parsing numpy array: unknown type in shape value"):
|
|
||||||
vec_npy_each(
|
|
||||||
b"\x93NUMPY\x01\x00v\x00{'shape': (2, 'nope' \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
|
||||||
)
|
|
||||||
|
|
||||||
with _raises(
|
|
||||||
"Error parsing numpy array: expected right parenthesis after shape value"
|
|
||||||
):
|
|
||||||
vec_npy_each(
|
|
||||||
b"\x93NUMPY\x01\x00v\x00{'shape': (2,4 ( \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
|
||||||
)
|
|
||||||
|
|
||||||
with _raises("Error parsing numpy array: unknown key in numpy header"):
|
|
||||||
vec_npy_each(
|
|
||||||
b"\x93NUMPY\x01\x00v\x00{'no': '<f4', 'fortran_order': False, 'shape': (2, 4), } \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
|
||||||
)
|
|
||||||
|
|
||||||
with _raises("Error parsing numpy array: unknown extra token after value"):
|
|
||||||
vec_npy_each(
|
|
||||||
b"\x93NUMPY\x01\x00v\x00{'descr': '<f4' 'asdf', 'fortran_order': False, 'shape': (2, 4), } \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
|
||||||
)
|
|
||||||
|
|
||||||
with _raises("numpy array error: Expected a data size of 32, found 31"):
|
|
||||||
vec_npy_each(
|
|
||||||
b"\x93NUMPY\x01\x00v\x00{'descr': '<f4', 'fortran_order': False, 'shape': (2, 4), } \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3"
|
|
||||||
)
|
|
||||||
|
|
||||||
# with _raises("XXX"):
|
|
||||||
# vec_npy_each(b"\x93NUMPY\x01\x00v\x00{'descr': '<f4', 'fortran_order': False, 'shape': (2, 4), } \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@")
|
|
||||||
|
|
||||||
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
|
|
||||||
def test_vec_npy_each_errors_files():
|
|
||||||
db = connect(EXT_PATH, extra_entrypoint="sqlite3_vec_numpy_init")
|
|
||||||
|
|
||||||
def vec_npy_each(data):
|
|
||||||
with tempfile.NamedTemporaryFile(delete_on_close=False) as f:
|
|
||||||
f.write(data)
|
|
||||||
f.close()
|
|
||||||
try:
|
|
||||||
return execute_all(
|
|
||||||
db, "select rowid, * from vec_npy_each(vec_npy_file(?))", [f.name]
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
f.close()
|
|
||||||
|
|
||||||
with _raises("Could not open numpy file"):
|
|
||||||
db.execute('select * from vec_npy_each(vec_npy_file("not exist"))')
|
|
||||||
|
|
||||||
with _raises("numpy array file too short"):
|
|
||||||
vec_npy_each(b"\x93NUMPY\x01\x00v")
|
|
||||||
|
|
||||||
with _raises("numpy array file does not contain the 'magic' header"):
|
|
||||||
vec_npy_each(b"\x93XUMPY\x01\x00v\x00")
|
|
||||||
|
|
||||||
with _raises("numpy array file header length is invalid"):
|
|
||||||
vec_npy_each(b"\x93NUMPY\x01\x00v\x00")
|
|
||||||
|
|
||||||
with _raises(
|
|
||||||
"Error parsing numpy array: Only fortran_order = False is supported in sqlite-vec numpy functions"
|
|
||||||
):
|
|
||||||
vec_npy_each(
|
|
||||||
b"\x93NUMPY\x01\x00v\x00{'descr': '<f4', 'fortran_order': True, 'shape': (2, 4), } \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3@"
|
|
||||||
)
|
|
||||||
|
|
||||||
with _raises("numpy array file error: Expected a data size of 32, found 31"):
|
|
||||||
vec_npy_each(
|
|
||||||
b"\x93NUMPY\x01\x00v\x00{'descr': '<f4', 'fortran_order': False, 'shape': (2, 4), } \n\xcd\xcc\x8c?\xcd\xcc\x0c@33S@\xcd\xcc\x8c@ff\x1eA\xcd\xcc\x0cAff\xf6@33\xd3"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert vec_npy_each(to_npy(np.array([1.1, 2.2, 3.3], dtype=np.float32))) == [
|
|
||||||
{
|
|
||||||
"rowid": 0,
|
|
||||||
"vector": _f32([1.1, 2.2, 3.3]),
|
|
||||||
},
|
|
||||||
]
|
|
||||||
assert vec_npy_each(
|
|
||||||
to_npy(np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32))
|
|
||||||
) == [
|
|
||||||
{
|
|
||||||
"rowid": 0,
|
|
||||||
"vector": _f32([1.1, 2.2, 3.3]),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"rowid": 1,
|
|
||||||
"vector": _f32([4.4, 5.5, 6.6]),
|
|
||||||
},
|
|
||||||
]
|
|
||||||
assert vec_npy_each(to_npy(np.array([], dtype=np.float32))) == []
|
|
||||||
x1025 = vec_npy_each(to_npy(np.array([[0.1, 0.2, 0.3]] * 1025, dtype=np.float32)))
|
|
||||||
assert len(x1025) == 1025
|
|
||||||
|
|
||||||
# np.array([[.1, .2, 3]] * 99, dtype=np.float32).shape
|
|
||||||
|
|
||||||
|
|
||||||
def test_vec0_constructor():
|
def test_vec0_constructor():
|
||||||
vec_constructor_error_prefix = "vec0 constructor error: {}"
|
vec_constructor_error_prefix = "vec0 constructor error: {}"
|
||||||
vec_col_error_prefix = "vec0 constructor error: could not parse vector column '{}'"
|
vec_col_error_prefix = "vec0 constructor error: could not parse vector column '{}'"
|
||||||
|
|
@ -1923,6 +1590,54 @@ def test_vec0_constructor():
|
||||||
db.execute("create virtual table v using vec0(4)")
|
db.execute("create virtual table v using vec0(4)")
|
||||||
|
|
||||||
|
|
||||||
|
def test_vec0_indexed_by_flat():
|
||||||
|
db.execute("drop table if exists t_ibf")
|
||||||
|
db.execute("drop table if exists t_ibf2")
|
||||||
|
db.execute("drop table if exists t_ibf3")
|
||||||
|
db.execute("drop table if exists t_ibf4")
|
||||||
|
|
||||||
|
# indexed by flat() should succeed and behave identically to no index clause
|
||||||
|
db.execute("create virtual table t_ibf using vec0(emb float[4] indexed by flat())")
|
||||||
|
db.execute(
|
||||||
|
"insert into t_ibf(rowid, emb) values (1, X'00000000000000000000000000000000')"
|
||||||
|
)
|
||||||
|
rows = db.execute("select rowid from t_ibf where emb match X'00000000000000000000000000000000' and k = 1").fetchall()
|
||||||
|
assert len(rows) == 1
|
||||||
|
assert rows[0][0] == 1
|
||||||
|
db.execute("drop table t_ibf")
|
||||||
|
|
||||||
|
# indexed by flat() with distance_metric
|
||||||
|
db.execute(
|
||||||
|
"create virtual table t_ibf2 using vec0(emb float[4] distance_metric=cosine indexed by flat())"
|
||||||
|
)
|
||||||
|
db.execute("drop table t_ibf2")
|
||||||
|
|
||||||
|
# indexed by flat() on int8
|
||||||
|
db.execute("create virtual table t_ibf3 using vec0(emb int8[4] indexed by flat())")
|
||||||
|
db.execute("drop table t_ibf3")
|
||||||
|
|
||||||
|
# indexed by flat() on bit
|
||||||
|
db.execute("create virtual table t_ibf4 using vec0(emb bit[8] indexed by flat())")
|
||||||
|
db.execute("drop table t_ibf4")
|
||||||
|
|
||||||
|
# Error: unknown index type
|
||||||
|
with _raises(
|
||||||
|
"vec0 constructor error: could not parse vector column 'emb float[4] indexed by unknown()'",
|
||||||
|
sqlite3.DatabaseError,
|
||||||
|
):
|
||||||
|
db.execute("create virtual table v using vec0(emb float[4] indexed by unknown())")
|
||||||
|
|
||||||
|
# Error: indexed by (missing type)
|
||||||
|
with _raises(
|
||||||
|
"vec0 constructor error: could not parse vector column 'emb float[4] indexed by'",
|
||||||
|
sqlite3.DatabaseError,
|
||||||
|
):
|
||||||
|
db.execute("create virtual table v using vec0(emb float[4] indexed by)")
|
||||||
|
|
||||||
|
if db.in_transaction:
|
||||||
|
db.rollback()
|
||||||
|
|
||||||
|
|
||||||
def test_vec0_create_errors():
|
def test_vec0_create_errors():
|
||||||
# EVIDENCE-OF: V17740_01811 vec0 create _chunks error handling
|
# EVIDENCE-OF: V17740_01811 vec0 create _chunks error handling
|
||||||
db.set_authorizer(authorizer_deny_on(sqlite3.SQLITE_CREATE_TABLE, "t1_chunks"))
|
db.set_authorizer(authorizer_deny_on(sqlite3.SQLITE_CREATE_TABLE, "t1_chunks"))
|
||||||
|
|
|
||||||
|
|
@ -265,6 +265,35 @@ def test_deletes(db, snapshot):
|
||||||
assert vec0_shadow_table_contents(db, "v") == snapshot()
|
assert vec0_shadow_table_contents(db, "v") == snapshot()
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_by_metadata_with_long_text(db):
|
||||||
|
"""Regression for https://github.com/asg017/sqlite-vec/issues/274.
|
||||||
|
|
||||||
|
ClearMetadata left rc=SQLITE_DONE after the long-text DELETE, which
|
||||||
|
propagated as an error and silently aborted the DELETE scan.
|
||||||
|
"""
|
||||||
|
db.execute(
|
||||||
|
"create virtual table v using vec0("
|
||||||
|
" tag text, embedding float[4], chunk_size=8"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
for i in range(6):
|
||||||
|
db.execute(
|
||||||
|
"insert into v(tag, embedding) values (?, zeroblob(16))",
|
||||||
|
[f"long_text_value_{i}"],
|
||||||
|
)
|
||||||
|
for i in range(4):
|
||||||
|
db.execute(
|
||||||
|
"insert into v(tag, embedding) values (?, zeroblob(16))",
|
||||||
|
[f"long_text_value_0"],
|
||||||
|
)
|
||||||
|
assert db.execute("select count(*) from v").fetchone()[0] == 10
|
||||||
|
|
||||||
|
# DELETE by metadata WHERE — the pattern from the issue
|
||||||
|
db.execute("delete from v where tag = 'long_text_value_0'")
|
||||||
|
assert db.execute("select count(*) from v where tag = 'long_text_value_0'").fetchone()[0] == 0
|
||||||
|
assert db.execute("select count(*) from v").fetchone()[0] == 5
|
||||||
|
|
||||||
|
|
||||||
def test_knn(db, snapshot):
|
def test_knn(db, snapshot):
|
||||||
db.execute(
|
db.execute(
|
||||||
"create virtual table v using vec0(vector float[1], name text, chunk_size=8)"
|
"create virtual table v using vec0(vector float[1], name text, chunk_size=8)"
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue