From bf2455f2bacb0aef976ee03e03e265f270d5afc9 Mon Sep 17 00:00:00 2001 From: Alex Garcia Date: Sun, 29 Mar 2026 19:44:44 -0700 Subject: [PATCH 1/3] Add ANN search support for vec0 virtual table Add approximate nearest neighbor infrastructure to vec0: shared distance dispatch (vec0_distance_full), flat index type with parser, NEON-optimized cosine/Hamming for float32/int8, amalgamation script, and benchmark suite (benchmarks-ann/) with ground-truth generation and profiling tools. Remove unused vec_npy_each/vec_static_blobs code, fix missing stdint.h include. --- Makefile | 14 +- TODO.md | 73 + benchmarks-ann/.gitignore | 2 + benchmarks-ann/Makefile | 61 + benchmarks-ann/README.md | 81 + benchmarks-ann/bench.py | 488 ++++++ benchmarks-ann/ground_truth.py | 168 ++ benchmarks-ann/profile.py | 440 ++++++ benchmarks-ann/schema.sql | 35 + benchmarks-ann/seed/.gitignore | 2 + benchmarks-ann/seed/Makefile | 24 + benchmarks-ann/seed/build_base_db.py | 121 ++ benchmarks/exhaustive-memory/bench.py | 57 +- benchmarks/profiling/build-from-npy.sql | 7 - benchmarks/self-params/build.py | 14 +- bindings/go/ncruces/go-sqlite3.patch | 1 - bindings/python/extra_init.py | 31 - scripts/amalgamate.py | 119 ++ site/api-reference.md | 59 - site/compiling.md | 1 - sqlite-vec.c | 1863 +++++------------------ tests/correctness/test-correctness.py | 17 +- tests/fuzz/numpy.c | 37 - tests/sqlite-vec-internal.h | 6 + tests/test-loadable.py | 415 +---- tests/test-unit.c | 101 ++ tmp-static.py | 56 - 27 files changed, 2177 insertions(+), 2116 deletions(-) create mode 100644 TODO.md create mode 100644 benchmarks-ann/.gitignore create mode 100644 benchmarks-ann/Makefile create mode 100644 benchmarks-ann/README.md create mode 100644 benchmarks-ann/bench.py create mode 100644 benchmarks-ann/ground_truth.py create mode 100644 benchmarks-ann/profile.py create mode 100644 benchmarks-ann/schema.sql create mode 100644 benchmarks-ann/seed/.gitignore create mode 100644 benchmarks-ann/seed/Makefile create mode 100644 benchmarks-ann/seed/build_base_db.py create mode 100644 scripts/amalgamate.py delete mode 100644 tests/fuzz/numpy.c delete mode 100644 tmp-static.py diff --git a/Makefile b/Makefile index 1ebdbed..051590e 100644 --- a/Makefile +++ b/Makefile @@ -42,6 +42,11 @@ ifndef OMIT_SIMD ifeq ($(shell uname -sm),Darwin arm64) CFLAGS += -mcpu=apple-m1 -DSQLITE_VEC_ENABLE_NEON endif + ifeq ($(shell uname -s),Linux) + ifneq ($(filter avx,$(shell grep -o 'avx[^ ]*' /proc/cpuinfo 2>/dev/null | head -1)),) + CFLAGS += -mavx -DSQLITE_VEC_ENABLE_AVX + endif + endif endif ifdef USE_BREW_SQLITE @@ -155,6 +160,13 @@ clean: rm -rf dist +TARGET_AMALGAMATION=$(prefix)/sqlite-vec.c + +amalgamation: $(TARGET_AMALGAMATION) + +$(TARGET_AMALGAMATION): sqlite-vec.c $(wildcard sqlite-vec-*.c) scripts/amalgamate.py $(prefix) + python3 scripts/amalgamate.py sqlite-vec.c > $@ + FORMAT_FILES=sqlite-vec.h sqlite-vec.c format: $(FORMAT_FILES) clang-format -i $(FORMAT_FILES) @@ -174,7 +186,7 @@ evidence-of: test: sqlite3 :memory: '.read test.sql' -.PHONY: version loadable static test clean gh-release evidence-of install uninstall +.PHONY: version loadable static test clean gh-release evidence-of install uninstall amalgamation publish-release: ./scripts/publish-release.sh diff --git a/TODO.md b/TODO.md new file mode 100644 index 0000000..4c3cc19 --- /dev/null +++ b/TODO.md @@ -0,0 +1,73 @@ +# TODO: `ann` base branch + consolidated benchmarks + +## 1. Create `ann` branch with shared code + +### 1.1 Branch setup +- [x] `git checkout -B ann origin/main` +- [x] Cherry-pick `624f998` (vec0_distance_full shared distance dispatch) +- [x] Cherry-pick stdint.h fix for test header +- [ ] Pull NEON cosine optimization from ivf-yolo3 into shared code + - Currently only in ivf branch but is general-purpose (benefits all distance calcs) + - Lives in `distance_cosine_float()` — ~57 lines of ARM NEON vectorized cosine + +### 1.2 Benchmark infrastructure (`benchmarks-ann/`) +- [x] Seed data pipeline (`seed/Makefile`, `seed/build_base_db.py`) +- [x] Ground truth generator (`ground_truth.py`) +- [x] Results schema (`schema.sql`) +- [x] Benchmark runner with `INDEX_REGISTRY` extension point (`bench.py`) + - Baseline configs (float, int8-rescore, bit-rescore) implemented + - Index branches register their types via `INDEX_REGISTRY` dict +- [x] Makefile with baseline targets +- [x] README + +### 1.3 Rebase feature branches onto `ann` +- [x] Rebase `diskann-yolo2` onto `ann` (1 commit: DiskANN implementation) +- [x] Rebase `ivf-yolo3` onto `ann` (1 commit: IVF implementation) +- [x] Rebase `annoy-yolo2` onto `ann` (2 commits: Annoy implementation + schema fix) +- [x] Verify each branch has only its index-specific commits remaining +- [ ] Force-push all 4 branches to origin + +--- + +## 2. Per-branch: register index type in benchmarks + +Each index branch should add to `benchmarks-ann/` when rebased onto `ann`: + +### 2.1 Register in `bench.py` + +Add an `INDEX_REGISTRY` entry. Each entry provides: +- `defaults` — default param values +- `create_table_sql(params)` — CREATE VIRTUAL TABLE with INDEXED BY clause +- `insert_sql(params)` — custom insert SQL, or None for default +- `post_insert_hook(conn, params)` — training/building step, returns time +- `run_query(conn, params, query, k)` — custom query, or None for default MATCH +- `describe(params)` — one-line description for report output + +### 2.2 Add configs to `Makefile` + +Append index-specific config variables and targets. Example pattern: + +```makefile +DISKANN_CONFIGS = \ + "diskann-R48-binary:type=diskann,R=48,L=128,quantizer=binary" \ + ... + +ALL_CONFIGS += $(DISKANN_CONFIGS) + +bench-diskann: seed + $(BENCH) --subset-size 10000 -k 10 -o runs/diskann $(BASELINES) $(DISKANN_CONFIGS) + ... +``` + +### 2.3 Migrate existing benchmark results/docs + +- Move useful results docs (RESULTS.md, etc.) into `benchmarks-ann/results/` +- Delete redundant per-branch benchmark directories once consolidated infra is proven + +--- + +## 3. Future improvements + +- [ ] Reporting script (`report.py`) — query results.db, produce markdown comparison tables +- [ ] Profiling targets in Makefile (lift from ivf-yolo3's Instruments/perf wrappers) +- [ ] Pre-computed ground truth integration (use GT DB files instead of on-the-fly brute-force) diff --git a/benchmarks-ann/.gitignore b/benchmarks-ann/.gitignore new file mode 100644 index 0000000..c418b76 --- /dev/null +++ b/benchmarks-ann/.gitignore @@ -0,0 +1,2 @@ +*.db +runs/ diff --git a/benchmarks-ann/Makefile b/benchmarks-ann/Makefile new file mode 100644 index 0000000..59e2dcd --- /dev/null +++ b/benchmarks-ann/Makefile @@ -0,0 +1,61 @@ +BENCH = python bench.py +BASE_DB = seed/base.db +EXT = ../dist/vec0 + +# --- Baseline (brute-force) configs --- +BASELINES = \ + "brute-float:type=baseline,variant=float" \ + "brute-int8:type=baseline,variant=int8" \ + "brute-bit:type=baseline,variant=bit" + +# --- Index-specific configs --- +# Each index branch should add its own configs here. Example: +# +# DISKANN_CONFIGS = \ +# "diskann-R48-binary:type=diskann,R=48,L=128,quantizer=binary" \ +# "diskann-R72-int8:type=diskann,R=72,L=128,quantizer=int8" +# +# IVF_CONFIGS = \ +# "ivf-n128-p16:type=ivf,nlist=128,nprobe=16" +# +# ANNOY_CONFIGS = \ +# "annoy-t50:type=annoy,n_trees=50" + +ALL_CONFIGS = $(BASELINES) + +.PHONY: seed ground-truth bench-smoke bench-10k bench-50k bench-100k bench-all \ + report clean + +# --- Data preparation --- +seed: + $(MAKE) -C seed + +ground-truth: seed + python ground_truth.py --subset-size 10000 + python ground_truth.py --subset-size 50000 + python ground_truth.py --subset-size 100000 + +# --- Quick smoke test --- +bench-smoke: seed + $(BENCH) --subset-size 5000 -k 10 -n 20 -o runs/smoke \ + $(BASELINES) + +# --- Standard sizes --- +bench-10k: seed + $(BENCH) --subset-size 10000 -k 10 -o runs/10k $(ALL_CONFIGS) + +bench-50k: seed + $(BENCH) --subset-size 50000 -k 10 -o runs/50k $(ALL_CONFIGS) + +bench-100k: seed + $(BENCH) --subset-size 100000 -k 10 -o runs/100k $(ALL_CONFIGS) + +bench-all: bench-10k bench-50k bench-100k + +# --- Report --- +report: + @echo "Use: sqlite3 runs//results.db 'SELECT * FROM bench_results ORDER BY recall DESC'" + +# --- Cleanup --- +clean: + rm -rf runs/ diff --git a/benchmarks-ann/README.md b/benchmarks-ann/README.md new file mode 100644 index 0000000..1f7fd5c --- /dev/null +++ b/benchmarks-ann/README.md @@ -0,0 +1,81 @@ +# KNN Benchmarks for sqlite-vec + +Benchmarking infrastructure for vec0 KNN configurations. Includes brute-force +baselines (float, int8, bit); index-specific branches add their own types +via the `INDEX_REGISTRY` in `bench.py`. + +## Prerequisites + +- Built `dist/vec0` extension (run `make` from repo root) +- Python 3.10+ +- `uv` (for seed data prep): `pip install uv` + +## Quick start + +```bash +# 1. Download dataset and build seed DB (~3 GB download, ~5 min) +make seed + +# 2. Run a quick smoke test (5k vectors, ~1 min) +make bench-smoke + +# 3. Run full benchmark at 10k +make bench-10k +``` + +## Usage + +### Direct invocation + +```bash +python bench.py --subset-size 10000 \ + "brute-float:type=baseline,variant=float" \ + "brute-int8:type=baseline,variant=int8" \ + "brute-bit:type=baseline,variant=bit" +``` + +### Config format + +`name:type=,key=val,key=val` + +| Index type | Keys | Branch | +|-----------|------|--------| +| `baseline` | `variant` (float/int8/bit), `oversample` | this branch | + +Index branches register additional types in `INDEX_REGISTRY`. See the +docstring in `bench.py` for the extension API. + +### Make targets + +| Target | Description | +|--------|-------------| +| `make seed` | Download COHERE 1M dataset | +| `make ground-truth` | Pre-compute ground truth for 10k/50k/100k | +| `make bench-smoke` | Quick 5k baseline test | +| `make bench-10k` | All configs at 10k vectors | +| `make bench-50k` | All configs at 50k vectors | +| `make bench-100k` | All configs at 100k vectors | +| `make bench-all` | 10k + 50k + 100k | + +## Adding an index type + +In your index branch, add an entry to `INDEX_REGISTRY` in `bench.py` and +append your configs to `ALL_CONFIGS` in the `Makefile`. See the existing +`baseline` entry and the comments in both files for the pattern. + +## Results + +Results are stored in `runs//results.db` using the schema in `schema.sql`. + +```bash +sqlite3 runs/10k/results.db " + SELECT config_name, recall, mean_ms, qps + FROM bench_results + ORDER BY recall DESC +" +``` + +## Dataset + +[Zilliz COHERE Medium 1M](https://zilliz.com/learn/datasets-for-vector-database-benchmarks): +768 dimensions, cosine distance, 1M train vectors + 10k query vectors with precomputed neighbors. diff --git a/benchmarks-ann/bench.py b/benchmarks-ann/bench.py new file mode 100644 index 0000000..93f8f82 --- /dev/null +++ b/benchmarks-ann/bench.py @@ -0,0 +1,488 @@ +#!/usr/bin/env python3 +"""Benchmark runner for sqlite-vec KNN configurations. + +Measures insert time, build/train time, DB size, KNN latency, and recall +across different vec0 configurations. + +Config format: name:type=,key=val,key=val + + Baseline (brute-force) keys: + type=baseline, variant=float|int8|bit, oversample=8 + + Index-specific types can be registered via INDEX_REGISTRY (see below). + +Usage: + python bench.py --subset-size 10000 \ + "brute-float:type=baseline,variant=float" \ + "brute-int8:type=baseline,variant=int8" \ + "brute-bit:type=baseline,variant=bit" +""" +import argparse +import os +import sqlite3 +import statistics +import time + +_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +EXT_PATH = os.path.join(_SCRIPT_DIR, "..", "dist", "vec0") +BASE_DB = os.path.join(_SCRIPT_DIR, "seed", "base.db") +INSERT_BATCH_SIZE = 1000 + + +# ============================================================================ +# Index registry — extension point for ANN index branches +# ============================================================================ +# +# Each index type provides a dict with: +# "defaults": dict of default params +# "create_table_sql": fn(params) -> SQL string +# "insert_sql": fn(params) -> SQL string (or None for default) +# "post_insert_hook": fn(conn, params) -> train_time_s (or None) +# "run_query": fn(conn, params, query, k) -> [(id, distance), ...] (or None for default MATCH) +# "describe": fn(params) -> str (one-line description) +# +# To add a new index type, add an entry here. Example (in your branch): +# +# INDEX_REGISTRY["diskann"] = { +# "defaults": {"R": 72, "L": 128, "quantizer": "binary", "buffer_threshold": 0}, +# "create_table_sql": lambda p: f"CREATE VIRTUAL TABLE vec_items USING vec0(...)", +# "insert_sql": None, +# "post_insert_hook": None, +# "run_query": None, +# "describe": lambda p: f"diskann q={p['quantizer']} R={p['R']} L={p['L']}", +# } + +INDEX_REGISTRY = {} + + +# ============================================================================ +# Baseline implementation +# ============================================================================ + + +def _baseline_create_table_sql(params): + variant = params["variant"] + extra = "" + if variant == "int8": + extra = ", embedding_int8 int8[768]" + elif variant == "bit": + extra = ", embedding_bq bit[768]" + return ( + f"CREATE VIRTUAL TABLE vec_items USING vec0(" + f" chunk_size=256," + f" id integer primary key," + f" embedding float[768] distance_metric=cosine" + f" {extra})" + ) + + +def _baseline_insert_sql(params): + variant = params["variant"] + if variant == "int8": + return ( + "INSERT INTO vec_items(id, embedding, embedding_int8) " + "SELECT id, vector, vec_quantize_int8(vector, 'unit') " + "FROM base.train WHERE id >= :lo AND id < :hi" + ) + elif variant == "bit": + return ( + "INSERT INTO vec_items(id, embedding, embedding_bq) " + "SELECT id, vector, vec_quantize_binary(vector) " + "FROM base.train WHERE id >= :lo AND id < :hi" + ) + return None # use default + + +def _baseline_run_query(conn, params, query, k): + variant = params["variant"] + oversample = params.get("oversample", 8) + + if variant == "int8": + return conn.execute( + "WITH coarse AS (" + " SELECT id, embedding FROM vec_items" + " WHERE embedding_int8 MATCH vec_quantize_int8(:query, 'unit')" + " LIMIT :oversample_k" + ") " + "SELECT id, vec_distance_cosine(embedding, :query) as distance " + "FROM coarse ORDER BY 2 LIMIT :k", + {"query": query, "k": k, "oversample_k": k * oversample}, + ).fetchall() + elif variant == "bit": + return conn.execute( + "WITH coarse AS (" + " SELECT id, embedding FROM vec_items" + " WHERE embedding_bq MATCH vec_quantize_binary(:query)" + " LIMIT :oversample_k" + ") " + "SELECT id, vec_distance_cosine(embedding, :query) as distance " + "FROM coarse ORDER BY 2 LIMIT :k", + {"query": query, "k": k, "oversample_k": k * oversample}, + ).fetchall() + + return None # use default MATCH + + +def _baseline_describe(params): + v = params["variant"] + if v in ("int8", "bit"): + return f"baseline {v} (os={params['oversample']})" + return f"baseline {v}" + + +INDEX_REGISTRY["baseline"] = { + "defaults": {"variant": "float", "oversample": 8}, + "create_table_sql": _baseline_create_table_sql, + "insert_sql": _baseline_insert_sql, + "post_insert_hook": None, + "run_query": _baseline_run_query, + "describe": _baseline_describe, +} + + +# ============================================================================ +# Config parsing +# ============================================================================ + +INT_KEYS = { + "R", "L", "buffer_threshold", "nlist", "nprobe", "oversample", + "n_trees", "search_k", +} + + +def parse_config(spec): + """Parse 'name:type=baseline,key=val,...' into (name, params_dict).""" + if ":" in spec: + name, opts_str = spec.split(":", 1) + else: + name, opts_str = spec, "" + + raw = {} + if opts_str: + for kv in opts_str.split(","): + k, v = kv.split("=", 1) + raw[k.strip()] = v.strip() + + index_type = raw.pop("type", "baseline") + if index_type not in INDEX_REGISTRY: + raise ValueError( + f"Unknown index type: {index_type}. " + f"Available: {', '.join(sorted(INDEX_REGISTRY.keys()))}" + ) + + reg = INDEX_REGISTRY[index_type] + params = dict(reg["defaults"]) + for k, v in raw.items(): + if k in INT_KEYS: + params[k] = int(v) + else: + params[k] = v + params["index_type"] = index_type + + return name, params + + +# ============================================================================ +# Shared helpers +# ============================================================================ + + +def load_query_vectors(base_db_path, n): + conn = sqlite3.connect(base_db_path) + rows = conn.execute( + "SELECT id, vector FROM query_vectors ORDER BY id LIMIT :n", {"n": n} + ).fetchall() + conn.close() + return [(r[0], r[1]) for r in rows] + + +def insert_loop(conn, sql, subset_size, label=""): + t0 = time.perf_counter() + for lo in range(0, subset_size, INSERT_BATCH_SIZE): + hi = min(lo + INSERT_BATCH_SIZE, subset_size) + conn.execute(sql, {"lo": lo, "hi": hi}) + conn.commit() + done = hi + if done % 5000 == 0 or done == subset_size: + elapsed = time.perf_counter() - t0 + rate = done / elapsed if elapsed > 0 else 0 + print( + f" [{label}] {done:>8}/{subset_size} " + f"{elapsed:.1f}s {rate:.0f} rows/s", + flush=True, + ) + return time.perf_counter() - t0 + + +def open_bench_db(db_path, ext_path, base_db): + if os.path.exists(db_path): + os.remove(db_path) + conn = sqlite3.connect(db_path) + conn.enable_load_extension(True) + conn.load_extension(ext_path) + conn.execute("PRAGMA page_size=8192") + conn.execute(f"ATTACH DATABASE '{base_db}' AS base") + return conn + + +DEFAULT_INSERT_SQL = ( + "INSERT INTO vec_items(id, embedding) " + "SELECT id, vector FROM base.train WHERE id >= :lo AND id < :hi" +) + + +# ============================================================================ +# Build +# ============================================================================ + + +def build_index(base_db, ext_path, name, params, subset_size, out_dir): + db_path = os.path.join(out_dir, f"{name}.{subset_size}.db") + conn = open_bench_db(db_path, ext_path, base_db) + + reg = INDEX_REGISTRY[params["index_type"]] + + conn.execute(reg["create_table_sql"](params)) + + label = params["index_type"] + print(f" Inserting {subset_size} vectors...") + + sql_fn = reg.get("insert_sql") + sql = sql_fn(params) if sql_fn else None + if sql is None: + sql = DEFAULT_INSERT_SQL + + insert_time = insert_loop(conn, sql, subset_size, label) + + train_time = 0.0 + hook = reg.get("post_insert_hook") + if hook: + train_time = hook(conn, params) + + row_count = conn.execute("SELECT count(*) FROM vec_items").fetchone()[0] + conn.close() + file_size_mb = os.path.getsize(db_path) / (1024 * 1024) + + return { + "db_path": db_path, + "insert_time_s": round(insert_time, 3), + "train_time_s": round(train_time, 3), + "total_time_s": round(insert_time + train_time, 3), + "insert_per_vec_ms": round((insert_time / row_count) * 1000, 2) + if row_count + else 0, + "rows": row_count, + "file_size_mb": round(file_size_mb, 2), + } + + +# ============================================================================ +# KNN measurement +# ============================================================================ + + +def _default_match_query(conn, query, k): + return conn.execute( + "SELECT id, distance FROM vec_items " + "WHERE embedding MATCH :query AND k = :k", + {"query": query, "k": k}, + ).fetchall() + + +def measure_knn(db_path, ext_path, base_db, params, subset_size, k=10, n=50): + conn = sqlite3.connect(db_path) + conn.enable_load_extension(True) + conn.load_extension(ext_path) + conn.execute(f"ATTACH DATABASE '{base_db}' AS base") + + query_vectors = load_query_vectors(base_db, n) + + reg = INDEX_REGISTRY[params["index_type"]] + query_fn = reg.get("run_query") + + times_ms = [] + recalls = [] + for qid, query in query_vectors: + t0 = time.perf_counter() + + results = None + if query_fn: + results = query_fn(conn, params, query, k) + if results is None: + results = _default_match_query(conn, query, k) + + elapsed_ms = (time.perf_counter() - t0) * 1000 + times_ms.append(elapsed_ms) + result_ids = set(r[0] for r in results) + + # Ground truth: use pre-computed neighbors table for full dataset, + # otherwise brute-force over the subset + if subset_size >= 1000000: + gt_rows = conn.execute( + "SELECT CAST(neighbors_id AS INTEGER) FROM base.neighbors " + "WHERE query_vector_id = :qid AND rank < :k", + {"qid": qid, "k": k}, + ).fetchall() + else: + gt_rows = conn.execute( + "SELECT id FROM (" + " SELECT id, vec_distance_cosine(vector, :query) as dist " + " FROM base.train WHERE id < :n ORDER BY dist LIMIT :k" + ")", + {"query": query, "k": k, "n": subset_size}, + ).fetchall() + gt_ids = set(r[0] for r in gt_rows) + + if gt_ids: + recalls.append(len(result_ids & gt_ids) / len(gt_ids)) + else: + recalls.append(0.0) + + conn.close() + + return { + "mean_ms": round(statistics.mean(times_ms), 2), + "median_ms": round(statistics.median(times_ms), 2), + "p99_ms": round(sorted(times_ms)[int(len(times_ms) * 0.99)], 2) + if len(times_ms) > 1 + else round(times_ms[0], 2), + "total_ms": round(sum(times_ms), 2), + "recall": round(statistics.mean(recalls), 4), + } + + +# ============================================================================ +# Results persistence +# ============================================================================ + + +def save_results(results_path, rows): + db = sqlite3.connect(results_path) + db.executescript(open(os.path.join(_SCRIPT_DIR, "schema.sql")).read()) + for r in rows: + db.execute( + "INSERT OR REPLACE INTO build_results " + "(config_name, index_type, subset_size, db_path, " + " insert_time_s, train_time_s, total_time_s, rows, file_size_mb) " + "VALUES (?,?,?,?,?,?,?,?,?)", + ( + r["name"], r["index_type"], r["n_vectors"], r["db_path"], + r["insert_time_s"], r["train_time_s"], r["total_time_s"], + r["rows"], r["file_size_mb"], + ), + ) + db.execute( + "INSERT OR REPLACE INTO bench_results " + "(config_name, index_type, subset_size, k, n, " + " mean_ms, median_ms, p99_ms, total_ms, qps, recall, db_path) " + "VALUES (?,?,?,?,?,?,?,?,?,?,?,?)", + ( + r["name"], r["index_type"], r["n_vectors"], r["k"], r["n_queries"], + r["mean_ms"], r["median_ms"], r["p99_ms"], r["total_ms"], + round(r["n_queries"] / (r["total_ms"] / 1000), 1) + if r["total_ms"] > 0 else 0, + r["recall"], r["db_path"], + ), + ) + db.commit() + db.close() + + +# ============================================================================ +# Reporting +# ============================================================================ + + +def print_report(all_results): + print( + f"\n{'name':>20} {'N':>7} {'type':>10} {'config':>28} " + f"{'ins(s)':>7} {'train':>6} {'MB':>7} " + f"{'qry(ms)':>8} {'recall':>7}" + ) + print("-" * 115) + for r in all_results: + train = f"{r['train_time_s']:.1f}" if r["train_time_s"] > 0 else "-" + print( + f"{r['name']:>20} {r['n_vectors']:>7} {r['index_type']:>10} " + f"{r['config_desc']:>28} " + f"{r['insert_time_s']:>7.1f} {train:>6} {r['file_size_mb']:>7.1f} " + f"{r['mean_ms']:>8.2f} {r['recall']:>7.4f}" + ) + + +# ============================================================================ +# Main +# ============================================================================ + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark runner for sqlite-vec KNN configurations", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument("configs", nargs="+", help="config specs (name:type=X,key=val,...)") + parser.add_argument("--subset-size", type=int, required=True) + parser.add_argument("-k", type=int, default=10, help="KNN k (default 10)") + parser.add_argument("-n", type=int, default=50, help="number of queries (default 50)") + parser.add_argument("--base-db", default=BASE_DB) + parser.add_argument("--ext", default=EXT_PATH) + parser.add_argument("-o", "--out-dir", default="runs") + parser.add_argument("--results-db", default=None, + help="path to results DB (default: /results.db)") + args = parser.parse_args() + + os.makedirs(args.out_dir, exist_ok=True) + results_db = args.results_db or os.path.join(args.out_dir, "results.db") + configs = [parse_config(c) for c in args.configs] + + all_results = [] + for i, (name, params) in enumerate(configs, 1): + reg = INDEX_REGISTRY[params["index_type"]] + desc = reg["describe"](params) + print(f"\n[{i}/{len(configs)}] {name} ({desc.strip()})") + + build = build_index( + args.base_db, args.ext, name, params, args.subset_size, args.out_dir + ) + train_str = f" + {build['train_time_s']}s train" if build["train_time_s"] > 0 else "" + print( + f" Build: {build['insert_time_s']}s insert{train_str} " + f"{build['file_size_mb']} MB" + ) + + print(f" Measuring KNN (k={args.k}, n={args.n})...") + knn = measure_knn( + build["db_path"], args.ext, args.base_db, + params, args.subset_size, k=args.k, n=args.n, + ) + print(f" KNN: mean={knn['mean_ms']}ms recall@{args.k}={knn['recall']}") + + all_results.append({ + "name": name, + "n_vectors": args.subset_size, + "index_type": params["index_type"], + "config_desc": desc, + "db_path": build["db_path"], + "insert_time_s": build["insert_time_s"], + "train_time_s": build["train_time_s"], + "total_time_s": build["total_time_s"], + "insert_per_vec_ms": build["insert_per_vec_ms"], + "rows": build["rows"], + "file_size_mb": build["file_size_mb"], + "k": args.k, + "n_queries": args.n, + "mean_ms": knn["mean_ms"], + "median_ms": knn["median_ms"], + "p99_ms": knn["p99_ms"], + "total_ms": knn["total_ms"], + "recall": knn["recall"], + }) + + print_report(all_results) + save_results(results_db, all_results) + print(f"\nResults saved to {results_db}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks-ann/ground_truth.py b/benchmarks-ann/ground_truth.py new file mode 100644 index 0000000..636a495 --- /dev/null +++ b/benchmarks-ann/ground_truth.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 +"""Compute per-subset ground truth for ANN benchmarks. + +For subset sizes < 1M, builds a temporary vec0 float table with the first N +vectors and runs brute-force KNN to get correct ground truth per subset. + +For 1M (the full dataset), converts the existing `neighbors` table. + +Output: ground_truth.{subset_size}.db with table: + ground_truth(query_vector_id, rank, neighbor_id, distance) + +Usage: + python ground_truth.py --subset-size 50000 + python ground_truth.py --subset-size 1000000 +""" +import argparse +import os +import sqlite3 +import time + +_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +EXT_PATH = os.path.join(_SCRIPT_DIR, "..", "dist", "vec0") +BASE_DB = os.path.join(_SCRIPT_DIR, "seed", "base.db") +FULL_DATASET_SIZE = 1_000_000 + + +def gen_ground_truth_subset(base_db, ext_path, subset_size, n_queries, k, out_path): + """Build ground truth by brute-force KNN over the first `subset_size` vectors.""" + if os.path.exists(out_path): + os.remove(out_path) + + conn = sqlite3.connect(out_path) + conn.enable_load_extension(True) + conn.load_extension(ext_path) + + conn.execute( + "CREATE TABLE ground_truth (" + " query_vector_id INTEGER NOT NULL," + " rank INTEGER NOT NULL," + " neighbor_id INTEGER NOT NULL," + " distance REAL NOT NULL," + " PRIMARY KEY (query_vector_id, rank)" + ")" + ) + + conn.execute(f"ATTACH DATABASE '{base_db}' AS base") + + print(f" Building temp vec0 table with {subset_size} vectors...") + conn.execute( + "CREATE VIRTUAL TABLE tmp_vec USING vec0(" + " id integer primary key," + " embedding float[768] distance_metric=cosine" + ")" + ) + + t0 = time.perf_counter() + conn.execute( + "INSERT INTO tmp_vec(id, embedding) " + "SELECT id, vector FROM base.train WHERE id < :n", + {"n": subset_size}, + ) + conn.commit() + build_time = time.perf_counter() - t0 + print(f" Temp table built in {build_time:.1f}s") + + query_vectors = conn.execute( + "SELECT id, vector FROM base.query_vectors ORDER BY id LIMIT :n", + {"n": n_queries}, + ).fetchall() + + print(f" Running brute-force KNN for {len(query_vectors)} queries, k={k}...") + t0 = time.perf_counter() + + for i, (qid, qvec) in enumerate(query_vectors): + results = conn.execute( + "SELECT id, distance FROM tmp_vec " + "WHERE embedding MATCH :query AND k = :k", + {"query": qvec, "k": k}, + ).fetchall() + + for rank, (nid, dist) in enumerate(results): + conn.execute( + "INSERT INTO ground_truth(query_vector_id, rank, neighbor_id, distance) " + "VALUES (?, ?, ?, ?)", + (qid, rank, nid, dist), + ) + + if (i + 1) % 10 == 0 or i == 0: + elapsed = time.perf_counter() - t0 + eta = (elapsed / (i + 1)) * (len(query_vectors) - i - 1) + print( + f" {i+1}/{len(query_vectors)} queries " + f"elapsed={elapsed:.1f}s eta={eta:.1f}s", + flush=True, + ) + + conn.commit() + conn.execute("DROP TABLE tmp_vec") + conn.execute("DETACH DATABASE base") + conn.commit() + + elapsed = time.perf_counter() - t0 + total_rows = conn.execute("SELECT count(*) FROM ground_truth").fetchone()[0] + conn.close() + print(f" Ground truth: {total_rows} rows in {elapsed:.1f}s -> {out_path}") + + +def gen_ground_truth_full(base_db, n_queries, k, out_path): + """Convert the existing neighbors table for the full 1M dataset.""" + if os.path.exists(out_path): + os.remove(out_path) + + conn = sqlite3.connect(out_path) + conn.execute(f"ATTACH DATABASE '{base_db}' AS base") + + conn.execute( + "CREATE TABLE ground_truth (" + " query_vector_id INTEGER NOT NULL," + " rank INTEGER NOT NULL," + " neighbor_id INTEGER NOT NULL," + " distance REAL," + " PRIMARY KEY (query_vector_id, rank)" + ")" + ) + + conn.execute( + "INSERT INTO ground_truth(query_vector_id, rank, neighbor_id) " + "SELECT query_vector_id, rank, CAST(neighbors_id AS INTEGER) " + "FROM base.neighbors " + "WHERE query_vector_id < :n AND rank < :k", + {"n": n_queries, "k": k}, + ) + conn.commit() + + total_rows = conn.execute("SELECT count(*) FROM ground_truth").fetchone()[0] + conn.execute("DETACH DATABASE base") + conn.close() + print(f" Ground truth (full): {total_rows} rows -> {out_path}") + + +def main(): + parser = argparse.ArgumentParser(description="Generate per-subset ground truth") + parser.add_argument( + "--subset-size", type=int, required=True, help="number of vectors in subset" + ) + parser.add_argument("-n", type=int, default=100, help="number of query vectors") + parser.add_argument("-k", type=int, default=100, help="max k for ground truth") + parser.add_argument("--base-db", default=BASE_DB) + parser.add_argument("--ext", default=EXT_PATH) + parser.add_argument( + "-o", "--out-dir", default=os.path.join(_SCRIPT_DIR, "seed"), + help="output directory for ground_truth.{N}.db", + ) + args = parser.parse_args() + + os.makedirs(args.out_dir, exist_ok=True) + out_path = os.path.join(args.out_dir, f"ground_truth.{args.subset_size}.db") + + if args.subset_size >= FULL_DATASET_SIZE: + gen_ground_truth_full(args.base_db, args.n, args.k, out_path) + else: + gen_ground_truth_subset( + args.base_db, args.ext, args.subset_size, args.n, args.k, out_path + ) + + +if __name__ == "__main__": + main() diff --git a/benchmarks-ann/profile.py b/benchmarks-ann/profile.py new file mode 100644 index 0000000..0792373 --- /dev/null +++ b/benchmarks-ann/profile.py @@ -0,0 +1,440 @@ +#!/usr/bin/env python3 +"""CPU profiling for sqlite-vec KNN configurations using macOS `sample` tool. + +Builds dist/sqlite3 (with -g3), generates a SQL workload (inserts + repeated +KNN queries) for each config, profiles the sqlite3 process with `sample`, and +prints the top-N hottest functions by self (exclusive) CPU samples. + +Usage: + cd benchmarks-ann + uv run profile.py --subset-size 50000 -n 50 \\ + "baseline-int8:type=baseline,variant=int8,oversample=8" \\ + "rescore-int8:type=rescore,quantizer=int8,oversample=8" +""" + +import argparse +import os +import re +import shutil +import subprocess +import sys +import tempfile + +_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +_PROJECT_ROOT = os.path.join(_SCRIPT_DIR, "..") + +sys.path.insert(0, _SCRIPT_DIR) +from bench import ( + BASE_DB, + DEFAULT_INSERT_SQL, + INDEX_REGISTRY, + INSERT_BATCH_SIZE, + parse_config, +) + +SQLITE3_PATH = os.path.join(_PROJECT_ROOT, "dist", "sqlite3") +EXT_PATH = os.path.join(_PROJECT_ROOT, "dist", "vec0") + + +# ============================================================================ +# SQL generation +# ============================================================================ + + +def _query_sql_for_config(params, query_id, k): + """Return a SQL query string for a single KNN query by query_vector id.""" + index_type = params["index_type"] + qvec = f"(SELECT vector FROM base.query_vectors WHERE id = {query_id})" + + if index_type == "baseline": + variant = params.get("variant", "float") + oversample = params.get("oversample", 8) + oversample_k = k * oversample + + if variant == "int8": + return ( + f"WITH coarse AS (" + f" SELECT id, embedding FROM vec_items" + f" WHERE embedding_int8 MATCH vec_quantize_int8({qvec}, 'unit')" + f" LIMIT {oversample_k}" + f") " + f"SELECT id, vec_distance_cosine(embedding, {qvec}) as distance " + f"FROM coarse ORDER BY 2 LIMIT {k};" + ) + elif variant == "bit": + return ( + f"WITH coarse AS (" + f" SELECT id, embedding FROM vec_items" + f" WHERE embedding_bq MATCH vec_quantize_binary({qvec})" + f" LIMIT {oversample_k}" + f") " + f"SELECT id, vec_distance_cosine(embedding, {qvec}) as distance " + f"FROM coarse ORDER BY 2 LIMIT {k};" + ) + + # Default MATCH query (baseline-float, rescore, and others) + return ( + f"SELECT id, distance FROM vec_items" + f" WHERE embedding MATCH {qvec} AND k = {k};" + ) + + +def generate_sql(db_path, params, subset_size, n_queries, k, repeats): + """Generate a complete SQL workload: load ext, create table, insert, query.""" + lines = [] + lines.append(".bail on") + lines.append(f".load {EXT_PATH}") + lines.append(f"ATTACH DATABASE '{os.path.abspath(BASE_DB)}' AS base;") + lines.append("PRAGMA page_size=8192;") + + # Create table + reg = INDEX_REGISTRY[params["index_type"]] + lines.append(reg["create_table_sql"](params) + ";") + + # Inserts + sql_fn = reg.get("insert_sql") + insert_sql = sql_fn(params) if sql_fn else None + if insert_sql is None: + insert_sql = DEFAULT_INSERT_SQL + for lo in range(0, subset_size, INSERT_BATCH_SIZE): + hi = min(lo + INSERT_BATCH_SIZE, subset_size) + stmt = insert_sql.replace(":lo", str(lo)).replace(":hi", str(hi)) + lines.append(stmt + ";") + if hi % 10000 == 0 or hi == subset_size: + lines.append("-- progress: inserted %d/%d" % (hi, subset_size)) + + # Queries (repeated) + lines.append("-- BEGIN QUERIES") + for _rep in range(repeats): + for qid in range(n_queries): + lines.append(_query_sql_for_config(params, qid, k)) + + return "\n".join(lines) + + +# ============================================================================ +# Profiling with macOS `sample` +# ============================================================================ + + +def run_profile(sqlite3_path, db_path, sql_file, sample_output, duration=120): + """Run sqlite3 under macOS `sample` profiler. + + Starts sqlite3 directly with stdin from the SQL file, then immediately + attaches `sample` to its PID with -mayDie (tolerates process exit). + The workload must be long enough for sample to attach and capture useful data. + """ + sql_fd = open(sql_file, "r") + proc = subprocess.Popen( + [sqlite3_path, db_path], + stdin=sql_fd, + stdout=subprocess.DEVNULL, + stderr=subprocess.PIPE, + ) + + pid = proc.pid + print(f" sqlite3 PID: {pid}") + + # Attach sample immediately (1ms interval, -mayDie tolerates process exit) + sample_proc = subprocess.Popen( + ["sample", str(pid), str(duration), "1", "-mayDie", "-file", sample_output], + stdout=subprocess.DEVNULL, + stderr=subprocess.PIPE, + ) + + # Wait for sqlite3 to finish + _, stderr = proc.communicate() + sql_fd.close() + rc = proc.returncode + if rc != 0: + print(f" sqlite3 failed (rc={rc}):", file=sys.stderr) + print(f" {stderr.decode().strip()}", file=sys.stderr) + sample_proc.kill() + return False + + # Wait for sample to finish + sample_proc.wait() + return True + + +# ============================================================================ +# Parse `sample` output +# ============================================================================ + +# Tree-drawing characters used by macOS `sample` to represent hierarchy. +# We replace them with spaces so indentation depth reflects tree depth. +_TREE_CHARS_RE = re.compile(r"[+!:|]") + +# After tree chars are replaced with spaces, each call-graph line looks like: +# " 800 rescore_knn (in vec0.dylib) + 3808,3640,... [0x1a,0x2b,...] file.c:123" +# We extract just (indent, count, symbol, module) — everything after "(in ...)" +# is decoration we don't need. +_LEADING_RE = re.compile(r"^(\s+)(\d+)\s+(.+)") + + +def _extract_symbol_and_module(rest): + """Given the text after 'count ', extract (symbol, module). + + Handles patterns like: + 'rescore_knn (in vec0.dylib) + 3808,3640,... [0x...]' + 'pread (in libsystem_kernel.dylib) + 8 [0x...]' + '??? (in ) [0x...]' + 'start (in dyld) + 2840 [0x198650274]' + 'Thread_26759239 DispatchQueue_1: ...' + """ + # Try to find "(in ...)" to split symbol from module + m = re.match(r"^(.+?)\s+\(in\s+(.+?)\)", rest) + if m: + return m.group(1).strip(), m.group(2).strip() + # No module — return whole thing as symbol, strip trailing junk + sym = re.sub(r"\s+\[0x[0-9a-f].*", "", rest).strip() + return sym, "" + + +def _parse_call_graph_lines(text): + """Parse call-graph section into list of (depth, count, symbol, module).""" + entries = [] + for raw_line in text.split("\n"): + # Strip tree-drawing characters, replace with spaces to preserve depth + line = _TREE_CHARS_RE.sub(" ", raw_line) + m = _LEADING_RE.match(line) + if not m: + continue + depth = len(m.group(1)) + count = int(m.group(2)) + rest = m.group(3) + symbol, module = _extract_symbol_and_module(rest) + entries.append((depth, count, symbol, module)) + return entries + + +def parse_sample_output(filepath): + """Parse `sample` call-graph output, compute exclusive (self) samples per function. + + Returns dict of {display_name: self_sample_count}. + """ + with open(filepath, "r") as f: + text = f.read() + + # Find "Call graph:" section + cg_start = text.find("Call graph:") + if cg_start == -1: + print(" Warning: no 'Call graph:' section found in sample output") + return {} + + # End at "Total number in stack" or EOF + cg_end = text.find("\nTotal number in stack", cg_start) + if cg_end == -1: + cg_end = len(text) + + entries = _parse_call_graph_lines(text[cg_start:cg_end]) + + if not entries: + print(" Warning: no call graph entries parsed") + return {} + + # Compute self (exclusive) samples per function: + # self = count - sum(direct_children_counts) + self_samples = {} + for i, (depth, count, sym, mod) in enumerate(entries): + children_sum = 0 + child_depth = None + for j in range(i + 1, len(entries)): + j_depth = entries[j][0] + if j_depth <= depth: + break + if child_depth is None: + child_depth = j_depth + if j_depth == child_depth: + children_sum += entries[j][1] + + self_count = count - children_sum + if self_count > 0: + key = f"{sym} ({mod})" if mod else sym + self_samples[key] = self_samples.get(key, 0) + self_count + + return self_samples + + +# ============================================================================ +# Display +# ============================================================================ + + +def print_profile(title, self_samples, top_n=20): + total = sum(self_samples.values()) + if total == 0: + print(f"\n=== {title} (no samples) ===") + return + + sorted_syms = sorted(self_samples.items(), key=lambda x: -x[1]) + + print(f"\n=== {title} (top {top_n}, {total} total self-samples) ===") + for sym, count in sorted_syms[:top_n]: + pct = 100.0 * count / total + print(f" {pct:5.1f}% {count:>6} {sym}") + + +# ============================================================================ +# Main +# ============================================================================ + + +def main(): + parser = argparse.ArgumentParser( + description="CPU profiling for sqlite-vec KNN configurations", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "configs", nargs="+", help="config specs (name:type=X,key=val,...)" + ) + parser.add_argument("--subset-size", type=int, required=True) + parser.add_argument("-k", type=int, default=10, help="KNN k (default 10)") + parser.add_argument( + "-n", type=int, default=50, help="number of distinct queries (default 50)" + ) + parser.add_argument( + "--repeats", + type=int, + default=10, + help="repeat query set N times for more samples (default 10)", + ) + parser.add_argument( + "--top", type=int, default=20, help="show top N functions (default 20)" + ) + parser.add_argument("--base-db", default=BASE_DB) + parser.add_argument("--sqlite3", default=SQLITE3_PATH) + parser.add_argument( + "--keep-temp", + action="store_true", + help="keep temp directory with DBs, SQL, and sample output", + ) + args = parser.parse_args() + + # Check prerequisites + if not os.path.exists(args.base_db): + print(f"Error: base DB not found at {args.base_db}", file=sys.stderr) + print("Run 'make seed' in benchmarks-ann/ first.", file=sys.stderr) + sys.exit(1) + + if not shutil.which("sample"): + print("Error: macOS 'sample' tool not found.", file=sys.stderr) + sys.exit(1) + + # Build CLI + print("Building dist/sqlite3...") + result = subprocess.run( + ["make", "cli"], cwd=_PROJECT_ROOT, capture_output=True, text=True + ) + if result.returncode != 0: + print(f"Error: make cli failed:\n{result.stderr}", file=sys.stderr) + sys.exit(1) + print(" done.") + + if not os.path.exists(args.sqlite3): + print(f"Error: sqlite3 not found at {args.sqlite3}", file=sys.stderr) + sys.exit(1) + + configs = [parse_config(c) for c in args.configs] + + tmpdir = tempfile.mkdtemp(prefix="sqlite-vec-profile-") + print(f"Working directory: {tmpdir}") + + all_profiles = [] + + for i, (name, params) in enumerate(configs, 1): + reg = INDEX_REGISTRY[params["index_type"]] + desc = reg["describe"](params) + print(f"\n[{i}/{len(configs)}] {name} ({desc})") + + # Generate SQL workload + db_path = os.path.join(tmpdir, f"{name}.db") + sql_text = generate_sql( + db_path, params, args.subset_size, args.n, args.k, args.repeats + ) + sql_file = os.path.join(tmpdir, f"{name}.sql") + with open(sql_file, "w") as f: + f.write(sql_text) + + total_queries = args.n * args.repeats + print( + f" SQL workload: {args.subset_size} inserts + " + f"{total_queries} queries ({args.n} x {args.repeats} repeats)" + ) + + # Profile + sample_file = os.path.join(tmpdir, f"{name}.sample.txt") + print(f" Profiling...") + ok = run_profile(args.sqlite3, db_path, sql_file, sample_file) + if not ok: + print(f" FAILED — skipping {name}") + all_profiles.append((name, desc, {})) + continue + + if not os.path.exists(sample_file): + print(f" Warning: sample output not created") + all_profiles.append((name, desc, {})) + continue + + # Parse + self_samples = parse_sample_output(sample_file) + all_profiles.append((name, desc, self_samples)) + + # Show individual profile + print_profile(f"{name} ({desc})", self_samples, args.top) + + # Side-by-side comparison if multiple configs + if len(all_profiles) > 1: + print("\n" + "=" * 80) + print("COMPARISON") + print("=" * 80) + + # Collect all symbols that appear in top-N of any config + all_syms = set() + for _name, _desc, prof in all_profiles: + sorted_syms = sorted(prof.items(), key=lambda x: -x[1]) + for sym, _count in sorted_syms[: args.top]: + all_syms.add(sym) + + # Build comparison table + rows = [] + for sym in all_syms: + row = [sym] + for _name, _desc, prof in all_profiles: + total = sum(prof.values()) + count = prof.get(sym, 0) + pct = 100.0 * count / total if total > 0 else 0.0 + row.append((pct, count)) + max_pct = max(r[0] for r in row[1:]) + rows.append((max_pct, row)) + + rows.sort(key=lambda x: -x[0]) + + # Header + header = f"{'function':>40}" + for name, desc, _ in all_profiles: + header += f" {name:>14}" + print(header) + print("-" * len(header)) + + for _sort_key, row in rows[: args.top * 2]: + sym = row[0] + display_sym = sym if len(sym) <= 40 else sym[:37] + "..." + line = f"{display_sym:>40}" + for pct, count in row[1:]: + if count > 0: + line += f" {pct:>13.1f}%" + else: + line += f" {'-':>14}" + print(line) + + if args.keep_temp: + print(f"\nTemp files kept at: {tmpdir}") + else: + shutil.rmtree(tmpdir) + print(f"\nTemp files cleaned up. Use --keep-temp to preserve.") + + +if __name__ == "__main__": + main() diff --git a/benchmarks-ann/schema.sql b/benchmarks-ann/schema.sql new file mode 100644 index 0000000..681df4e --- /dev/null +++ b/benchmarks-ann/schema.sql @@ -0,0 +1,35 @@ +-- Canonical results schema for vec0 KNN benchmark comparisons. +-- The index_type column is a free-form TEXT field. Baseline configs use +-- "baseline"; index-specific branches add their own types (registered +-- via INDEX_REGISTRY in bench.py). + +CREATE TABLE IF NOT EXISTS build_results ( + config_name TEXT NOT NULL, + index_type TEXT NOT NULL, + subset_size INTEGER NOT NULL, + db_path TEXT NOT NULL, + insert_time_s REAL NOT NULL, + train_time_s REAL, -- NULL when no training/build step is needed + total_time_s REAL NOT NULL, + rows INTEGER NOT NULL, + file_size_mb REAL NOT NULL, + created_at TEXT NOT NULL DEFAULT (datetime('now')), + PRIMARY KEY (config_name, subset_size) +); + +CREATE TABLE IF NOT EXISTS bench_results ( + config_name TEXT NOT NULL, + index_type TEXT NOT NULL, + subset_size INTEGER NOT NULL, + k INTEGER NOT NULL, + n INTEGER NOT NULL, + mean_ms REAL NOT NULL, + median_ms REAL NOT NULL, + p99_ms REAL NOT NULL, + total_ms REAL NOT NULL, + qps REAL NOT NULL, + recall REAL NOT NULL, + db_path TEXT NOT NULL, + created_at TEXT NOT NULL DEFAULT (datetime('now')), + PRIMARY KEY (config_name, subset_size, k) +); diff --git a/benchmarks-ann/seed/.gitignore b/benchmarks-ann/seed/.gitignore new file mode 100644 index 0000000..8efed50 --- /dev/null +++ b/benchmarks-ann/seed/.gitignore @@ -0,0 +1,2 @@ +*.parquet +base.db diff --git a/benchmarks-ann/seed/Makefile b/benchmarks-ann/seed/Makefile new file mode 100644 index 0000000..186bf66 --- /dev/null +++ b/benchmarks-ann/seed/Makefile @@ -0,0 +1,24 @@ +BASE_URL = https://assets.zilliz.com/benchmark/cohere_medium_1m + +PARQUETS = train.parquet test.parquet neighbors.parquet + +.PHONY: all download base.db clean + +all: base.db + +download: $(PARQUETS) + +train.parquet: + curl -L -o $@ $(BASE_URL)/train.parquet + +test.parquet: + curl -L -o $@ $(BASE_URL)/test.parquet + +neighbors.parquet: + curl -L -o $@ $(BASE_URL)/neighbors.parquet + +base.db: $(PARQUETS) build_base_db.py + uv run --with pandas --with pyarrow python build_base_db.py + +clean: + rm -f base.db diff --git a/benchmarks-ann/seed/build_base_db.py b/benchmarks-ann/seed/build_base_db.py new file mode 100644 index 0000000..33d280d --- /dev/null +++ b/benchmarks-ann/seed/build_base_db.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 +"""Build base.db from downloaded parquet files. + +Reads train.parquet, test.parquet, neighbors.parquet and creates a SQLite +database with tables: train, query_vectors, neighbors. + +Usage: + uv run --with pandas --with pyarrow python build_base_db.py +""" +import json +import os +import sqlite3 +import struct +import sys +import time + +import pandas as pd + + +def float_list_to_blob(floats): + """Pack a list of floats into a little-endian f32 blob.""" + return struct.pack(f"<{len(floats)}f", *floats) + + +def main(): + seed_dir = os.path.dirname(os.path.abspath(__file__)) + db_path = os.path.join(seed_dir, "base.db") + + train_path = os.path.join(seed_dir, "train.parquet") + test_path = os.path.join(seed_dir, "test.parquet") + neighbors_path = os.path.join(seed_dir, "neighbors.parquet") + + for p in (train_path, test_path, neighbors_path): + if not os.path.exists(p): + print(f"ERROR: {p} not found. Run 'make download' first.") + sys.exit(1) + + if os.path.exists(db_path): + os.remove(db_path) + + conn = sqlite3.connect(db_path) + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA page_size=4096") + + # --- query_vectors (from test.parquet) --- + print("Loading test.parquet (query vectors)...") + t0 = time.perf_counter() + df_test = pd.read_parquet(test_path) + conn.execute( + "CREATE TABLE query_vectors (id INTEGER PRIMARY KEY, vector BLOB)" + ) + rows = [] + for _, row in df_test.iterrows(): + rows.append((int(row["id"]), float_list_to_blob(row["emb"]))) + conn.executemany("INSERT INTO query_vectors (id, vector) VALUES (?, ?)", rows) + conn.commit() + print(f" {len(rows)} query vectors in {time.perf_counter() - t0:.1f}s") + + # --- neighbors (from neighbors.parquet) --- + print("Loading neighbors.parquet...") + t0 = time.perf_counter() + df_neighbors = pd.read_parquet(neighbors_path) + conn.execute( + "CREATE TABLE neighbors (" + " query_vector_id INTEGER, rank INTEGER, neighbors_id TEXT," + " UNIQUE(query_vector_id, rank))" + ) + rows = [] + for _, row in df_neighbors.iterrows(): + qid = int(row["id"]) + # neighbors_id may be a numpy array or JSON string + nids = row["neighbors_id"] + if isinstance(nids, str): + nids = json.loads(nids) + for rank, nid in enumerate(nids): + rows.append((qid, rank, str(int(nid)))) + conn.executemany( + "INSERT INTO neighbors (query_vector_id, rank, neighbors_id) VALUES (?, ?, ?)", + rows, + ) + conn.commit() + print(f" {len(rows)} neighbor rows in {time.perf_counter() - t0:.1f}s") + + # --- train (from train.parquet) --- + print("Loading train.parquet (1M vectors, this takes a few minutes)...") + t0 = time.perf_counter() + conn.execute( + "CREATE TABLE train (id INTEGER PRIMARY KEY, vector BLOB)" + ) + + batch_size = 10000 + df_iter = pd.read_parquet(train_path) + total = len(df_iter) + + for start in range(0, total, batch_size): + chunk = df_iter.iloc[start : start + batch_size] + rows = [] + for _, row in chunk.iterrows(): + rows.append((int(row["id"]), float_list_to_blob(row["emb"]))) + conn.executemany("INSERT INTO train (id, vector) VALUES (?, ?)", rows) + conn.commit() + + done = min(start + batch_size, total) + elapsed = time.perf_counter() - t0 + rate = done / elapsed if elapsed > 0 else 0 + eta = (total - done) / rate if rate > 0 else 0 + print( + f" {done:>8}/{total} {elapsed:.0f}s {rate:.0f} rows/s eta {eta:.0f}s", + flush=True, + ) + + elapsed = time.perf_counter() - t0 + print(f" {total} train vectors in {elapsed:.1f}s") + + conn.close() + size_mb = os.path.getsize(db_path) / (1024 * 1024) + print(f"\nDone: {db_path} ({size_mb:.0f} MB)") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/exhaustive-memory/bench.py b/benchmarks/exhaustive-memory/bench.py index c9da831..7c969d6 100644 --- a/benchmarks/exhaustive-memory/bench.py +++ b/benchmarks/exhaustive-memory/bench.py @@ -248,59 +248,6 @@ def bench_libsql(base, query, page_size, k) -> BenchResult: return BenchResult(f"libsql ({page_size})", build_time, times) -def register_np(db, array, name): - ptr = array.__array_interface__["data"][0] - nvectors, dimensions = array.__array_interface__["shape"] - element_type = array.__array_interface__["typestr"] - - assert element_type == " BenchResult: - print(f"sqlite-vec static...") - - db = sqlite3.connect(":memory:") - db.enable_load_extension(True) - db.load_extension("../../dist/vec0") - - - - t = time.time() - register_np(db, base, "base") - build_time = time.time() - t - - times = [] - results = [] - for ( - idx, - q, - ) in enumerate(query): - t0 = time.time() - result = db.execute( - """ - select - rowid - from base - where vector match ? - and k = ? - order by distance - """, - [q.tobytes(), k], - ).fetchall() - assert len(result) == k - times.append(time.time() - t0) - return BenchResult(f"sqlite-vec static", build_time, times) - def bench_faiss(base, query, k) -> BenchResult: import faiss dimensions = base.shape[1] @@ -438,8 +385,6 @@ def suite(name, base, query, k, benchmarks): for b in benchmarks: if b == "faiss": results.append(bench_faiss(base, query, k=k)) - elif b == "vec-static": - results.append(bench_sqlite_vec_static(base, query, k=k)) elif b.startswith("vec-scalar"): _, page_size = b.split('.') results.append(bench_sqlite_vec_scalar(base, query, page_size, k=k)) @@ -541,7 +486,7 @@ def parse_args(): help="Number of queries to use. Defaults all", ) parser.add_argument( - "-x", help="type of runs to make", default="faiss,vec-scalar.4096,vec-static,vec-vec0.4096.16,usearch,duckdb,hnswlib,numpy" + "-x", help="type of runs to make", default="faiss,vec-scalar.4096,vec-vec0.4096.16,usearch,duckdb,hnswlib,numpy" ) args = parser.parse_args() diff --git a/benchmarks/profiling/build-from-npy.sql b/benchmarks/profiling/build-from-npy.sql index 134df70..92ef59c 100644 --- a/benchmarks/profiling/build-from-npy.sql +++ b/benchmarks/profiling/build-from-npy.sql @@ -8,10 +8,3 @@ create virtual table vec_items using vec0( embedding float[1536] ); --- 65s (limit 1e5), ~615MB on disk -insert into vec_items - select - rowid, - vector - from vec_npy_each(vec_npy_file('examples/dbpedia-openai/data/vectors.npy')) - limit 1e5; diff --git a/benchmarks/self-params/build.py b/benchmarks/self-params/build.py index bc6e388..c5d9fc1 100644 --- a/benchmarks/self-params/build.py +++ b/benchmarks/self-params/build.py @@ -6,7 +6,6 @@ def connect(path): db = sqlite3.connect(path) db.enable_load_extension(True) db.load_extension("../dist/vec0") - db.execute("select load_extension('../dist/vec0', 'sqlite3_vec_fs_read_init')") db.enable_load_extension(False) return db @@ -18,8 +17,6 @@ page_sizes = [ # 4096, 8192, chunk_sizes = [128, 256, 1024, 2048] types = ["f32", "int8", "bit"] -SRC = "../examples/dbpedia-openai/data/vectors.npy" - for page_size in page_sizes: for chunk_size in chunk_sizes: for t in types: @@ -42,15 +39,8 @@ for page_size in page_sizes: func = "vec_quantize_i8(vector, 'unit')" if t == "bit": func = "vec_quantize_binary(vector)" - db.execute( - f""" - insert into vec_items - select rowid, {func} - from vec_npy_each(vec_npy_file(?)) - limit 100000 - """, - [SRC], - ) + # TODO: replace with non-npy data loading + pass elapsed = time.time() - t0 print(elapsed) diff --git a/bindings/go/ncruces/go-sqlite3.patch b/bindings/go/ncruces/go-sqlite3.patch index f202bc3..03bead9 100644 --- a/bindings/go/ncruces/go-sqlite3.patch +++ b/bindings/go/ncruces/go-sqlite3.patch @@ -6,7 +6,6 @@ index ed2aaec..4cc0b0e 100755 -Wl,--initial-memory=327680 \ -D_HAVE_SQLITE_CONFIG_H \ -DSQLITE_CUSTOM_INCLUDE=sqlite_opt.h \ -+ -DSQLITE_VEC_OMIT_FS=1 \ $(awk '{print "-Wl,--export="$0}' exports.txt) "$BINARYEN/wasm-ctor-eval" -g -c _initialize sqlite3.wasm -o sqlite3.tmp diff --git a/bindings/python/extra_init.py b/bindings/python/extra_init.py index 267bc41..4408855 100644 --- a/bindings/python/extra_init.py +++ b/bindings/python/extra_init.py @@ -1,6 +1,5 @@ from typing import List from struct import pack -from sqlite3 import Connection def serialize_float32(vector: List[float]) -> bytes: @@ -13,33 +12,3 @@ def serialize_int8(vector: List[int]) -> bytes: return pack("%sb" % len(vector), *vector) -try: - import numpy.typing as npt - - def register_numpy(db: Connection, name: str, array: npt.NDArray): - """ayoo""" - - ptr = array.__array_interface__["data"][0] - nvectors, dimensions = array.__array_interface__["shape"] - element_type = array.__array_interface__["typestr"] - - assert element_type == " dist/sqlite-vec.c +""" + +import re +import sys +import os + + +def strip_lsp_block(content): + """Remove the LSP-support pattern: + #ifndef SQLITE_VEC_H + #include "sqlite-vec.c" // ... + #endif + """ + pattern = re.compile( + r'^\s*#ifndef\s+SQLITE_VEC_H\s*\n' + r'\s*#include\s+"sqlite-vec\.c"[^\n]*\n' + r'\s*#endif[^\n]*\n', + re.MULTILINE, + ) + return pattern.sub('', content) + + +def strip_include_guard(content, guard_macro): + """Remove the include guard pair: + #ifndef GUARD_MACRO + #define GUARD_MACRO + ...content... + (trailing #endif removed) + """ + # Strip the #ifndef / #define pair at the top + header_pattern = re.compile( + r'^\s*#ifndef\s+' + re.escape(guard_macro) + r'\s*\n' + r'\s*#define\s+' + re.escape(guard_macro) + r'\s*\n', + re.MULTILINE, + ) + content = header_pattern.sub('', content, count=1) + + # Strip the trailing #endif (last one in file that closes the guard) + # Find the last #endif and remove it + lines = content.rstrip('\n').split('\n') + for i in range(len(lines) - 1, -1, -1): + if re.match(r'^\s*#endif', lines[i]): + lines.pop(i) + break + + return '\n'.join(lines) + '\n' + + +def detect_include_guard(content): + """Detect an include guard macro like SQLITE_VEC_IVF_C.""" + m = re.match( + r'\s*(?:/\*[\s\S]*?\*/\s*)?' # optional block comment + r'#ifndef\s+(SQLITE_VEC_\w+_C)\s*\n' + r'#define\s+\1', + content, + ) + return m.group(1) if m else None + + +def inline_include(match, base_dir): + """Replace an #include "sqlite-vec-*.c" with the file's contents.""" + filename = match.group(1) + filepath = os.path.join(base_dir, filename) + + if not os.path.exists(filepath): + print(f"Warning: {filepath} not found, leaving #include in place", file=sys.stderr) + return match.group(0) + + with open(filepath, 'r') as f: + content = f.read() + + # Strip LSP-support block + content = strip_lsp_block(content) + + # Strip include guard if present + guard = detect_include_guard(content) + if guard: + content = strip_include_guard(content, guard) + + separator = '/' * 78 + header = f'\n{separator}\n// Begin inlined: {filename}\n{separator}\n\n' + footer = f'\n{separator}\n// End inlined: {filename}\n{separator}\n' + + return header + content.strip('\n') + footer + + +def amalgamate(input_path): + base_dir = os.path.dirname(os.path.abspath(input_path)) + + with open(input_path, 'r') as f: + content = f.read() + + # Replace #include "sqlite-vec-*.c" with inlined contents + include_pattern = re.compile(r'^#include\s+"(sqlite-vec-[^"]+\.c)"\s*$', re.MULTILINE) + content = include_pattern.sub(lambda m: inline_include(m, base_dir), content) + + return content + + +def main(): + if len(sys.argv) != 2: + print(f"Usage: {sys.argv[0]} ", file=sys.stderr) + sys.exit(1) + + result = amalgamate(sys.argv[1]) + sys.stdout.write(result) + + +if __name__ == '__main__': + main() diff --git a/site/api-reference.md b/site/api-reference.md index bd144ea..ba8c648 100644 --- a/site/api-reference.md +++ b/site/api-reference.md @@ -568,65 +568,6 @@ select 'todo'; -- 'todo' -``` - -## NumPy Utilities {#numpy} - -Functions to read data from or work with [NumPy arrays](https://numpy.org/doc/stable/reference/generated/numpy.array.html). - -### `vec_npy_each(vector)` {#vec_npy_each} - -xxx - - -```sql --- db.execute('select quote(?)', [to_npy(np.array([[1.0], [2.0], [3.0]], dtype=np.float32))]).fetchone() -select - rowid, - vector, - vec_type(vector), - vec_to_json(vector) -from vec_npy_each( - X'934E554D5059010076007B276465736372273A20273C6634272C2027666F727472616E5F6F72646572273A2046616C73652C20277368617065273A2028332C2031292C207D202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020200A0000803F0000004000004040' -) -/* -┌───────┬─────────────┬──────────────────┬─────────────────────┐ -│ rowid │ vector │ vec_type(vector) │ vec_to_json(vector) │ -├───────┼─────────────┼──────────────────┼─────────────────────┤ -│ 0 │ X'0000803F' │ 'float32' │ '[1.000000]' │ -├───────┼─────────────┼──────────────────┼─────────────────────┤ -│ 1 │ X'00000040' │ 'float32' │ '[2.000000]' │ -├───────┼─────────────┼──────────────────┼─────────────────────┤ -│ 2 │ X'00004040' │ 'float32' │ '[3.000000]' │ -└───────┴─────────────┴──────────────────┴─────────────────────┘ - -*/ - - --- db.execute('select quote(?)', [to_npy(np.array([[1.0], [2.0], [3.0]], dtype=np.float32))]).fetchone() -select - rowid, - vector, - vec_type(vector), - vec_to_json(vector) -from vec_npy_each( - X'934E554D5059010076007B276465736372273A20273C6634272C2027666F727472616E5F6F72646572273A2046616C73652C20277368617065273A2028332C2031292C207D202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020202020200A0000803F0000004000004040' -) -/* -┌───────┬─────────────┬──────────────────┬─────────────────────┐ -│ rowid │ vector │ vec_type(vector) │ vec_to_json(vector) │ -├───────┼─────────────┼──────────────────┼─────────────────────┤ -│ 0 │ X'0000803F' │ 'float32' │ '[1.000000]' │ -├───────┼─────────────┼──────────────────┼─────────────────────┤ -│ 1 │ X'00000040' │ 'float32' │ '[2.000000]' │ -├───────┼─────────────┼──────────────────┼─────────────────────┤ -│ 2 │ X'00004040' │ 'float32' │ '[3.000000]' │ -└───────┴─────────────┴──────────────────┴─────────────────────┘ - -*/ - - - ``` ## Meta {#meta} diff --git a/site/compiling.md b/site/compiling.md index 9ce3c83..b3b2e33 100644 --- a/site/compiling.md +++ b/site/compiling.md @@ -59,5 +59,4 @@ The current compile-time flags are: - `SQLITE_VEC_ENABLE_AVX`, enables AVX CPU instructions for some vector search operations - `SQLITE_VEC_ENABLE_NEON`, enables NEON CPU instructions for some vector search operations -- `SQLITE_VEC_OMIT_FS`, removes some obsure SQL functions and features that use the filesystem, meant for some WASM builds where there's no available filesystem - `SQLITE_VEC_STATIC`, meant for statically linking `sqlite-vec` diff --git a/sqlite-vec.c b/sqlite-vec.c index c1874a7..390123b 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -11,7 +11,7 @@ #include #include -#ifndef SQLITE_VEC_OMIT_FS +#ifdef SQLITE_VEC_DEBUG #include #endif @@ -224,6 +224,63 @@ static f32 l2_sqr_float_neon(const void *pVect1v, const void *pVect2v, return sqrt(sum_scalar); } +static f32 cosine_float_neon(const void *pVect1v, const void *pVect2v, + const void *qty_ptr) { + f32 *pVect1 = (f32 *)pVect1v; + f32 *pVect2 = (f32 *)pVect2v; + size_t qty = *((size_t *)qty_ptr); + size_t qty16 = qty >> 4; + const f32 *pEnd1 = pVect1 + (qty16 << 4); + + float32x4_t dot0 = vdupq_n_f32(0), dot1 = vdupq_n_f32(0); + float32x4_t dot2 = vdupq_n_f32(0), dot3 = vdupq_n_f32(0); + float32x4_t amag0 = vdupq_n_f32(0), amag1 = vdupq_n_f32(0); + float32x4_t amag2 = vdupq_n_f32(0), amag3 = vdupq_n_f32(0); + float32x4_t bmag0 = vdupq_n_f32(0), bmag1 = vdupq_n_f32(0); + float32x4_t bmag2 = vdupq_n_f32(0), bmag3 = vdupq_n_f32(0); + + while (pVect1 < pEnd1) { + float32x4_t v1, v2; + v1 = vld1q_f32(pVect1); pVect1 += 4; + v2 = vld1q_f32(pVect2); pVect2 += 4; + dot0 = vfmaq_f32(dot0, v1, v2); + amag0 = vfmaq_f32(amag0, v1, v1); + bmag0 = vfmaq_f32(bmag0, v2, v2); + + v1 = vld1q_f32(pVect1); pVect1 += 4; + v2 = vld1q_f32(pVect2); pVect2 += 4; + dot1 = vfmaq_f32(dot1, v1, v2); + amag1 = vfmaq_f32(amag1, v1, v1); + bmag1 = vfmaq_f32(bmag1, v2, v2); + + v1 = vld1q_f32(pVect1); pVect1 += 4; + v2 = vld1q_f32(pVect2); pVect2 += 4; + dot2 = vfmaq_f32(dot2, v1, v2); + amag2 = vfmaq_f32(amag2, v1, v1); + bmag2 = vfmaq_f32(bmag2, v2, v2); + + v1 = vld1q_f32(pVect1); pVect1 += 4; + v2 = vld1q_f32(pVect2); pVect2 += 4; + dot3 = vfmaq_f32(dot3, v1, v2); + amag3 = vfmaq_f32(amag3, v1, v1); + bmag3 = vfmaq_f32(bmag3, v2, v2); + } + + f32 dot_s = vaddvq_f32(vaddq_f32(vaddq_f32(dot0, dot1), vaddq_f32(dot2, dot3))); + f32 amag_s = vaddvq_f32(vaddq_f32(vaddq_f32(amag0, amag1), vaddq_f32(amag2, amag3))); + f32 bmag_s = vaddvq_f32(vaddq_f32(vaddq_f32(bmag0, bmag1), vaddq_f32(bmag2, bmag3))); + + const f32 *pEnd2 = pVect1 + (qty - (qty16 << 4)); + while (pVect1 < pEnd2) { + dot_s += *pVect1 * *pVect2; + amag_s += *pVect1 * *pVect1; + bmag_s += *pVect2 * *pVect2; + pVect1++; pVect2++; + } + + return 1.0f - (dot_s / (sqrtf(amag_s) * sqrtf(bmag_s))); +} + static f32 l2_sqr_int8_neon(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { i8 *pVect1 = (i8 *)pVect1v; @@ -462,6 +519,11 @@ static double distance_l1_f32(const void *a, const void *b, const void *d) { static f32 distance_cosine_float(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { +#ifdef SQLITE_VEC_ENABLE_NEON + if ((*(const size_t *)qty_ptr) > 16) { + return cosine_float_neon(pVect1v, pVect2v, qty_ptr); + } +#endif f32 *pVect1 = (f32 *)pVect1v; f32 *pVect2 = (f32 *)pVect2v; size_t qty = *((size_t *)qty_ptr); @@ -478,8 +540,7 @@ static f32 distance_cosine_float(const void *pVect1v, const void *pVect2v, } return 1 - (dot / (sqrt(aMag) * sqrt(bMag))); } -static f32 distance_cosine_int8(const void *pA, const void *pB, - const void *pD) { +static f32 cosine_int8(const void *pA, const void *pB, const void *pD) { i8 *a = (i8 *)pA; i8 *b = (i8 *)pB; size_t d = *((size_t *)pD); @@ -497,6 +558,125 @@ static f32 distance_cosine_int8(const void *pA, const void *pB, return 1 - (dot / (sqrt(aMag) * sqrt(bMag))); } +#ifdef SQLITE_VEC_ENABLE_NEON +static f32 cosine_int8_neon(const void *pA, const void *pB, const void *pD) { + const i8 *a = (const i8 *)pA; + const i8 *b = (const i8 *)pB; + size_t d = *((const size_t *)pD); + const i8 *aEnd = a + d; + + int32x4_t dot_acc1 = vdupq_n_s32(0); + int32x4_t dot_acc2 = vdupq_n_s32(0); + int32x4_t aMag_acc1 = vdupq_n_s32(0); + int32x4_t aMag_acc2 = vdupq_n_s32(0); + int32x4_t bMag_acc1 = vdupq_n_s32(0); + int32x4_t bMag_acc2 = vdupq_n_s32(0); + + while (a < aEnd - 31) { + int8x16_t va1 = vld1q_s8(a); + int8x16_t vb1 = vld1q_s8(b); + int16x8_t a1_lo = vmovl_s8(vget_low_s8(va1)); + int16x8_t a1_hi = vmovl_s8(vget_high_s8(va1)); + int16x8_t b1_lo = vmovl_s8(vget_low_s8(vb1)); + int16x8_t b1_hi = vmovl_s8(vget_high_s8(vb1)); + + dot_acc1 = vmlal_s16(dot_acc1, vget_low_s16(a1_lo), vget_low_s16(b1_lo)); + dot_acc1 = vmlal_s16(dot_acc1, vget_high_s16(a1_lo), vget_high_s16(b1_lo)); + dot_acc2 = vmlal_s16(dot_acc2, vget_low_s16(a1_hi), vget_low_s16(b1_hi)); + dot_acc2 = vmlal_s16(dot_acc2, vget_high_s16(a1_hi), vget_high_s16(b1_hi)); + + aMag_acc1 = vmlal_s16(aMag_acc1, vget_low_s16(a1_lo), vget_low_s16(a1_lo)); + aMag_acc1 = vmlal_s16(aMag_acc1, vget_high_s16(a1_lo), vget_high_s16(a1_lo)); + aMag_acc2 = vmlal_s16(aMag_acc2, vget_low_s16(a1_hi), vget_low_s16(a1_hi)); + aMag_acc2 = vmlal_s16(aMag_acc2, vget_high_s16(a1_hi), vget_high_s16(a1_hi)); + + bMag_acc1 = vmlal_s16(bMag_acc1, vget_low_s16(b1_lo), vget_low_s16(b1_lo)); + bMag_acc1 = vmlal_s16(bMag_acc1, vget_high_s16(b1_lo), vget_high_s16(b1_lo)); + bMag_acc2 = vmlal_s16(bMag_acc2, vget_low_s16(b1_hi), vget_low_s16(b1_hi)); + bMag_acc2 = vmlal_s16(bMag_acc2, vget_high_s16(b1_hi), vget_high_s16(b1_hi)); + + int8x16_t va2 = vld1q_s8(a + 16); + int8x16_t vb2 = vld1q_s8(b + 16); + int16x8_t a2_lo = vmovl_s8(vget_low_s8(va2)); + int16x8_t a2_hi = vmovl_s8(vget_high_s8(va2)); + int16x8_t b2_lo = vmovl_s8(vget_low_s8(vb2)); + int16x8_t b2_hi = vmovl_s8(vget_high_s8(vb2)); + + dot_acc1 = vmlal_s16(dot_acc1, vget_low_s16(a2_lo), vget_low_s16(b2_lo)); + dot_acc1 = vmlal_s16(dot_acc1, vget_high_s16(a2_lo), vget_high_s16(b2_lo)); + dot_acc2 = vmlal_s16(dot_acc2, vget_low_s16(a2_hi), vget_low_s16(b2_hi)); + dot_acc2 = vmlal_s16(dot_acc2, vget_high_s16(a2_hi), vget_high_s16(b2_hi)); + + aMag_acc1 = vmlal_s16(aMag_acc1, vget_low_s16(a2_lo), vget_low_s16(a2_lo)); + aMag_acc1 = vmlal_s16(aMag_acc1, vget_high_s16(a2_lo), vget_high_s16(a2_lo)); + aMag_acc2 = vmlal_s16(aMag_acc2, vget_low_s16(a2_hi), vget_low_s16(a2_hi)); + aMag_acc2 = vmlal_s16(aMag_acc2, vget_high_s16(a2_hi), vget_high_s16(a2_hi)); + + bMag_acc1 = vmlal_s16(bMag_acc1, vget_low_s16(b2_lo), vget_low_s16(b2_lo)); + bMag_acc1 = vmlal_s16(bMag_acc1, vget_high_s16(b2_lo), vget_high_s16(b2_lo)); + bMag_acc2 = vmlal_s16(bMag_acc2, vget_low_s16(b2_hi), vget_low_s16(b2_hi)); + bMag_acc2 = vmlal_s16(bMag_acc2, vget_high_s16(b2_hi), vget_high_s16(b2_hi)); + + a += 32; + b += 32; + } + + while (a < aEnd - 15) { + int8x16_t va = vld1q_s8(a); + int8x16_t vb = vld1q_s8(b); + int16x8_t a_lo = vmovl_s8(vget_low_s8(va)); + int16x8_t a_hi = vmovl_s8(vget_high_s8(va)); + int16x8_t b_lo = vmovl_s8(vget_low_s8(vb)); + int16x8_t b_hi = vmovl_s8(vget_high_s8(vb)); + + dot_acc1 = vmlal_s16(dot_acc1, vget_low_s16(a_lo), vget_low_s16(b_lo)); + dot_acc1 = vmlal_s16(dot_acc1, vget_high_s16(a_lo), vget_high_s16(b_lo)); + dot_acc1 = vmlal_s16(dot_acc1, vget_low_s16(a_hi), vget_low_s16(b_hi)); + dot_acc1 = vmlal_s16(dot_acc1, vget_high_s16(a_hi), vget_high_s16(b_hi)); + + aMag_acc1 = vmlal_s16(aMag_acc1, vget_low_s16(a_lo), vget_low_s16(a_lo)); + aMag_acc1 = vmlal_s16(aMag_acc1, vget_high_s16(a_lo), vget_high_s16(a_lo)); + aMag_acc1 = vmlal_s16(aMag_acc1, vget_low_s16(a_hi), vget_low_s16(a_hi)); + aMag_acc1 = vmlal_s16(aMag_acc1, vget_high_s16(a_hi), vget_high_s16(a_hi)); + + bMag_acc1 = vmlal_s16(bMag_acc1, vget_low_s16(b_lo), vget_low_s16(b_lo)); + bMag_acc1 = vmlal_s16(bMag_acc1, vget_high_s16(b_lo), vget_high_s16(b_lo)); + bMag_acc1 = vmlal_s16(bMag_acc1, vget_low_s16(b_hi), vget_low_s16(b_hi)); + bMag_acc1 = vmlal_s16(bMag_acc1, vget_high_s16(b_hi), vget_high_s16(b_hi)); + + a += 16; + b += 16; + } + + int32x4_t dot_sum = vaddq_s32(dot_acc1, dot_acc2); + int32x4_t aMag_sum = vaddq_s32(aMag_acc1, aMag_acc2); + int32x4_t bMag_sum = vaddq_s32(bMag_acc1, bMag_acc2); + + i32 dot = vaddvq_s32(dot_sum); + i32 aMag = vaddvq_s32(aMag_sum); + i32 bMag = vaddvq_s32(bMag_sum); + + while (a < aEnd) { + dot += (i32)*a * (i32)*b; + aMag += (i32)*a * (i32)*a; + bMag += (i32)*b * (i32)*b; + a++; + b++; + } + + return 1.0f - ((f32)dot / (sqrtf((f32)aMag) * sqrtf((f32)bMag))); +} +#endif + +static f32 distance_cosine_int8(const void *a, const void *b, const void *d) { +#ifdef SQLITE_VEC_ENABLE_NEON + if ((*(const size_t *)d) > 15) { + return cosine_int8_neon(a, b, d); + } +#endif + return cosine_int8(a, b, d); +} + // https://github.com/facebookresearch/faiss/blob/77e2e79cd0a680adc343b9840dd865da724c579e/faiss/utils/hamming_distance/common.h#L34 static u8 hamdist_table[256] = { 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 1, 2, 2, 3, 2, 3, 3, 4, @@ -511,6 +691,59 @@ static u8 hamdist_table[256] = { 4, 5, 5, 6, 5, 6, 6, 7, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8}; +#ifdef SQLITE_VEC_ENABLE_NEON +static f32 distance_hamming_neon(const u8 *a, const u8 *b, size_t n_bytes) { + const u8 *pEnd = a + n_bytes; + + uint32x4_t acc1 = vdupq_n_u32(0); + uint32x4_t acc2 = vdupq_n_u32(0); + uint32x4_t acc3 = vdupq_n_u32(0); + uint32x4_t acc4 = vdupq_n_u32(0); + + while (a <= pEnd - 64) { + uint8x16_t v1 = vld1q_u8(a); + uint8x16_t v2 = vld1q_u8(b); + acc1 = vaddq_u32(acc1, vpaddlq_u16(vpaddlq_u8(vcntq_u8(veorq_u8(v1, v2))))); + + v1 = vld1q_u8(a + 16); + v2 = vld1q_u8(b + 16); + acc2 = vaddq_u32(acc2, vpaddlq_u16(vpaddlq_u8(vcntq_u8(veorq_u8(v1, v2))))); + + v1 = vld1q_u8(a + 32); + v2 = vld1q_u8(b + 32); + acc3 = vaddq_u32(acc3, vpaddlq_u16(vpaddlq_u8(vcntq_u8(veorq_u8(v1, v2))))); + + v1 = vld1q_u8(a + 48); + v2 = vld1q_u8(b + 48); + acc4 = vaddq_u32(acc4, vpaddlq_u16(vpaddlq_u8(vcntq_u8(veorq_u8(v1, v2))))); + + a += 64; + b += 64; + } + + while (a <= pEnd - 16) { + uint8x16_t v1 = vld1q_u8(a); + uint8x16_t v2 = vld1q_u8(b); + acc1 = vaddq_u32(acc1, vpaddlq_u16(vpaddlq_u8(vcntq_u8(veorq_u8(v1, v2))))); + a += 16; + b += 16; + } + + acc1 = vaddq_u32(acc1, acc2); + acc3 = vaddq_u32(acc3, acc4); + acc1 = vaddq_u32(acc1, acc3); + u32 sum = vaddvq_u32(acc1); + + while (a < pEnd) { + sum += hamdist_table[*a ^ *b]; + a++; + b++; + } + + return (f32)sum; +} +#endif + static f32 distance_hamming_u8(u8 *a, u8 *b, size_t n) { int same = 0; for (unsigned long i = 0; i < n; i++) { @@ -555,11 +788,18 @@ static f32 distance_hamming_u64(u64 *a, u64 *b, size_t n) { */ static f32 distance_hamming(const void *a, const void *b, const void *d) { size_t dimensions = *((size_t *)d); + size_t n_bytes = dimensions / CHAR_BIT; + +#ifdef SQLITE_VEC_ENABLE_NEON + if (dimensions >= 128) { + return distance_hamming_neon((const u8 *)a, (const u8 *)b, n_bytes); + } +#endif if ((dimensions % 64) == 0) { - return distance_hamming_u64((u64 *)a, (u64 *)b, dimensions / 8 / CHAR_BIT); + return distance_hamming_u64((u64 *)a, (u64 *)b, n_bytes / sizeof(u64)); } - return distance_hamming_u8((u8 *)a, (u8 *)b, dimensions / CHAR_BIT); + return distance_hamming_u8((u8 *)a, (u8 *)b, n_bytes); } #ifdef SQLITE_VEC_TEST @@ -1065,33 +1305,6 @@ int ensure_vector_match(sqlite3_value *aValue, sqlite3_value *bValue, void **a, int _cmp(const void *a, const void *b) { return (*(i64 *)a - *(i64 *)b); } -struct VecNpyFile { - char *path; - size_t pathLength; -}; -#define SQLITE_VEC_NPY_FILE_NAME "vec0-npy-file" - -#ifndef SQLITE_VEC_OMIT_FS -static void vec_npy_file(sqlite3_context *context, int argc, - sqlite3_value **argv) { - assert(argc == 1); - char *path = (char *)sqlite3_value_text(argv[0]); - size_t pathLength = sqlite3_value_bytes(argv[0]); - struct VecNpyFile *f; - - f = sqlite3_malloc(sizeof(*f)); - if (!f) { - sqlite3_result_error_nomem(context); - return; - } - memset(f, 0, sizeof(*f)); - - f->path = path; - f->pathLength = pathLength; - sqlite3_result_pointer(context, f, SQLITE_VEC_NPY_FILE_NAME, sqlite3_free); -} -#endif - #pragma region scalar functions static void vec_f32(sqlite3_context *context, int argc, sqlite3_value **argv) { assert(argc == 1); @@ -2281,12 +2494,53 @@ enum Vec0DistanceMetrics { VEC0_DISTANCE_METRIC_L1 = 3, }; +/** + * Compute distance between two full-precision vectors using the appropriate + * distance function for the given element type and metric. + * Shared utility used by ANN index implementations. + */ +static f32 vec0_distance_full( + const void *a, const void *b, size_t dimensions, + enum VectorElementType elementType, + enum Vec0DistanceMetrics metric) { + switch (elementType) { + case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: + switch (metric) { + case VEC0_DISTANCE_METRIC_L2: + return distance_l2_sqr_float(a, b, &dimensions); + case VEC0_DISTANCE_METRIC_COSINE: + return distance_cosine_float(a, b, &dimensions); + case VEC0_DISTANCE_METRIC_L1: + return (f32)distance_l1_f32(a, b, &dimensions); + } + break; + case SQLITE_VEC_ELEMENT_TYPE_INT8: + switch (metric) { + case VEC0_DISTANCE_METRIC_L2: + return distance_l2_sqr_int8(a, b, &dimensions); + case VEC0_DISTANCE_METRIC_COSINE: + return distance_cosine_int8(a, b, &dimensions); + case VEC0_DISTANCE_METRIC_L1: + return (f32)distance_l1_int8(a, b, &dimensions); + } + break; + case SQLITE_VEC_ELEMENT_TYPE_BIT: + return distance_hamming(a, b, &dimensions); + } + return 0.0f; +} + +enum Vec0IndexType { + VEC0_INDEX_TYPE_FLAT = 1, +}; + struct VectorColumnDefinition { char *name; int name_length; size_t dimensions; enum VectorElementType element_type; enum Vec0DistanceMetrics distance_metric; + enum Vec0IndexType index_type; }; struct Vec0PartitionColumnDefinition { @@ -2346,6 +2600,7 @@ int vec0_parse_vector_column(const char *source, int source_length, int nameLength; enum VectorElementType elementType; enum Vec0DistanceMetrics distanceMetric = VEC0_DISTANCE_METRIC_L2; + enum Vec0IndexType indexType = VEC0_INDEX_TYPE_FLAT; int dimensions; vec0_scanner_init(&scanner, source, source_length); @@ -2449,6 +2704,40 @@ int vec0_parse_vector_column(const char *source, int source_length, return SQLITE_ERROR; } } + else if (sqlite3_strnicmp(key, "indexed", keyLength) == 0) { + // expect "by" + rc = vec0_scanner_next(&scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME || + token.token_type != TOKEN_TYPE_IDENTIFIER || + sqlite3_strnicmp(token.start, "by", token.end - token.start) != 0) { + return SQLITE_ERROR; + } + // expect index type name + rc = vec0_scanner_next(&scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME || + token.token_type != TOKEN_TYPE_IDENTIFIER) { + return SQLITE_ERROR; + } + int indexNameLen = token.end - token.start; + if (sqlite3_strnicmp(token.start, "flat", indexNameLen) == 0) { + indexType = VEC0_INDEX_TYPE_FLAT; + // expect '(' + rc = vec0_scanner_next(&scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME || + token.token_type != TOKEN_TYPE_LPAREN) { + return SQLITE_ERROR; + } + // expect ')' + rc = vec0_scanner_next(&scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME || + token.token_type != TOKEN_TYPE_RPAREN) { + return SQLITE_ERROR; + } + } else { + // unknown index type + return SQLITE_ERROR; + } + } // unknown key else { return SQLITE_ERROR; @@ -2463,6 +2752,7 @@ int vec0_parse_vector_column(const char *source, int source_length, outColumn->distance_metric = distanceMetric; outColumn->element_type = elementType; outColumn->dimensions = dimensions; + outColumn->index_type = indexType; return SQLITE_OK; } @@ -2660,758 +2950,6 @@ static sqlite3_module vec_eachModule = { #pragma endregion -#pragma region vec_npy_each table function - -enum NpyTokenType { - NPY_TOKEN_TYPE_IDENTIFIER, - NPY_TOKEN_TYPE_NUMBER, - NPY_TOKEN_TYPE_LPAREN, - NPY_TOKEN_TYPE_RPAREN, - NPY_TOKEN_TYPE_LBRACE, - NPY_TOKEN_TYPE_RBRACE, - NPY_TOKEN_TYPE_COLON, - NPY_TOKEN_TYPE_COMMA, - NPY_TOKEN_TYPE_STRING, - NPY_TOKEN_TYPE_FALSE, -}; - -struct NpyToken { - enum NpyTokenType token_type; - unsigned char *start; - unsigned char *end; -}; - -int npy_token_next(unsigned char *start, unsigned char *end, - struct NpyToken *out) { - unsigned char *ptr = start; - while (ptr < end) { - unsigned char curr = *ptr; - if (is_whitespace(curr)) { - ptr++; - continue; - } else if (curr == '(') { - out->start = ptr++; - out->end = ptr; - out->token_type = NPY_TOKEN_TYPE_LPAREN; - return VEC0_TOKEN_RESULT_SOME; - } else if (curr == ')') { - out->start = ptr++; - out->end = ptr; - out->token_type = NPY_TOKEN_TYPE_RPAREN; - return VEC0_TOKEN_RESULT_SOME; - } else if (curr == '{') { - out->start = ptr++; - out->end = ptr; - out->token_type = NPY_TOKEN_TYPE_LBRACE; - return VEC0_TOKEN_RESULT_SOME; - } else if (curr == '}') { - out->start = ptr++; - out->end = ptr; - out->token_type = NPY_TOKEN_TYPE_RBRACE; - return VEC0_TOKEN_RESULT_SOME; - } else if (curr == ':') { - out->start = ptr++; - out->end = ptr; - out->token_type = NPY_TOKEN_TYPE_COLON; - return VEC0_TOKEN_RESULT_SOME; - } else if (curr == ',') { - out->start = ptr++; - out->end = ptr; - out->token_type = NPY_TOKEN_TYPE_COMMA; - return VEC0_TOKEN_RESULT_SOME; - } else if (curr == '\'') { - unsigned char *start = ptr; - ptr++; - while (ptr < end) { - if ((*ptr) == '\'') { - break; - } - ptr++; - } - if (ptr >= end || (*ptr) != '\'') { - return VEC0_TOKEN_RESULT_ERROR; - } - out->start = start; - out->end = ++ptr; - out->token_type = NPY_TOKEN_TYPE_STRING; - return VEC0_TOKEN_RESULT_SOME; - } else if (curr == 'F' && - strncmp((char *)ptr, "False", strlen("False")) == 0) { - out->start = ptr; - out->end = (ptr + (int)strlen("False")); - ptr = out->end; - out->token_type = NPY_TOKEN_TYPE_FALSE; - return VEC0_TOKEN_RESULT_SOME; - } else if (is_digit(curr)) { - unsigned char *start = ptr; - while (ptr < end && (is_digit(*ptr))) { - ptr++; - } - out->start = start; - out->end = ptr; - out->token_type = NPY_TOKEN_TYPE_NUMBER; - return VEC0_TOKEN_RESULT_SOME; - } else { - return VEC0_TOKEN_RESULT_ERROR; - } - } - return VEC0_TOKEN_RESULT_ERROR; -} - -struct NpyScanner { - unsigned char *start; - unsigned char *end; - unsigned char *ptr; -}; - -void npy_scanner_init(struct NpyScanner *scanner, const unsigned char *source, - int source_length) { - scanner->start = (unsigned char *)source; - scanner->end = (unsigned char *)source + source_length; - scanner->ptr = (unsigned char *)source; -} - -int npy_scanner_next(struct NpyScanner *scanner, struct NpyToken *out) { - int rc = npy_token_next(scanner->start, scanner->end, out); - if (rc == VEC0_TOKEN_RESULT_SOME) { - scanner->start = out->end; - } - return rc; -} - -#define NPY_PARSE_ERROR "Error parsing numpy array: " -int parse_npy_header(sqlite3_vtab *pVTab, const unsigned char *header, - size_t headerLength, - enum VectorElementType *out_element_type, - int *fortran_order, size_t *numElements, - size_t *numDimensions) { - - struct NpyScanner scanner; - struct NpyToken token; - int rc; - npy_scanner_init(&scanner, header, headerLength); - - if (npy_scanner_next(&scanner, &token) != VEC0_TOKEN_RESULT_SOME && - token.token_type != NPY_TOKEN_TYPE_LBRACE) { - vtab_set_error(pVTab, - NPY_PARSE_ERROR "numpy header did not start with '{'"); - return SQLITE_ERROR; - } - while (1) { - rc = npy_scanner_next(&scanner, &token); - if (rc != VEC0_TOKEN_RESULT_SOME) { - vtab_set_error(pVTab, NPY_PARSE_ERROR "expected key in numpy header"); - return SQLITE_ERROR; - } - - if (token.token_type == NPY_TOKEN_TYPE_RBRACE) { - break; - } - if (token.token_type != NPY_TOKEN_TYPE_STRING) { - vtab_set_error(pVTab, NPY_PARSE_ERROR - "expected a string as key in numpy header"); - return SQLITE_ERROR; - } - unsigned char *key = token.start; - - rc = npy_scanner_next(&scanner, &token); - if ((rc != VEC0_TOKEN_RESULT_SOME) || - (token.token_type != NPY_TOKEN_TYPE_COLON)) { - vtab_set_error(pVTab, NPY_PARSE_ERROR - "expected a ':' after key in numpy header"); - return SQLITE_ERROR; - } - - if (strncmp((char *)key, "'descr'", strlen("'descr'")) == 0) { - rc = npy_scanner_next(&scanner, &token); - if ((rc != VEC0_TOKEN_RESULT_SOME) || - (token.token_type != NPY_TOKEN_TYPE_STRING)) { - vtab_set_error(pVTab, NPY_PARSE_ERROR - "expected a string value after 'descr' key"); - return SQLITE_ERROR; - } - if (strncmp((char *)token.start, "'maxChunks = 1024; - pCur->chunksBufferSize = - (vector_byte_size(element_type, numDimensions)) * pCur->maxChunks; - pCur->chunksBuffer = sqlite3_malloc(pCur->chunksBufferSize); - if (pCur->chunksBufferSize && !pCur->chunksBuffer) { - return SQLITE_NOMEM; - } - - pCur->currentChunkSize = - fread(pCur->chunksBuffer, vector_byte_size(element_type, numDimensions), - pCur->maxChunks, file); - - pCur->currentChunkIndex = 0; - pCur->elementType = element_type; - pCur->nElements = numElements; - pCur->nDimensions = numDimensions; - pCur->input_type = VEC_NPY_EACH_INPUT_FILE; - - pCur->eof = pCur->currentChunkSize == 0; - pCur->file = file; - return SQLITE_OK; -} -#endif - -int parse_npy_buffer(sqlite3_vtab *pVTab, const unsigned char *buffer, - int bufferLength, void **data, size_t *numElements, - size_t *numDimensions, - enum VectorElementType *element_type) { - - if (bufferLength < 10) { - // IMP: V03312_20150 - vtab_set_error(pVTab, "numpy array too short"); - return SQLITE_ERROR; - } - if (memcmp(NPY_MAGIC, buffer, sizeof(NPY_MAGIC)) != 0) { - // V11954_28792 - vtab_set_error(pVTab, "numpy array does not contain the 'magic' header"); - return SQLITE_ERROR; - } - - u8 major = buffer[6]; - u8 minor = buffer[7]; - uint16_t headerLength = 0; - memcpy(&headerLength, &buffer[8], sizeof(uint16_t)); - - i32 totalHeaderLength = sizeof(NPY_MAGIC) + sizeof(major) + sizeof(minor) + - sizeof(headerLength) + headerLength; - i32 dataSize = bufferLength - totalHeaderLength; - - if (dataSize < 0) { - vtab_set_error(pVTab, "numpy array header length is invalid"); - return SQLITE_ERROR; - } - - const unsigned char *header = &buffer[10]; - int fortran_order; - - int rc = parse_npy_header(pVTab, header, headerLength, element_type, - &fortran_order, numElements, numDimensions); - if (rc != SQLITE_OK) { - return rc; - } - - i32 expectedDataSize = - (*numElements * vector_byte_size(*element_type, *numDimensions)); - if (expectedDataSize != dataSize) { - vtab_set_error(pVTab, - "numpy array error: Expected a data size of %d, found %d", - expectedDataSize, dataSize); - return SQLITE_ERROR; - } - - *data = (void *)&buffer[totalHeaderLength]; - return SQLITE_OK; -} - -static int vec_npy_eachConnect(sqlite3 *db, void *pAux, int argc, - const char *const *argv, sqlite3_vtab **ppVtab, - char **pzErr) { - UNUSED_PARAMETER(pAux); - UNUSED_PARAMETER(argc); - UNUSED_PARAMETER(argv); - UNUSED_PARAMETER(pzErr); - vec_npy_each_vtab *pNew; - int rc; - - rc = sqlite3_declare_vtab(db, "CREATE TABLE x(vector, input hidden)"); -#define VEC_NPY_EACH_COLUMN_VECTOR 0 -#define VEC_NPY_EACH_COLUMN_INPUT 1 - if (rc == SQLITE_OK) { - pNew = sqlite3_malloc(sizeof(*pNew)); - *ppVtab = (sqlite3_vtab *)pNew; - if (pNew == 0) - return SQLITE_NOMEM; - memset(pNew, 0, sizeof(*pNew)); - } - return rc; -} - -static int vec_npy_eachDisconnect(sqlite3_vtab *pVtab) { - vec_npy_each_vtab *p = (vec_npy_each_vtab *)pVtab; - sqlite3_free(p); - return SQLITE_OK; -} - -static int vec_npy_eachOpen(sqlite3_vtab *p, sqlite3_vtab_cursor **ppCursor) { - UNUSED_PARAMETER(p); - vec_npy_each_cursor *pCur; - pCur = sqlite3_malloc(sizeof(*pCur)); - if (pCur == 0) - return SQLITE_NOMEM; - memset(pCur, 0, sizeof(*pCur)); - *ppCursor = &pCur->base; - return SQLITE_OK; -} - -static int vec_npy_eachClose(sqlite3_vtab_cursor *cur) { - vec_npy_each_cursor *pCur = (vec_npy_each_cursor *)cur; -#ifndef SQLITE_VEC_OMIT_FS - if (pCur->file) { - fclose(pCur->file); - pCur->file = NULL; - } -#endif - if (pCur->chunksBuffer) { - sqlite3_free(pCur->chunksBuffer); - pCur->chunksBuffer = NULL; - } - if (pCur->vector) { - pCur->vector = NULL; - } - sqlite3_free(pCur); - return SQLITE_OK; -} - -static int vec_npy_eachBestIndex(sqlite3_vtab *pVTab, - sqlite3_index_info *pIdxInfo) { - int hasInput; - for (int i = 0; i < pIdxInfo->nConstraint; i++) { - const struct sqlite3_index_constraint *pCons = &pIdxInfo->aConstraint[i]; - // printf("i=%d iColumn=%d, op=%d, usable=%d\n", i, pCons->iColumn, - // pCons->op, pCons->usable); - switch (pCons->iColumn) { - case VEC_NPY_EACH_COLUMN_INPUT: { - if (pCons->op == SQLITE_INDEX_CONSTRAINT_EQ && pCons->usable) { - hasInput = 1; - pIdxInfo->aConstraintUsage[i].argvIndex = 1; - pIdxInfo->aConstraintUsage[i].omit = 1; - } - break; - } - } - } - if (!hasInput) { - pVTab->zErrMsg = sqlite3_mprintf("input argument is required"); - return SQLITE_ERROR; - } - - pIdxInfo->estimatedCost = (double)100000; - pIdxInfo->estimatedRows = 100000; - - return SQLITE_OK; -} - -static int vec_npy_eachFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum, - const char *idxStr, int argc, - sqlite3_value **argv) { - UNUSED_PARAMETER(idxNum); - UNUSED_PARAMETER(idxStr); - assert(argc == 1); - int rc; - - vec_npy_each_cursor *pCur = (vec_npy_each_cursor *)pVtabCursor; - -#ifndef SQLITE_VEC_OMIT_FS - if (pCur->file) { - fclose(pCur->file); - pCur->file = NULL; - } -#endif - if (pCur->chunksBuffer) { - sqlite3_free(pCur->chunksBuffer); - pCur->chunksBuffer = NULL; - } - if (pCur->vector) { - pCur->vector = NULL; - } - -#ifndef SQLITE_VEC_OMIT_FS - struct VecNpyFile *f = NULL; - if ((f = sqlite3_value_pointer(argv[0], SQLITE_VEC_NPY_FILE_NAME))) { - FILE *file = fopen(f->path, "r"); - if (!file) { - vtab_set_error(pVtabCursor->pVtab, "Could not open numpy file"); - return SQLITE_ERROR; - } - - rc = parse_npy_file(pVtabCursor->pVtab, file, pCur); - if (rc != SQLITE_OK) { -#ifndef SQLITE_VEC_OMIT_FS - fclose(file); -#endif - return rc; - } - - } else -#endif - { - - const unsigned char *input = sqlite3_value_blob(argv[0]); - int inputLength = sqlite3_value_bytes(argv[0]); - void *data; - size_t numElements; - size_t numDimensions; - enum VectorElementType element_type; - - rc = parse_npy_buffer(pVtabCursor->pVtab, input, inputLength, &data, - &numElements, &numDimensions, &element_type); - if (rc != SQLITE_OK) { - return rc; - } - - pCur->vector = data; - pCur->elementType = element_type; - pCur->nElements = numElements; - pCur->nDimensions = numDimensions; - pCur->input_type = VEC_NPY_EACH_INPUT_BUFFER; - } - - pCur->iRowid = 0; - return SQLITE_OK; -} - -static int vec_npy_eachRowid(sqlite3_vtab_cursor *cur, sqlite_int64 *pRowid) { - vec_npy_each_cursor *pCur = (vec_npy_each_cursor *)cur; - *pRowid = pCur->iRowid; - return SQLITE_OK; -} - -static int vec_npy_eachEof(sqlite3_vtab_cursor *cur) { - vec_npy_each_cursor *pCur = (vec_npy_each_cursor *)cur; - if (pCur->input_type == VEC_NPY_EACH_INPUT_BUFFER) { - return (!pCur->nElements) || (size_t)pCur->iRowid >= pCur->nElements; - } - return pCur->eof; -} - -static int vec_npy_eachNext(sqlite3_vtab_cursor *cur) { - vec_npy_each_cursor *pCur = (vec_npy_each_cursor *)cur; - pCur->iRowid++; - if (pCur->input_type == VEC_NPY_EACH_INPUT_BUFFER) { - return SQLITE_OK; - } - -#ifndef SQLITE_VEC_OMIT_FS - // else: input is a file - pCur->currentChunkIndex++; - if (pCur->currentChunkIndex >= pCur->currentChunkSize) { - pCur->currentChunkSize = - fread(pCur->chunksBuffer, - vector_byte_size(pCur->elementType, pCur->nDimensions), - pCur->maxChunks, pCur->file); - if (!pCur->currentChunkSize) { - pCur->eof = 1; - } - pCur->currentChunkIndex = 0; - } -#endif - return SQLITE_OK; -} - -static int vec_npy_eachColumnBuffer(vec_npy_each_cursor *pCur, - sqlite3_context *context, int i) { - switch (i) { - case VEC_NPY_EACH_COLUMN_VECTOR: { - sqlite3_result_subtype(context, pCur->elementType); - switch (pCur->elementType) { - case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { - sqlite3_result_blob( - context, - &((unsigned char *) - pCur->vector)[pCur->iRowid * pCur->nDimensions * sizeof(f32)], - pCur->nDimensions * sizeof(f32), SQLITE_TRANSIENT); - - break; - } - case SQLITE_VEC_ELEMENT_TYPE_INT8: - case SQLITE_VEC_ELEMENT_TYPE_BIT: { - // https://github.com/asg017/sqlite-vec/issues/42 - sqlite3_result_error(context, - "vec_npy_each only supports float32 vectors", -1); - break; - } - } - - break; - } - } - return SQLITE_OK; -} -static int vec_npy_eachColumnFile(vec_npy_each_cursor *pCur, - sqlite3_context *context, int i) { - switch (i) { - case VEC_NPY_EACH_COLUMN_VECTOR: { - switch (pCur->elementType) { - case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { - sqlite3_result_blob( - context, - &((unsigned char *) - pCur->chunksBuffer)[pCur->currentChunkIndex * - pCur->nDimensions * sizeof(f32)], - pCur->nDimensions * sizeof(f32), SQLITE_TRANSIENT); - break; - } - case SQLITE_VEC_ELEMENT_TYPE_INT8: - case SQLITE_VEC_ELEMENT_TYPE_BIT: { - // https://github.com/asg017/sqlite-vec/issues/42 - sqlite3_result_error(context, - "vec_npy_each only supports float32 vectors", -1); - break; - } - } - break; - } - } - return SQLITE_OK; -} -static int vec_npy_eachColumn(sqlite3_vtab_cursor *cur, - sqlite3_context *context, int i) { - vec_npy_each_cursor *pCur = (vec_npy_each_cursor *)cur; - switch (pCur->input_type) { - case VEC_NPY_EACH_INPUT_BUFFER: - return vec_npy_eachColumnBuffer(pCur, context, i); - case VEC_NPY_EACH_INPUT_FILE: - return vec_npy_eachColumnFile(pCur, context, i); - } - return SQLITE_ERROR; -} - -static sqlite3_module vec_npy_eachModule = { - /* iVersion */ 0, - /* xCreate */ 0, - /* xConnect */ vec_npy_eachConnect, - /* xBestIndex */ vec_npy_eachBestIndex, - /* xDisconnect */ vec_npy_eachDisconnect, - /* xDestroy */ 0, - /* xOpen */ vec_npy_eachOpen, - /* xClose */ vec_npy_eachClose, - /* xFilter */ vec_npy_eachFilter, - /* xNext */ vec_npy_eachNext, - /* xEof */ vec_npy_eachEof, - /* xColumn */ vec_npy_eachColumn, - /* xRowid */ vec_npy_eachRowid, - /* xUpdate */ 0, - /* xBegin */ 0, - /* xSync */ 0, - /* xCommit */ 0, - /* xRollback */ 0, - /* xFindMethod */ 0, - /* xRename */ 0, - /* xSavepoint */ 0, - /* xRelease */ 0, - /* xRollbackTo */ 0, - /* xShadowName */ 0, -#if SQLITE_VERSION_NUMBER >= 3044000 - /* xIntegrity */ 0, -#endif -}; - -#pragma endregion #pragma region vec0 virtual table @@ -5959,6 +5497,65 @@ int min_idx(const f32 *distances, i32 n, u8 *candidates, i32 *out, i32 k, assert(k > 0); assert(k <= n); +#ifdef SQLITE_VEC_EXPERIMENTAL_MIN_IDX + // Max-heap variant: O(n log k) single-pass. + // out[0..heap_size-1] stores indices; heap ordered by distances descending + // so out[0] is always the index of the LARGEST distance in the top-k. + (void)bTaken; + int heap_size = 0; + + #define HEAP_SIFT_UP(pos) do { \ + int _c = (pos); \ + while (_c > 0) { \ + int _p = (_c - 1) / 2; \ + if (distances[out[_p]] < distances[out[_c]]) { \ + i32 _tmp = out[_p]; out[_p] = out[_c]; out[_c] = _tmp; \ + _c = _p; \ + } else break; \ + } \ + } while(0) + + #define HEAP_SIFT_DOWN(pos, sz) do { \ + int _p = (pos); \ + for (;;) { \ + int _l = 2*_p + 1, _r = 2*_p + 2, _largest = _p; \ + if (_l < (sz) && distances[out[_l]] > distances[out[_largest]]) \ + _largest = _l; \ + if (_r < (sz) && distances[out[_r]] > distances[out[_largest]]) \ + _largest = _r; \ + if (_largest == _p) break; \ + i32 _tmp = out[_p]; out[_p] = out[_largest]; out[_largest] = _tmp; \ + _p = _largest; \ + } \ + } while(0) + + for (int i = 0; i < n; i++) { + if (!bitmap_get(candidates, i)) + continue; + if (heap_size < k) { + out[heap_size] = i; + heap_size++; + HEAP_SIFT_UP(heap_size - 1); + } else if (distances[i] < distances[out[0]]) { + out[0] = i; + HEAP_SIFT_DOWN(0, heap_size); + } + } + + // Heapsort to produce ascending order. + for (int i = heap_size - 1; i > 0; i--) { + i32 tmp = out[0]; out[0] = out[i]; out[i] = tmp; + HEAP_SIFT_DOWN(0, i); + } + + #undef HEAP_SIFT_UP + #undef HEAP_SIFT_DOWN + + *k_used = heap_size; + return SQLITE_OK; + +#else + // Original: O(n*k) repeated linear scan with bitmap. bitmap_clear(bTaken, n); for (int ik = 0; ik < k; ik++) { @@ -5984,6 +5581,7 @@ int min_idx(const f32 *distances, i32 n, u8 *candidates, i32 *out, i32 k, } *k_used = k; return SQLITE_OK; +#endif } int vec0_get_metadata_text_long_value( @@ -9388,652 +8986,6 @@ static sqlite3_module vec0Module = { }; #pragma endregion -static char *POINTER_NAME_STATIC_BLOB_DEF = "vec0-static_blob_def"; -struct static_blob_definition { - void *p; - size_t dimensions; - size_t nvectors; - enum VectorElementType element_type; -}; -static void vec_static_blob_from_raw(sqlite3_context *context, int argc, - sqlite3_value **argv) { - - assert(argc == 4); - struct static_blob_definition *p; - p = sqlite3_malloc(sizeof(*p)); - if (!p) { - sqlite3_result_error_nomem(context); - return; - } - memset(p, 0, sizeof(*p)); - p->p = (void *)sqlite3_value_int64(argv[0]); - p->element_type = SQLITE_VEC_ELEMENT_TYPE_FLOAT32; - p->dimensions = sqlite3_value_int64(argv[2]); - p->nvectors = sqlite3_value_int64(argv[3]); - sqlite3_result_pointer(context, p, POINTER_NAME_STATIC_BLOB_DEF, - sqlite3_free); -} -#pragma region vec_static_blobs() table function - -#define MAX_STATIC_BLOBS 16 - -typedef struct static_blob static_blob; -struct static_blob { - char *name; - void *p; - size_t dimensions; - size_t nvectors; - enum VectorElementType element_type; -}; - -typedef struct vec_static_blob_data vec_static_blob_data; -struct vec_static_blob_data { - static_blob static_blobs[MAX_STATIC_BLOBS]; -}; - -typedef struct vec_static_blobs_vtab vec_static_blobs_vtab; -struct vec_static_blobs_vtab { - sqlite3_vtab base; - vec_static_blob_data *data; -}; - -typedef struct vec_static_blobs_cursor vec_static_blobs_cursor; -struct vec_static_blobs_cursor { - sqlite3_vtab_cursor base; - sqlite3_int64 iRowid; -}; - -static int vec_static_blobsConnect(sqlite3 *db, void *pAux, int argc, - const char *const *argv, - sqlite3_vtab **ppVtab, char **pzErr) { - UNUSED_PARAMETER(argc); - UNUSED_PARAMETER(argv); - UNUSED_PARAMETER(pzErr); - - vec_static_blobs_vtab *pNew; -#define VEC_STATIC_BLOBS_NAME 0 -#define VEC_STATIC_BLOBS_DATA 1 -#define VEC_STATIC_BLOBS_DIMENSIONS 2 -#define VEC_STATIC_BLOBS_COUNT 3 - int rc = sqlite3_declare_vtab( - db, "CREATE TABLE x(name, data, dimensions hidden, count hidden)"); - if (rc == SQLITE_OK) { - pNew = sqlite3_malloc(sizeof(*pNew)); - *ppVtab = (sqlite3_vtab *)pNew; - if (pNew == 0) - return SQLITE_NOMEM; - memset(pNew, 0, sizeof(*pNew)); - pNew->data = pAux; - } - return rc; -} - -static int vec_static_blobsDisconnect(sqlite3_vtab *pVtab) { - vec_static_blobs_vtab *p = (vec_static_blobs_vtab *)pVtab; - sqlite3_free(p); - return SQLITE_OK; -} - -static int vec_static_blobsUpdate(sqlite3_vtab *pVTab, int argc, - sqlite3_value **argv, sqlite_int64 *pRowid) { - UNUSED_PARAMETER(pRowid); - vec_static_blobs_vtab *p = (vec_static_blobs_vtab *)pVTab; - // DELETE operation - if (argc == 1 && sqlite3_value_type(argv[0]) != SQLITE_NULL) { - return SQLITE_ERROR; - } - // INSERT operation - else if (argc > 1 && sqlite3_value_type(argv[0]) == SQLITE_NULL) { - const char *key = - (const char *)sqlite3_value_text(argv[2 + VEC_STATIC_BLOBS_NAME]); - int idx = -1; - for (int i = 0; i < MAX_STATIC_BLOBS; i++) { - if (!p->data->static_blobs[i].name) { - p->data->static_blobs[i].name = sqlite3_mprintf("%s", key); - idx = i; - break; - } - } - if (idx < 0) - abort(); - struct static_blob_definition *def = sqlite3_value_pointer( - argv[2 + VEC_STATIC_BLOBS_DATA], POINTER_NAME_STATIC_BLOB_DEF); - p->data->static_blobs[idx].p = def->p; - p->data->static_blobs[idx].dimensions = def->dimensions; - p->data->static_blobs[idx].nvectors = def->nvectors; - p->data->static_blobs[idx].element_type = def->element_type; - - return SQLITE_OK; - } - // UPDATE operation - else if (argc > 1 && sqlite3_value_type(argv[0]) != SQLITE_NULL) { - return SQLITE_ERROR; - } - return SQLITE_ERROR; -} - -static int vec_static_blobsOpen(sqlite3_vtab *p, - sqlite3_vtab_cursor **ppCursor) { - UNUSED_PARAMETER(p); - vec_static_blobs_cursor *pCur; - pCur = sqlite3_malloc(sizeof(*pCur)); - if (pCur == 0) - return SQLITE_NOMEM; - memset(pCur, 0, sizeof(*pCur)); - *ppCursor = &pCur->base; - return SQLITE_OK; -} - -static int vec_static_blobsClose(sqlite3_vtab_cursor *cur) { - vec_static_blobs_cursor *pCur = (vec_static_blobs_cursor *)cur; - sqlite3_free(pCur); - return SQLITE_OK; -} - -static int vec_static_blobsBestIndex(sqlite3_vtab *pVTab, - sqlite3_index_info *pIdxInfo) { - UNUSED_PARAMETER(pVTab); - pIdxInfo->idxNum = 1; - pIdxInfo->estimatedCost = (double)10; - pIdxInfo->estimatedRows = 10; - return SQLITE_OK; -} - -static int vec_static_blobsNext(sqlite3_vtab_cursor *cur); -static int vec_static_blobsFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum, - const char *idxStr, int argc, - sqlite3_value **argv) { - UNUSED_PARAMETER(idxNum); - UNUSED_PARAMETER(idxStr); - UNUSED_PARAMETER(argc); - UNUSED_PARAMETER(argv); - vec_static_blobs_cursor *pCur = (vec_static_blobs_cursor *)pVtabCursor; - pCur->iRowid = -1; - vec_static_blobsNext(pVtabCursor); - return SQLITE_OK; -} - -static int vec_static_blobsRowid(sqlite3_vtab_cursor *cur, - sqlite_int64 *pRowid) { - vec_static_blobs_cursor *pCur = (vec_static_blobs_cursor *)cur; - *pRowid = pCur->iRowid; - return SQLITE_OK; -} - -static int vec_static_blobsNext(sqlite3_vtab_cursor *cur) { - vec_static_blobs_cursor *pCur = (vec_static_blobs_cursor *)cur; - vec_static_blobs_vtab *p = (vec_static_blobs_vtab *)pCur->base.pVtab; - pCur->iRowid++; - while (pCur->iRowid < MAX_STATIC_BLOBS) { - if (p->data->static_blobs[pCur->iRowid].name) { - return SQLITE_OK; - } - pCur->iRowid++; - } - return SQLITE_OK; -} - -static int vec_static_blobsEof(sqlite3_vtab_cursor *cur) { - vec_static_blobs_cursor *pCur = (vec_static_blobs_cursor *)cur; - return pCur->iRowid >= MAX_STATIC_BLOBS; -} - -static int vec_static_blobsColumn(sqlite3_vtab_cursor *cur, - sqlite3_context *context, int i) { - vec_static_blobs_cursor *pCur = (vec_static_blobs_cursor *)cur; - vec_static_blobs_vtab *p = (vec_static_blobs_vtab *)cur->pVtab; - switch (i) { - case VEC_STATIC_BLOBS_NAME: - sqlite3_result_text(context, p->data->static_blobs[pCur->iRowid].name, -1, - SQLITE_TRANSIENT); - break; - case VEC_STATIC_BLOBS_DATA: - sqlite3_result_null(context); - break; - case VEC_STATIC_BLOBS_DIMENSIONS: - sqlite3_result_int64(context, - p->data->static_blobs[pCur->iRowid].dimensions); - break; - case VEC_STATIC_BLOBS_COUNT: - sqlite3_result_int64(context, p->data->static_blobs[pCur->iRowid].nvectors); - break; - } - return SQLITE_OK; -} - -static sqlite3_module vec_static_blobsModule = { - /* iVersion */ 3, - /* xCreate */ 0, - /* xConnect */ vec_static_blobsConnect, - /* xBestIndex */ vec_static_blobsBestIndex, - /* xDisconnect */ vec_static_blobsDisconnect, - /* xDestroy */ 0, - /* xOpen */ vec_static_blobsOpen, - /* xClose */ vec_static_blobsClose, - /* xFilter */ vec_static_blobsFilter, - /* xNext */ vec_static_blobsNext, - /* xEof */ vec_static_blobsEof, - /* xColumn */ vec_static_blobsColumn, - /* xRowid */ vec_static_blobsRowid, - /* xUpdate */ vec_static_blobsUpdate, - /* xBegin */ 0, - /* xSync */ 0, - /* xCommit */ 0, - /* xRollback */ 0, - /* xFindMethod */ 0, - /* xRename */ 0, - /* xSavepoint */ 0, - /* xRelease */ 0, - /* xRollbackTo */ 0, - /* xShadowName */ 0, -#if SQLITE_VERSION_NUMBER >= 3044000 - /* xIntegrity */ 0 -#endif -}; -#pragma endregion - -#pragma region vec_static_blob_entries() table function - -typedef struct vec_static_blob_entries_vtab vec_static_blob_entries_vtab; -struct vec_static_blob_entries_vtab { - sqlite3_vtab base; - static_blob *blob; -}; -typedef enum { - VEC_SBE__QUERYPLAN_FULLSCAN = 1, - VEC_SBE__QUERYPLAN_KNN = 2 -} vec_sbe_query_plan; - -struct sbe_query_knn_data { - i64 k; - i64 k_used; - // Array of rowids of size k. Must be freed with sqlite3_free(). - i32 *rowids; - // Array of distances of size k. Must be freed with sqlite3_free(). - f32 *distances; - i64 current_idx; -}; -void sbe_query_knn_data_clear(struct sbe_query_knn_data *knn_data) { - if (!knn_data) - return; - - if (knn_data->rowids) { - sqlite3_free(knn_data->rowids); - knn_data->rowids = NULL; - } - if (knn_data->distances) { - sqlite3_free(knn_data->distances); - knn_data->distances = NULL; - } -} - -typedef struct vec_static_blob_entries_cursor vec_static_blob_entries_cursor; -struct vec_static_blob_entries_cursor { - sqlite3_vtab_cursor base; - sqlite3_int64 iRowid; - vec_sbe_query_plan query_plan; - struct sbe_query_knn_data *knn_data; -}; - -static int vec_static_blob_entriesConnect(sqlite3 *db, void *pAux, int argc, - const char *const *argv, - sqlite3_vtab **ppVtab, char **pzErr) { - UNUSED_PARAMETER(argc); - UNUSED_PARAMETER(argv); - UNUSED_PARAMETER(pzErr); - vec_static_blob_data *blob_data = pAux; - int idx = -1; - for (int i = 0; i < MAX_STATIC_BLOBS; i++) { - if (!blob_data->static_blobs[i].name) - continue; - if (strncmp(blob_data->static_blobs[i].name, argv[3], - strlen(blob_data->static_blobs[i].name)) == 0) { - idx = i; - break; - } - } - if (idx < 0) - abort(); - vec_static_blob_entries_vtab *pNew; -#define VEC_STATIC_BLOB_ENTRIES_VECTOR 0 -#define VEC_STATIC_BLOB_ENTRIES_DISTANCE 1 -#define VEC_STATIC_BLOB_ENTRIES_K 2 - int rc = sqlite3_declare_vtab( - db, "CREATE TABLE x(vector, distance hidden, k hidden)"); - if (rc == SQLITE_OK) { - pNew = sqlite3_malloc(sizeof(*pNew)); - *ppVtab = (sqlite3_vtab *)pNew; - if (pNew == 0) - return SQLITE_NOMEM; - memset(pNew, 0, sizeof(*pNew)); - pNew->blob = &blob_data->static_blobs[idx]; - } - return rc; -} - -static int vec_static_blob_entriesCreate(sqlite3 *db, void *pAux, int argc, - const char *const *argv, - sqlite3_vtab **ppVtab, char **pzErr) { - return vec_static_blob_entriesConnect(db, pAux, argc, argv, ppVtab, pzErr); -} - -static int vec_static_blob_entriesDisconnect(sqlite3_vtab *pVtab) { - vec_static_blob_entries_vtab *p = (vec_static_blob_entries_vtab *)pVtab; - sqlite3_free(p); - return SQLITE_OK; -} - -static int vec_static_blob_entriesOpen(sqlite3_vtab *p, - sqlite3_vtab_cursor **ppCursor) { - UNUSED_PARAMETER(p); - vec_static_blob_entries_cursor *pCur; - pCur = sqlite3_malloc(sizeof(*pCur)); - if (pCur == 0) - return SQLITE_NOMEM; - memset(pCur, 0, sizeof(*pCur)); - *ppCursor = &pCur->base; - return SQLITE_OK; -} - -static int vec_static_blob_entriesClose(sqlite3_vtab_cursor *cur) { - vec_static_blob_entries_cursor *pCur = (vec_static_blob_entries_cursor *)cur; - sqlite3_free(pCur->knn_data); - sqlite3_free(pCur); - return SQLITE_OK; -} - -static int vec_static_blob_entriesBestIndex(sqlite3_vtab *pVTab, - sqlite3_index_info *pIdxInfo) { - vec_static_blob_entries_vtab *p = (vec_static_blob_entries_vtab *)pVTab; - int iMatchTerm = -1; - int iLimitTerm = -1; - // int iRowidTerm = -1; // https://github.com/asg017/sqlite-vec/issues/47 - int iKTerm = -1; - - for (int i = 0; i < pIdxInfo->nConstraint; i++) { - if (!pIdxInfo->aConstraint[i].usable) - continue; - - int iColumn = pIdxInfo->aConstraint[i].iColumn; - int op = pIdxInfo->aConstraint[i].op; - if (op == SQLITE_INDEX_CONSTRAINT_MATCH && - iColumn == VEC_STATIC_BLOB_ENTRIES_VECTOR) { - if (iMatchTerm > -1) { - // https://github.com/asg017/sqlite-vec/issues/51 - return SQLITE_ERROR; - } - iMatchTerm = i; - } - if (op == SQLITE_INDEX_CONSTRAINT_LIMIT) { - iLimitTerm = i; - } - if (op == SQLITE_INDEX_CONSTRAINT_EQ && - iColumn == VEC_STATIC_BLOB_ENTRIES_K) { - iKTerm = i; - } - } - if (iMatchTerm >= 0) { - if (iLimitTerm < 0 && iKTerm < 0) { - // https://github.com/asg017/sqlite-vec/issues/51 - return SQLITE_ERROR; - } - if (iLimitTerm >= 0 && iKTerm >= 0) { - return SQLITE_ERROR; // limit or k, not both - } - if (pIdxInfo->nOrderBy < 1) { - vtab_set_error(pVTab, "ORDER BY distance required"); - return SQLITE_CONSTRAINT; - } - if (pIdxInfo->nOrderBy > 1) { - // https://github.com/asg017/sqlite-vec/issues/51 - vtab_set_error(pVTab, "more than 1 ORDER BY clause provided"); - return SQLITE_CONSTRAINT; - } - if (pIdxInfo->aOrderBy[0].iColumn != VEC_STATIC_BLOB_ENTRIES_DISTANCE) { - vtab_set_error(pVTab, "ORDER BY must be on the distance column"); - return SQLITE_CONSTRAINT; - } - if (pIdxInfo->aOrderBy[0].desc) { - vtab_set_error(pVTab, - "Only ascending in ORDER BY distance clause is supported, " - "DESC is not supported yet."); - return SQLITE_CONSTRAINT; - } - - pIdxInfo->idxNum = VEC_SBE__QUERYPLAN_KNN; - pIdxInfo->estimatedCost = (double)10; - pIdxInfo->estimatedRows = 10; - - pIdxInfo->orderByConsumed = 1; - pIdxInfo->aConstraintUsage[iMatchTerm].argvIndex = 1; - pIdxInfo->aConstraintUsage[iMatchTerm].omit = 1; - if (iLimitTerm >= 0) { - pIdxInfo->aConstraintUsage[iLimitTerm].argvIndex = 2; - pIdxInfo->aConstraintUsage[iLimitTerm].omit = 1; - } else { - pIdxInfo->aConstraintUsage[iKTerm].argvIndex = 2; - pIdxInfo->aConstraintUsage[iKTerm].omit = 1; - } - - } else { - pIdxInfo->idxNum = VEC_SBE__QUERYPLAN_FULLSCAN; - pIdxInfo->estimatedCost = (double)p->blob->nvectors; - pIdxInfo->estimatedRows = p->blob->nvectors; - } - return SQLITE_OK; -} - -static int vec_static_blob_entriesFilter(sqlite3_vtab_cursor *pVtabCursor, - int idxNum, const char *idxStr, - int argc, sqlite3_value **argv) { - UNUSED_PARAMETER(idxStr); - assert(argc >= 0 && argc <= 3); - vec_static_blob_entries_cursor *pCur = - (vec_static_blob_entries_cursor *)pVtabCursor; - vec_static_blob_entries_vtab *p = - (vec_static_blob_entries_vtab *)pCur->base.pVtab; - - if (idxNum == VEC_SBE__QUERYPLAN_KNN) { - assert(argc == 2); - pCur->query_plan = VEC_SBE__QUERYPLAN_KNN; - struct sbe_query_knn_data *knn_data; - knn_data = sqlite3_malloc(sizeof(*knn_data)); - if (!knn_data) { - return SQLITE_NOMEM; - } - memset(knn_data, 0, sizeof(*knn_data)); - - void *queryVector; - size_t dimensions; - enum VectorElementType elementType; - vector_cleanup cleanup; - char *err; - int rc = vector_from_value(argv[0], &queryVector, &dimensions, &elementType, - &cleanup, &err); - if (rc != SQLITE_OK) { - return SQLITE_ERROR; - } - if (elementType != p->blob->element_type) { - return SQLITE_ERROR; - } - if (dimensions != p->blob->dimensions) { - return SQLITE_ERROR; - } - - i64 k = min(sqlite3_value_int64(argv[1]), (i64)p->blob->nvectors); - if (k < 0) { - // HANDLE https://github.com/asg017/sqlite-vec/issues/55 - return SQLITE_ERROR; - } - if (k == 0) { - knn_data->k = 0; - pCur->knn_data = knn_data; - return SQLITE_OK; - } - - size_t bsize = (p->blob->nvectors + 7) & ~7; - - i32 *topk_rowids = sqlite3_malloc(k * sizeof(i32)); - if (!topk_rowids) { - // HANDLE https://github.com/asg017/sqlite-vec/issues/55 - return SQLITE_ERROR; - } - f32 *distances = sqlite3_malloc(bsize * sizeof(f32)); - if (!distances) { - // HANDLE https://github.com/asg017/sqlite-vec/issues/55 - return SQLITE_ERROR; - } - - for (size_t i = 0; i < p->blob->nvectors; i++) { - // https://github.com/asg017/sqlite-vec/issues/52 - float *v = ((float *)p->blob->p) + (i * p->blob->dimensions); - distances[i] = - distance_l2_sqr_float(v, (float *)queryVector, &p->blob->dimensions); - } - u8 *candidates = bitmap_new(bsize); - assert(candidates); - - u8 *taken = bitmap_new(bsize); - assert(taken); - - bitmap_fill(candidates, bsize); - for (size_t i = bsize; i >= p->blob->nvectors; i--) { - bitmap_set(candidates, i, 0); - } - i32 k_used = 0; - min_idx(distances, bsize, candidates, topk_rowids, k, taken, &k_used); - knn_data->current_idx = 0; - knn_data->distances = distances; - knn_data->k = k; - knn_data->rowids = topk_rowids; - - pCur->knn_data = knn_data; - } else { - pCur->query_plan = VEC_SBE__QUERYPLAN_FULLSCAN; - pCur->iRowid = 0; - } - - return SQLITE_OK; -} - -static int vec_static_blob_entriesRowid(sqlite3_vtab_cursor *cur, - sqlite_int64 *pRowid) { - vec_static_blob_entries_cursor *pCur = (vec_static_blob_entries_cursor *)cur; - switch (pCur->query_plan) { - case VEC_SBE__QUERYPLAN_FULLSCAN: { - *pRowid = pCur->iRowid; - return SQLITE_OK; - } - case VEC_SBE__QUERYPLAN_KNN: { - i32 rowid = ((i32 *)pCur->knn_data->rowids)[pCur->knn_data->current_idx]; - *pRowid = (sqlite3_int64)rowid; - return SQLITE_OK; - } - } - return SQLITE_ERROR; -} - -static int vec_static_blob_entriesNext(sqlite3_vtab_cursor *cur) { - vec_static_blob_entries_cursor *pCur = (vec_static_blob_entries_cursor *)cur; - switch (pCur->query_plan) { - case VEC_SBE__QUERYPLAN_FULLSCAN: { - pCur->iRowid++; - return SQLITE_OK; - } - case VEC_SBE__QUERYPLAN_KNN: { - pCur->knn_data->current_idx++; - return SQLITE_OK; - } - } - return SQLITE_ERROR; -} - -static int vec_static_blob_entriesEof(sqlite3_vtab_cursor *cur) { - vec_static_blob_entries_cursor *pCur = (vec_static_blob_entries_cursor *)cur; - vec_static_blob_entries_vtab *p = - (vec_static_blob_entries_vtab *)pCur->base.pVtab; - switch (pCur->query_plan) { - case VEC_SBE__QUERYPLAN_FULLSCAN: { - return (size_t)pCur->iRowid >= p->blob->nvectors; - } - case VEC_SBE__QUERYPLAN_KNN: { - return pCur->knn_data->current_idx >= pCur->knn_data->k; - } - } - return SQLITE_ERROR; -} - -static int vec_static_blob_entriesColumn(sqlite3_vtab_cursor *cur, - sqlite3_context *context, int i) { - vec_static_blob_entries_cursor *pCur = (vec_static_blob_entries_cursor *)cur; - vec_static_blob_entries_vtab *p = (vec_static_blob_entries_vtab *)cur->pVtab; - - switch (pCur->query_plan) { - case VEC_SBE__QUERYPLAN_FULLSCAN: { - switch (i) { - case VEC_STATIC_BLOB_ENTRIES_VECTOR: - - sqlite3_result_blob( - context, - ((unsigned char *)p->blob->p) + - (pCur->iRowid * p->blob->dimensions * sizeof(float)), - p->blob->dimensions * sizeof(float), SQLITE_TRANSIENT); - sqlite3_result_subtype(context, p->blob->element_type); - break; - } - return SQLITE_OK; - } - case VEC_SBE__QUERYPLAN_KNN: { - switch (i) { - case VEC_STATIC_BLOB_ENTRIES_VECTOR: { - i32 rowid = ((i32 *)pCur->knn_data->rowids)[pCur->knn_data->current_idx]; - sqlite3_result_blob(context, - ((unsigned char *)p->blob->p) + - (rowid * p->blob->dimensions * sizeof(float)), - p->blob->dimensions * sizeof(float), - SQLITE_TRANSIENT); - sqlite3_result_subtype(context, p->blob->element_type); - break; - } - } - return SQLITE_OK; - } - } - return SQLITE_ERROR; -} - -static sqlite3_module vec_static_blob_entriesModule = { - /* iVersion */ 3, - /* xCreate */ - vec_static_blob_entriesCreate, // handle rm? - // https://github.com/asg017/sqlite-vec/issues/55 - /* xConnect */ vec_static_blob_entriesConnect, - /* xBestIndex */ vec_static_blob_entriesBestIndex, - /* xDisconnect */ vec_static_blob_entriesDisconnect, - /* xDestroy */ vec_static_blob_entriesDisconnect, - /* xOpen */ vec_static_blob_entriesOpen, - /* xClose */ vec_static_blob_entriesClose, - /* xFilter */ vec_static_blob_entriesFilter, - /* xNext */ vec_static_blob_entriesNext, - /* xEof */ vec_static_blob_entriesEof, - /* xColumn */ vec_static_blob_entriesColumn, - /* xRowid */ vec_static_blob_entriesRowid, - /* xUpdate */ 0, - /* xBegin */ 0, - /* xSync */ 0, - /* xCommit */ 0, - /* xRollback */ 0, - /* xFindMethod */ 0, - /* xRename */ 0, - /* xSavepoint */ 0, - /* xRelease */ 0, - /* xRollbackTo */ 0, - /* xShadowName */ 0, -#if SQLITE_VERSION_NUMBER >= 3044000 - /* xIntegrity */ 0 -#endif -}; -#pragma endregion #ifdef SQLITE_VEC_ENABLE_AVX #define SQLITE_VEC_DEBUG_BUILD_AVX "avx" @@ -10139,55 +9091,4 @@ SQLITE_VEC_API int sqlite3_vec_init(sqlite3 *db, char **pzErrMsg, return SQLITE_OK; } -#ifndef SQLITE_VEC_OMIT_FS -SQLITE_VEC_API int sqlite3_vec_numpy_init(sqlite3 *db, char **pzErrMsg, - const sqlite3_api_routines *pApi) { - UNUSED_PARAMETER(pzErrMsg); -#ifndef SQLITE_CORE - SQLITE_EXTENSION_INIT2(pApi); -#endif - int rc = SQLITE_OK; - rc = sqlite3_create_function_v2(db, "vec_npy_file", 1, SQLITE_RESULT_SUBTYPE, - NULL, vec_npy_file, NULL, NULL, NULL); - if(rc != SQLITE_OK) { - return rc; - } - rc = sqlite3_create_module_v2(db, "vec_npy_each", &vec_npy_eachModule, NULL, NULL); - return rc; -} -#endif -SQLITE_VEC_API int -sqlite3_vec_static_blobs_init(sqlite3 *db, char **pzErrMsg, - const sqlite3_api_routines *pApi) { - UNUSED_PARAMETER(pzErrMsg); -#ifndef SQLITE_CORE - SQLITE_EXTENSION_INIT2(pApi); -#endif - - int rc = SQLITE_OK; - vec_static_blob_data *static_blob_data; - static_blob_data = sqlite3_malloc(sizeof(*static_blob_data)); - if (!static_blob_data) { - return SQLITE_NOMEM; - } - memset(static_blob_data, 0, sizeof(*static_blob_data)); - - rc = sqlite3_create_function_v2( - db, "vec_static_blob_from_raw", 4, - DEFAULT_FLAGS | SQLITE_SUBTYPE | SQLITE_RESULT_SUBTYPE, NULL, - vec_static_blob_from_raw, NULL, NULL, NULL); - if (rc != SQLITE_OK) - return rc; - - rc = sqlite3_create_module_v2(db, "vec_static_blobs", &vec_static_blobsModule, - static_blob_data, sqlite3_free); - if (rc != SQLITE_OK) - return rc; - rc = sqlite3_create_module_v2(db, "vec_static_blob_entries", - &vec_static_blob_entriesModule, - static_blob_data, NULL); - if (rc != SQLITE_OK) - return rc; - return rc; -} diff --git a/tests/correctness/test-correctness.py b/tests/correctness/test-correctness.py index cb01f8f..9ed0319 100644 --- a/tests/correctness/test-correctness.py +++ b/tests/correctness/test-correctness.py @@ -48,7 +48,6 @@ import json db = sqlite3.connect(":memory:") db.enable_load_extension(True) db.load_extension("../../dist/vec0") -db.execute("select load_extension('../../dist/vec0', 'sqlite3_vec_fs_read_init')") db.enable_load_extension(False) results = db.execute( @@ -75,17 +74,21 @@ print(b) db.execute('PRAGMA page_size=16384') -print("Loading into sqlite-vec vec0 table...") -t0 = time.time() -db.execute("create virtual table v using vec0(a float[3072], chunk_size=16)") -db.execute('insert into v select rowid, vector from vec_npy_each(vec_npy_file("dbpedia_openai_3_large_00.npy"))') -print(time.time() - t0) - print("loading numpy array...") t0 = time.time() base = np.load('dbpedia_openai_3_large_00.npy') print(time.time() - t0) +print("Loading into sqlite-vec vec0 table...") +t0 = time.time() +db.execute("create virtual table v using vec0(a float[3072], chunk_size=16)") +with db: + db.executemany( + "insert into v(rowid, a) values (?, ?)", + [(i, row.tobytes()) for i, row in enumerate(base)], + ) +print(time.time() - t0) + np.random.seed(1) queries = base[np.random.choice(base.shape[0], 20, replace=False), :] diff --git a/tests/fuzz/numpy.c b/tests/fuzz/numpy.c deleted file mode 100644 index 9e2900b..0000000 --- a/tests/fuzz/numpy.c +++ /dev/null @@ -1,37 +0,0 @@ -#include -#include - -#include -#include -#include -#include "sqlite-vec.h" -#include "sqlite3.h" -#include - -extern int sqlite3_vec_numpy_init(sqlite3 *db, char **pzErrMsg, - const sqlite3_api_routines *pApi); - -int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { - int rc = SQLITE_OK; - sqlite3 *db; - sqlite3_stmt *stmt; - - rc = sqlite3_open(":memory:", &db); - assert(rc == SQLITE_OK); - rc = sqlite3_vec_init(db, NULL, NULL); - assert(rc == SQLITE_OK); - rc = sqlite3_vec_numpy_init(db, NULL, NULL); - assert(rc == SQLITE_OK); - - rc = sqlite3_prepare_v2(db, "select * from vec_npy_each(?)", -1, &stmt, NULL); - assert(rc == SQLITE_OK); - sqlite3_bind_blob(stmt, 1, data, size, SQLITE_STATIC); - rc = sqlite3_step(stmt); - while (rc == SQLITE_ROW) { - rc = sqlite3_step(stmt); - } - - sqlite3_finalize(stmt); - sqlite3_close(db); - return 0; -} diff --git a/tests/sqlite-vec-internal.h b/tests/sqlite-vec-internal.h index a540849..a02c72a 100644 --- a/tests/sqlite-vec-internal.h +++ b/tests/sqlite-vec-internal.h @@ -3,6 +3,7 @@ #include #include +#include int min_idx( const float *distances, @@ -62,12 +63,17 @@ enum Vec0DistanceMetrics { VEC0_DISTANCE_METRIC_L1 = 3, }; +enum Vec0IndexType { + VEC0_INDEX_TYPE_FLAT = 1, +}; + struct VectorColumnDefinition { char *name; int name_length; size_t dimensions; enum VectorElementType element_type; enum Vec0DistanceMetrics distance_metric; + enum Vec0IndexType index_type; }; int vec0_parse_vector_column(const char *source, int source_length, diff --git a/tests/test-loadable.py b/tests/test-loadable.py index bc4eed1..40c6a5e 100644 --- a/tests/test-loadable.py +++ b/tests/test-loadable.py @@ -119,151 +119,9 @@ FUNCTIONS = [ MODULES = [ "vec0", "vec_each", - # "vec_static_blob_entries", - # "vec_static_blobs", ] -def register_numpy(db, name: str, array): - ptr = array.__array_interface__["data"][0] - nvectors, dimensions = array.__array_interface__["shape"] - element_type = array.__array_interface__["typestr"] - - assert element_type == "\x9a\x99\x99>", - }, - { - "vector": b"fff?\xcd\xccL?", - }, - ] - assert execute_all(db, "select rowid, (vector) from z") == [ - { - "rowid": 0, - "vector": b"\xcd\xcc\xcc=\xcd\xcc\xcc=\xcd\xcc\xcc=\xcd\xcc\xcc=", - }, - { - "rowid": 1, - "vector": b"\xcd\xccL>\xcd\xccL>\xcd\xccL>\xcd\xccL>", - }, - { - "rowid": 2, - "vector": b"\x9a\x99\x99>\x9a\x99\x99>\x9a\x99\x99>\x9a\x99\x99>", - }, - { - "rowid": 3, - "vector": b"\xcd\xcc\xcc>\xcd\xcc\xcc>\xcd\xcc\xcc>\xcd\xcc\xcc>", - }, - { - "rowid": 4, - "vector": b"\x00\x00\x00?\x00\x00\x00?\x00\x00\x00?\x00\x00\x00?", - }, - ] - assert execute_all( - db, - "select rowid, vec_to_json(vector) as v from z where vector match ? and k = 3 order by distance;", - [np.array([0.3, 0.3, 0.3, 0.3], dtype=np.float32)], - ) == [ - { - "rowid": 2, - "v": "[0.300000,0.300000,0.300000,0.300000]", - }, - { - "rowid": 3, - "v": "[0.400000,0.400000,0.400000,0.400000]", - }, - { - "rowid": 1, - "v": "[0.200000,0.200000,0.200000,0.200000]", - }, - ] - assert execute_all( - db, - "select rowid, vec_to_json(vector) as v from z where vector match ? and k = 3 order by distance;", - [np.array([0.6, 0.6, 0.6, 0.6], dtype=np.float32)], - ) == [ - { - "rowid": 4, - "v": "[0.500000,0.500000,0.500000,0.500000]", - }, - { - "rowid": 3, - "v": "[0.400000,0.400000,0.400000,0.400000]", - }, - { - "rowid": 2, - "v": "[0.300000,0.300000,0.300000,0.300000]", - }, - ] - - def test_limits(): db = connect(EXT_PATH) with _raises( @@ -1618,231 +1476,6 @@ def test_vec_each(): vec_each_f32(None) -import io - - -def to_npy(arr): - buf = io.BytesIO() - np.save(buf, arr) - buf.seek(0) - return buf.read() - - -def test_vec_npy_each(): - db = connect(EXT_PATH, extra_entrypoint="sqlite3_vec_numpy_init") - vec_npy_each = lambda *args: execute_all( - db, "select rowid, * from vec_npy_each(?)", args - ) - assert vec_npy_each(to_npy(np.array([1.1, 2.2, 3.3], dtype=np.float32))) == [ - { - "rowid": 0, - "vector": _f32([1.1, 2.2, 3.3]), - }, - ] - assert vec_npy_each(to_npy(np.array([[1.1, 2.2, 3.3]], dtype=np.float32))) == [ - { - "rowid": 0, - "vector": _f32([1.1, 2.2, 3.3]), - }, - ] - assert vec_npy_each( - to_npy(np.array([[1.1, 2.2, 3.3], [9.9, 8.8, 7.7]], dtype=np.float32)) - ) == [ - { - "rowid": 0, - "vector": _f32([1.1, 2.2, 3.3]), - }, - { - "rowid": 1, - "vector": _f32([9.9, 8.8, 7.7]), - }, - ] - - assert vec_npy_each(to_npy(np.array([], dtype=np.float32))) == [] - - -def test_vec_npy_each_errors(): - db = connect(EXT_PATH, extra_entrypoint="sqlite3_vec_numpy_init") - vec_npy_each = lambda *args: execute_all( - db, "select rowid, * from vec_npy_each(?)", args - ) - - full = b"\x93NUMPY\x01\x00v\x00{'descr': ' 8 bits per byte * 64 bytes = 512 + for (int i = 0; i < 128; i += 2) { + a[i] = 0xFF; + } + d = _test_distance_hamming(a, b, 1024); + assert(d == 512.0f); + } + printf(" All distance_hamming tests passed.\n"); } diff --git a/tmp-static.py b/tmp-static.py deleted file mode 100644 index a3b5f37..0000000 --- a/tmp-static.py +++ /dev/null @@ -1,56 +0,0 @@ -import sqlite3 -import numpy as np - -db = sqlite3.connect(":memory:") - -db.enable_load_extension(True) -db.load_extension("./dist/vec0") -db.execute("select load_extension('./dist/vec0', 'sqlite3_vec_raw_init')") -db.enable_load_extension(False) - -x = np.array([[0.1, 0.2, 0.3, 0.4], [0.9, 0.8, 0.7, 0.6]], dtype=np.float32) -y = np.array([[0.2, 0.3], [0.9, 0.8], [0.6, 0.5]], dtype=np.float32) -z = np.array( - [ - [0.1, 0.1, 0.1, 0.1], - [0.2, 0.2, 0.2, 0.2], - [0.3, 0.3, 0.3, 0.3], - [0.4, 0.4, 0.4, 0.4], - [0.5, 0.5, 0.5, 0.5], - ], - dtype=np.float32, -) - - -def register_np(array, name): - ptr = array.__array_interface__["data"][0] - nvectors, dimensions = array.__array_interface__["shape"] - element_type = array.__array_interface__["typestr"] - - assert element_type == " Date: Sun, 29 Mar 2026 19:45:54 -0700 Subject: [PATCH 2/3] Add rescore index for ANN queries Add rescore index type: stores full-precision float vectors in a rowid-keyed shadow table, quantizes to int8 for fast initial scan, then rescores top candidates with original vectors. Includes config parser, shadow table management, insert/delete support, KNN integration, compile flag (SQLITE_VEC_ENABLE_RESCORE), fuzz targets, and tests. --- Makefile | 2 +- benchmarks-ann/Makefile | 13 +- benchmarks-ann/bench.py | 33 ++ sqlite-vec-rescore.c | 662 ++++++++++++++++++++++++++++ sqlite-vec.c | 435 +++++++++++++++++- tests/fuzz/.gitignore | 5 + tests/fuzz/Makefile | 26 +- tests/fuzz/rescore-create.c | 36 ++ tests/fuzz/rescore-create.dict | 20 + tests/fuzz/rescore-interleave.c | 151 +++++++ tests/fuzz/rescore-knn-deep.c | 178 ++++++++ tests/fuzz/rescore-operations.c | 96 ++++ tests/fuzz/rescore-quantize-edge.c | 177 ++++++++ tests/fuzz/rescore-quantize.c | 54 +++ tests/fuzz/rescore-shadow-corrupt.c | 230 ++++++++++ tests/sqlite-vec-internal.h | 25 ++ tests/test-rescore-mutations.py | 470 ++++++++++++++++++++ tests/test-rescore.py | 568 ++++++++++++++++++++++++ tests/test-unit.c | 205 +++++++++ 19 files changed, 3378 insertions(+), 8 deletions(-) create mode 100644 sqlite-vec-rescore.c create mode 100644 tests/fuzz/rescore-create.c create mode 100644 tests/fuzz/rescore-create.dict create mode 100644 tests/fuzz/rescore-interleave.c create mode 100644 tests/fuzz/rescore-knn-deep.c create mode 100644 tests/fuzz/rescore-operations.c create mode 100644 tests/fuzz/rescore-quantize-edge.c create mode 100644 tests/fuzz/rescore-quantize.c create mode 100644 tests/fuzz/rescore-shadow-corrupt.c create mode 100644 tests/test-rescore-mutations.py create mode 100644 tests/test-rescore.py diff --git a/Makefile b/Makefile index 051590e..b604171 100644 --- a/Makefile +++ b/Makefile @@ -202,7 +202,7 @@ test-loadable-watch: watchexec --exts c,py,Makefile --clear -- make test-loadable test-unit: - $(CC) -DSQLITE_CORE -DSQLITE_VEC_TEST tests/test-unit.c sqlite-vec.c vendor/sqlite3.c -I./ -Ivendor -o $(prefix)/test-unit && $(prefix)/test-unit + $(CC) -DSQLITE_CORE -DSQLITE_VEC_TEST -DSQLITE_VEC_ENABLE_RESCORE tests/test-unit.c sqlite-vec.c vendor/sqlite3.c -I./ -Ivendor -o $(prefix)/test-unit && $(prefix)/test-unit fuzz-build: $(MAKE) -C tests/fuzz all diff --git a/benchmarks-ann/Makefile b/benchmarks-ann/Makefile index 59e2dcd..762abea 100644 --- a/benchmarks-ann/Makefile +++ b/benchmarks-ann/Makefile @@ -21,9 +21,14 @@ BASELINES = \ # ANNOY_CONFIGS = \ # "annoy-t50:type=annoy,n_trees=50" -ALL_CONFIGS = $(BASELINES) +RESCORE_CONFIGS = \ + "rescore-bit-os8:type=rescore,quantizer=bit,oversample=8" \ + "rescore-bit-os16:type=rescore,quantizer=bit,oversample=16" \ + "rescore-int8-os8:type=rescore,quantizer=int8,oversample=8" -.PHONY: seed ground-truth bench-smoke bench-10k bench-50k bench-100k bench-all \ +ALL_CONFIGS = $(BASELINES) $(RESCORE_CONFIGS) + +.PHONY: seed ground-truth bench-smoke bench-rescore bench-10k bench-50k bench-100k bench-all \ report clean # --- Data preparation --- @@ -40,6 +45,10 @@ bench-smoke: seed $(BENCH) --subset-size 5000 -k 10 -n 20 -o runs/smoke \ $(BASELINES) +bench-rescore: seed + $(BENCH) --subset-size 10000 -k 10 -o runs/rescore \ + $(RESCORE_CONFIGS) + # --- Standard sizes --- bench-10k: seed $(BENCH) --subset-size 10000 -k 10 -o runs/10k $(ALL_CONFIGS) diff --git a/benchmarks-ann/bench.py b/benchmarks-ann/bench.py index 93f8f82..c1179d6 100644 --- a/benchmarks-ann/bench.py +++ b/benchmarks-ann/bench.py @@ -140,6 +140,39 @@ INDEX_REGISTRY["baseline"] = { } +# ============================================================================ +# Rescore implementation +# ============================================================================ + + +def _rescore_create_table_sql(params): + quantizer = params.get("quantizer", "bit") + oversample = params.get("oversample", 8) + return ( + f"CREATE VIRTUAL TABLE vec_items USING vec0(" + f" chunk_size=256," + f" id integer primary key," + f" embedding float[768] distance_metric=cosine" + f" indexed by rescore(quantizer={quantizer}, oversample={oversample}))" + ) + + +def _rescore_describe(params): + q = params.get("quantizer", "bit") + os = params.get("oversample", 8) + return f"rescore {q} (os={os})" + + +INDEX_REGISTRY["rescore"] = { + "defaults": {"quantizer": "bit", "oversample": 8}, + "create_table_sql": _rescore_create_table_sql, + "insert_sql": None, + "post_insert_hook": None, + "run_query": None, # default MATCH query works — rescore is automatic + "describe": _rescore_describe, +} + + # ============================================================================ # Config parsing # ============================================================================ diff --git a/sqlite-vec-rescore.c b/sqlite-vec-rescore.c new file mode 100644 index 0000000..a45f52f --- /dev/null +++ b/sqlite-vec-rescore.c @@ -0,0 +1,662 @@ +/** + * sqlite-vec-rescore.c — Rescore index logic for sqlite-vec. + * + * This file is #included into sqlite-vec.c after the vec0_vtab definition. + * All functions receive a vec0_vtab *p and access p->vector_columns[i].rescore. + * + * Shadow tables per rescore-enabled vector column: + * _rescore_chunks{NN} — quantized vectors in chunk layout (for coarse scan) + * _rescore_vectors{NN} — float vectors keyed by rowid (for fast rescore lookup) + */ + +// ============================================================================ +// Shadow table lifecycle +// ============================================================================ + +static int rescore_create_tables(vec0_vtab *p, sqlite3 *db, char **pzErr) { + for (int i = 0; i < p->numVectorColumns; i++) { + if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_RESCORE) + continue; + + // Quantized chunk table (same layout as _vector_chunks) + char *zSql = sqlite3_mprintf( + "CREATE TABLE \"%w\".\"%w_rescore_chunks%02d\"" + "(rowid PRIMARY KEY, vectors BLOB NOT NULL)", + p->schemaName, p->tableName, i); + if (!zSql) + return SQLITE_NOMEM; + sqlite3_stmt *stmt; + int rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, 0); + sqlite3_free(zSql); + if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) { + *pzErr = sqlite3_mprintf( + "Could not create '_rescore_chunks%02d' shadow table: %s", i, + sqlite3_errmsg(db)); + sqlite3_finalize(stmt); + return SQLITE_ERROR; + } + sqlite3_finalize(stmt); + + // Float vector table (rowid-keyed for fast random access) + zSql = sqlite3_mprintf( + "CREATE TABLE \"%w\".\"%w_rescore_vectors%02d\"" + "(rowid INTEGER PRIMARY KEY, vector BLOB NOT NULL)", + p->schemaName, p->tableName, i); + if (!zSql) + return SQLITE_NOMEM; + rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, 0); + sqlite3_free(zSql); + if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) { + *pzErr = sqlite3_mprintf( + "Could not create '_rescore_vectors%02d' shadow table: %s", i, + sqlite3_errmsg(db)); + sqlite3_finalize(stmt); + return SQLITE_ERROR; + } + sqlite3_finalize(stmt); + } + return SQLITE_OK; +} + +static int rescore_drop_tables(vec0_vtab *p) { + for (int i = 0; i < p->numVectorColumns; i++) { + sqlite3_stmt *stmt; + int rc; + char *zSql; + + if (p->shadowRescoreChunksNames[i]) { + zSql = sqlite3_mprintf("DROP TABLE IF EXISTS \"%w\".\"%w\"", + p->schemaName, p->shadowRescoreChunksNames[i]); + if (!zSql) + return SQLITE_NOMEM; + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0); + sqlite3_free(zSql); + if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) { + sqlite3_finalize(stmt); + return SQLITE_ERROR; + } + sqlite3_finalize(stmt); + } + + if (p->shadowRescoreVectorsNames[i]) { + zSql = sqlite3_mprintf("DROP TABLE IF EXISTS \"%w\".\"%w\"", + p->schemaName, p->shadowRescoreVectorsNames[i]); + if (!zSql) + return SQLITE_NOMEM; + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0); + sqlite3_free(zSql); + if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) { + sqlite3_finalize(stmt); + return SQLITE_ERROR; + } + sqlite3_finalize(stmt); + } + } + return SQLITE_OK; +} + +static size_t rescore_quantized_byte_size(struct VectorColumnDefinition *col) { + switch (col->rescore.quantizer_type) { + case VEC0_RESCORE_QUANTIZER_BIT: + return col->dimensions / CHAR_BIT; + case VEC0_RESCORE_QUANTIZER_INT8: + return col->dimensions; + default: + return 0; + } +} + +/** + * Insert a new chunk row into each _rescore_chunks{NN} table with a zeroblob. + */ +static int rescore_new_chunk(vec0_vtab *p, i64 chunk_rowid) { + for (int i = 0; i < p->numVectorColumns; i++) { + if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_RESCORE) + continue; + size_t quantized_size = + rescore_quantized_byte_size(&p->vector_columns[i]); + i64 blob_size = (i64)p->chunk_size * (i64)quantized_size; + + char *zSql = sqlite3_mprintf( + "INSERT INTO \"%w\".\"%w\"(_rowid_, rowid, vectors) VALUES (?, ?, ?)", + p->schemaName, p->shadowRescoreChunksNames[i]); + if (!zSql) + return SQLITE_NOMEM; + sqlite3_stmt *stmt; + int rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) { + sqlite3_finalize(stmt); + return rc; + } + sqlite3_bind_int64(stmt, 1, chunk_rowid); + sqlite3_bind_int64(stmt, 2, chunk_rowid); + sqlite3_bind_zeroblob64(stmt, 3, blob_size); + rc = sqlite3_step(stmt); + sqlite3_finalize(stmt); + if (rc != SQLITE_DONE) + return rc; + } + return SQLITE_OK; +} + +// ============================================================================ +// Quantization +// ============================================================================ + +static void rescore_quantize_float_to_bit(const float *src, uint8_t *dst, + size_t dimensions) { + memset(dst, 0, dimensions / CHAR_BIT); + for (size_t i = 0; i < dimensions; i++) { + if (src[i] >= 0.0f) { + dst[i / CHAR_BIT] |= (1 << (i % CHAR_BIT)); + } + } +} + +static void rescore_quantize_float_to_int8(const float *src, int8_t *dst, + size_t dimensions) { + float vmin = src[0], vmax = src[0]; + for (size_t i = 1; i < dimensions; i++) { + if (src[i] < vmin) vmin = src[i]; + if (src[i] > vmax) vmax = src[i]; + } + float range = vmax - vmin; + if (range == 0.0f) { + memset(dst, 0, dimensions); + return; + } + float scale = 255.0f / range; + for (size_t i = 0; i < dimensions; i++) { + float v = (src[i] - vmin) * scale - 128.0f; + if (v < -128.0f) v = -128.0f; + if (v > 127.0f) v = 127.0f; + dst[i] = (int8_t)v; + } +} + +// ============================================================================ +// Insert path +// ============================================================================ + +/** + * Quantize float vector to _rescore_chunks and store in _rescore_vectors. + */ +static int rescore_on_insert(vec0_vtab *p, i64 chunk_rowid, i64 chunk_offset, + i64 rowid, void *vectorDatas[]) { + for (int i = 0; i < p->numVectorColumns; i++) { + if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_RESCORE) + continue; + + struct VectorColumnDefinition *col = &p->vector_columns[i]; + size_t qsize = rescore_quantized_byte_size(col); + size_t fsize = vector_column_byte_size(*col); + int rc; + + // 1. Write quantized vector to _rescore_chunks blob + { + void *qbuf = sqlite3_malloc(qsize); + if (!qbuf) + return SQLITE_NOMEM; + + switch (col->rescore.quantizer_type) { + case VEC0_RESCORE_QUANTIZER_BIT: + rescore_quantize_float_to_bit((const float *)vectorDatas[i], + (uint8_t *)qbuf, col->dimensions); + break; + case VEC0_RESCORE_QUANTIZER_INT8: + rescore_quantize_float_to_int8((const float *)vectorDatas[i], + (int8_t *)qbuf, col->dimensions); + break; + } + + sqlite3_blob *blob = NULL; + rc = sqlite3_blob_open(p->db, p->schemaName, + p->shadowRescoreChunksNames[i], "vectors", + chunk_rowid, 1, &blob); + if (rc != SQLITE_OK) { + sqlite3_free(qbuf); + return rc; + } + rc = sqlite3_blob_write(blob, qbuf, qsize, chunk_offset * qsize); + sqlite3_free(qbuf); + int brc = sqlite3_blob_close(blob); + if (rc != SQLITE_OK) + return rc; + if (brc != SQLITE_OK) + return brc; + } + + // 2. Insert float vector into _rescore_vectors (rowid-keyed) + { + char *zSql = sqlite3_mprintf( + "INSERT INTO \"%w\".\"%w\"(rowid, vector) VALUES (?, ?)", + p->schemaName, p->shadowRescoreVectorsNames[i]); + if (!zSql) + return SQLITE_NOMEM; + sqlite3_stmt *stmt; + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) { + sqlite3_finalize(stmt); + return rc; + } + sqlite3_bind_int64(stmt, 1, rowid); + sqlite3_bind_blob(stmt, 2, vectorDatas[i], fsize, SQLITE_TRANSIENT); + rc = sqlite3_step(stmt); + sqlite3_finalize(stmt); + if (rc != SQLITE_DONE) + return SQLITE_ERROR; + } + } + return SQLITE_OK; +} + +// ============================================================================ +// Delete path +// ============================================================================ + +/** + * Zero out quantized vector in _rescore_chunks and delete from _rescore_vectors. + */ +static int rescore_on_delete(vec0_vtab *p, i64 chunk_id, u64 chunk_offset, + i64 rowid) { + for (int i = 0; i < p->numVectorColumns; i++) { + if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_RESCORE) + continue; + int rc; + + // 1. Zero out quantized data in _rescore_chunks + { + size_t qsize = rescore_quantized_byte_size(&p->vector_columns[i]); + void *zeroBuf = sqlite3_malloc(qsize); + if (!zeroBuf) + return SQLITE_NOMEM; + memset(zeroBuf, 0, qsize); + + sqlite3_blob *blob = NULL; + rc = sqlite3_blob_open(p->db, p->schemaName, + p->shadowRescoreChunksNames[i], "vectors", + chunk_id, 1, &blob); + if (rc != SQLITE_OK) { + sqlite3_free(zeroBuf); + return rc; + } + rc = sqlite3_blob_write(blob, zeroBuf, qsize, chunk_offset * qsize); + sqlite3_free(zeroBuf); + int brc = sqlite3_blob_close(blob); + if (rc != SQLITE_OK) + return rc; + if (brc != SQLITE_OK) + return brc; + } + + // 2. Delete from _rescore_vectors + { + char *zSql = sqlite3_mprintf( + "DELETE FROM \"%w\".\"%w\" WHERE rowid = ?", + p->schemaName, p->shadowRescoreVectorsNames[i]); + if (!zSql) + return SQLITE_NOMEM; + sqlite3_stmt *stmt; + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) + return rc; + sqlite3_bind_int64(stmt, 1, rowid); + rc = sqlite3_step(stmt); + sqlite3_finalize(stmt); + if (rc != SQLITE_DONE) + return SQLITE_ERROR; + } + } + return SQLITE_OK; +} + +/** + * Delete a chunk row from _rescore_chunks{NN} tables. + * (_rescore_vectors rows were already deleted per-row in rescore_on_delete) + */ +static int rescore_delete_chunk(vec0_vtab *p, i64 chunk_id) { + for (int i = 0; i < p->numVectorColumns; i++) { + if (!p->shadowRescoreChunksNames[i]) + continue; + char *zSql = sqlite3_mprintf( + "DELETE FROM \"%w\".\"%w\" WHERE rowid = ?", + p->schemaName, p->shadowRescoreChunksNames[i]); + if (!zSql) + return SQLITE_NOMEM; + sqlite3_stmt *stmt; + int rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) + return rc; + sqlite3_bind_int64(stmt, 1, chunk_id); + rc = sqlite3_step(stmt); + sqlite3_finalize(stmt); + if (rc != SQLITE_DONE) + return SQLITE_ERROR; + } + return SQLITE_OK; +} + +// ============================================================================ +// KNN rescore query +// ============================================================================ + +/** + * Phase 1: Coarse scan of quantized chunks → top k*oversample candidates (rowids). + * Phase 2: For each candidate, blob_open _rescore_vectors by rowid, read float + * vector, compute float distance. Sort, return top k. + * + * Phase 2 is fast because _rescore_vectors has INTEGER PRIMARY KEY, so + * sqlite3_blob_open/reopen addresses rows directly by rowid — no index lookup. + */ +static int rescore_knn(vec0_vtab *p, vec0_cursor *pCur, + struct VectorColumnDefinition *vector_column, + int vectorColumnIdx, struct Array *arrayRowidsIn, + struct Array *aMetadataIn, const char *idxStr, int argc, + sqlite3_value **argv, void *queryVector, i64 k, + struct vec0_query_knn_data *knn_data) { + (void)pCur; + (void)aMetadataIn; + int rc = SQLITE_OK; + int oversample = vector_column->rescore.oversample; + i64 k_oversample = k * oversample; + if (k_oversample > 4096) + k_oversample = 4096; + + size_t qdim = vector_column->dimensions; + size_t qsize = rescore_quantized_byte_size(vector_column); + size_t fsize = vector_column_byte_size(*vector_column); + + // Quantize the query vector + void *quantizedQuery = sqlite3_malloc(qsize); + if (!quantizedQuery) + return SQLITE_NOMEM; + + switch (vector_column->rescore.quantizer_type) { + case VEC0_RESCORE_QUANTIZER_BIT: + rescore_quantize_float_to_bit((const float *)queryVector, + (uint8_t *)quantizedQuery, qdim); + break; + case VEC0_RESCORE_QUANTIZER_INT8: + rescore_quantize_float_to_int8((const float *)queryVector, + (int8_t *)quantizedQuery, qdim); + break; + } + + // Phase 1: Scan quantized chunks for k*oversample candidates + sqlite3_stmt *stmtChunks = NULL; + rc = vec0_chunks_iter(p, idxStr, argc, argv, &stmtChunks); + if (rc != SQLITE_OK) { + sqlite3_free(quantizedQuery); + return rc; + } + + i64 *cand_rowids = sqlite3_malloc(k_oversample * sizeof(i64)); + f32 *cand_distances = sqlite3_malloc(k_oversample * sizeof(f32)); + i64 *tmp_rowids = sqlite3_malloc(k_oversample * sizeof(i64)); + f32 *tmp_distances = sqlite3_malloc(k_oversample * sizeof(f32)); + f32 *chunk_distances = sqlite3_malloc(p->chunk_size * sizeof(f32)); + i32 *chunk_topk_idxs = sqlite3_malloc(k_oversample * sizeof(i32)); + u8 *b = sqlite3_malloc(p->chunk_size / CHAR_BIT); + u8 *bTaken = sqlite3_malloc(p->chunk_size / CHAR_BIT); + u8 *bmRowids = NULL; + void *baseVectors = sqlite3_malloc((i64)p->chunk_size * (i64)qsize); + + if (!cand_rowids || !cand_distances || !tmp_rowids || !tmp_distances || + !chunk_distances || !chunk_topk_idxs || !b || !bTaken || !baseVectors) { + rc = SQLITE_NOMEM; + goto cleanup; + } + memset(cand_rowids, 0, k_oversample * sizeof(i64)); + memset(cand_distances, 0, k_oversample * sizeof(f32)); + + if (arrayRowidsIn) { + bmRowids = sqlite3_malloc(p->chunk_size / CHAR_BIT); + if (!bmRowids) { + rc = SQLITE_NOMEM; + goto cleanup; + } + } + + i64 cand_used = 0; + + while (1) { + rc = sqlite3_step(stmtChunks); + if (rc == SQLITE_DONE) + break; + if (rc != SQLITE_ROW) { + rc = SQLITE_ERROR; + goto cleanup; + } + + i64 chunk_id = sqlite3_column_int64(stmtChunks, 0); + unsigned char *chunkValidity = + (unsigned char *)sqlite3_column_blob(stmtChunks, 1); + i64 *chunkRowids = (i64 *)sqlite3_column_blob(stmtChunks, 2); + + memset(chunk_distances, 0, p->chunk_size * sizeof(f32)); + memset(chunk_topk_idxs, 0, k_oversample * sizeof(i32)); + bitmap_copy(b, chunkValidity, p->chunk_size); + + if (arrayRowidsIn) { + bitmap_clear(bmRowids, p->chunk_size); + for (int j = 0; j < p->chunk_size; j++) { + if (!bitmap_get(chunkValidity, j)) + continue; + i64 rid = chunkRowids[j]; + void *found = bsearch(&rid, arrayRowidsIn->z, arrayRowidsIn->length, + sizeof(i64), _cmp); + bitmap_set(bmRowids, j, found ? 1 : 0); + } + bitmap_and_inplace(b, bmRowids, p->chunk_size); + } + + // Read quantized vectors + sqlite3_blob *blobQ = NULL; + rc = sqlite3_blob_open(p->db, p->schemaName, + p->shadowRescoreChunksNames[vectorColumnIdx], + "vectors", chunk_id, 0, &blobQ); + if (rc != SQLITE_OK) + goto cleanup; + rc = sqlite3_blob_read(blobQ, baseVectors, + (i64)p->chunk_size * (i64)qsize, 0); + sqlite3_blob_close(blobQ); + if (rc != SQLITE_OK) + goto cleanup; + + // Compute quantized distances + for (int j = 0; j < p->chunk_size; j++) { + if (!bitmap_get(b, j)) + continue; + f32 dist; + switch (vector_column->rescore.quantizer_type) { + case VEC0_RESCORE_QUANTIZER_BIT: { + const u8 *base_j = ((u8 *)baseVectors) + (j * (qdim / CHAR_BIT)); + dist = distance_hamming(base_j, (u8 *)quantizedQuery, &qdim); + break; + } + case VEC0_RESCORE_QUANTIZER_INT8: { + const i8 *base_j = ((i8 *)baseVectors) + (j * qdim); + switch (vector_column->distance_metric) { + case VEC0_DISTANCE_METRIC_L2: + dist = distance_l2_sqr_int8(base_j, (i8 *)quantizedQuery, &qdim); + break; + case VEC0_DISTANCE_METRIC_COSINE: + dist = distance_cosine_int8(base_j, (i8 *)quantizedQuery, &qdim); + break; + case VEC0_DISTANCE_METRIC_L1: + dist = (f32)distance_l1_int8(base_j, (i8 *)quantizedQuery, &qdim); + break; + } + break; + } + } + chunk_distances[j] = dist; + } + + int used1; + min_idx(chunk_distances, p->chunk_size, b, chunk_topk_idxs, + min(k_oversample, p->chunk_size), bTaken, &used1); + + i64 merged_used; + merge_sorted_lists(cand_distances, cand_rowids, cand_used, chunk_distances, + chunkRowids, chunk_topk_idxs, + min(min(k_oversample, p->chunk_size), used1), + tmp_distances, tmp_rowids, k_oversample, &merged_used); + + for (i64 j = 0; j < merged_used; j++) { + cand_rowids[j] = tmp_rowids[j]; + cand_distances[j] = tmp_distances[j]; + } + cand_used = merged_used; + } + rc = SQLITE_OK; + + // Phase 2: Rescore candidates using _rescore_vectors (rowid-keyed) + if (cand_used == 0) { + knn_data->current_idx = 0; + knn_data->k = 0; + knn_data->rowids = NULL; + knn_data->distances = NULL; + knn_data->k_used = 0; + goto cleanup; + } + { + f32 *float_distances = sqlite3_malloc(cand_used * sizeof(f32)); + void *fBuf = sqlite3_malloc(fsize); + if (!float_distances || !fBuf) { + sqlite3_free(float_distances); + sqlite3_free(fBuf); + rc = SQLITE_NOMEM; + goto cleanup; + } + + // Open blob on _rescore_vectors, then reopen for each candidate rowid. + // blob_reopen is O(1) for INTEGER PRIMARY KEY tables. + sqlite3_blob *blobFloat = NULL; + rc = sqlite3_blob_open(p->db, p->schemaName, + p->shadowRescoreVectorsNames[vectorColumnIdx], + "vector", cand_rowids[0], 0, &blobFloat); + if (rc != SQLITE_OK) { + sqlite3_free(float_distances); + sqlite3_free(fBuf); + goto cleanup; + } + + rc = sqlite3_blob_read(blobFloat, fBuf, fsize, 0); + if (rc != SQLITE_OK) { + sqlite3_blob_close(blobFloat); + sqlite3_free(float_distances); + sqlite3_free(fBuf); + goto cleanup; + } + float_distances[0] = + vec0_distance_full(fBuf, queryVector, vector_column->dimensions, + vector_column->element_type, + vector_column->distance_metric); + + for (i64 j = 1; j < cand_used; j++) { + rc = sqlite3_blob_reopen(blobFloat, cand_rowids[j]); + if (rc != SQLITE_OK) { + sqlite3_blob_close(blobFloat); + sqlite3_free(float_distances); + sqlite3_free(fBuf); + goto cleanup; + } + rc = sqlite3_blob_read(blobFloat, fBuf, fsize, 0); + if (rc != SQLITE_OK) { + sqlite3_blob_close(blobFloat); + sqlite3_free(float_distances); + sqlite3_free(fBuf); + goto cleanup; + } + float_distances[j] = + vec0_distance_full(fBuf, queryVector, vector_column->dimensions, + vector_column->element_type, + vector_column->distance_metric); + } + sqlite3_blob_close(blobFloat); + sqlite3_free(fBuf); + + // Sort by float distance + for (i64 a = 0; a + 1 < cand_used; a++) { + i64 minIdx = a; + for (i64 c = a + 1; c < cand_used; c++) { + if (float_distances[c] < float_distances[minIdx]) + minIdx = c; + } + if (minIdx != a) { + f32 td = float_distances[a]; + float_distances[a] = float_distances[minIdx]; + float_distances[minIdx] = td; + i64 tr = cand_rowids[a]; + cand_rowids[a] = cand_rowids[minIdx]; + cand_rowids[minIdx] = tr; + } + } + + i64 result_k = min(k, cand_used); + i64 *out_rowids = sqlite3_malloc(result_k * sizeof(i64)); + f32 *out_distances = sqlite3_malloc(result_k * sizeof(f32)); + if (!out_rowids || !out_distances) { + sqlite3_free(out_rowids); + sqlite3_free(out_distances); + sqlite3_free(float_distances); + rc = SQLITE_NOMEM; + goto cleanup; + } + for (i64 j = 0; j < result_k; j++) { + out_rowids[j] = cand_rowids[j]; + out_distances[j] = float_distances[j]; + } + + knn_data->current_idx = 0; + knn_data->k = result_k; + knn_data->rowids = out_rowids; + knn_data->distances = out_distances; + knn_data->k_used = result_k; + + sqlite3_free(float_distances); + } + +cleanup: + sqlite3_finalize(stmtChunks); + sqlite3_free(quantizedQuery); + sqlite3_free(cand_rowids); + sqlite3_free(cand_distances); + sqlite3_free(tmp_rowids); + sqlite3_free(tmp_distances); + sqlite3_free(chunk_distances); + sqlite3_free(chunk_topk_idxs); + sqlite3_free(b); + sqlite3_free(bTaken); + sqlite3_free(bmRowids); + sqlite3_free(baseVectors); + return rc; +} + +#ifdef SQLITE_VEC_TEST +void _test_rescore_quantize_float_to_bit(const float *src, uint8_t *dst, size_t dim) { + rescore_quantize_float_to_bit(src, dst, dim); +} +void _test_rescore_quantize_float_to_int8(const float *src, int8_t *dst, size_t dim) { + rescore_quantize_float_to_int8(src, dst, dim); +} +size_t _test_rescore_quantized_byte_size_bit(size_t dimensions) { + struct VectorColumnDefinition col; + memset(&col, 0, sizeof(col)); + col.dimensions = dimensions; + col.rescore.quantizer_type = VEC0_RESCORE_QUANTIZER_BIT; + return rescore_quantized_byte_size(&col); +} +size_t _test_rescore_quantized_byte_size_int8(size_t dimensions) { + struct VectorColumnDefinition col; + memset(&col, 0, sizeof(col)); + col.dimensions = dimensions; + col.rescore.quantizer_type = VEC0_RESCORE_QUANTIZER_INT8; + return rescore_quantized_byte_size(&col); +} +#endif diff --git a/sqlite-vec.c b/sqlite-vec.c index 390123b..ff9e0da 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -112,6 +112,10 @@ typedef size_t usize; #define countof(x) (sizeof(x) / sizeof((x)[0])) #define min(a, b) (((a) <= (b)) ? (a) : (b)) +#ifndef SQLITE_VEC_ENABLE_RESCORE +#define SQLITE_VEC_ENABLE_RESCORE 1 +#endif + enum VectorElementType { // clang-format off SQLITE_VEC_ELEMENT_TYPE_FLOAT32 = 223 + 0, @@ -2532,8 +2536,23 @@ static f32 vec0_distance_full( enum Vec0IndexType { VEC0_INDEX_TYPE_FLAT = 1, +#if SQLITE_VEC_ENABLE_RESCORE + VEC0_INDEX_TYPE_RESCORE = 2, +#endif }; +#if SQLITE_VEC_ENABLE_RESCORE +enum Vec0RescoreQuantizerType { + VEC0_RESCORE_QUANTIZER_BIT = 1, + VEC0_RESCORE_QUANTIZER_INT8 = 2, +}; + +struct Vec0RescoreConfig { + enum Vec0RescoreQuantizerType quantizer_type; + int oversample; +}; +#endif + struct VectorColumnDefinition { char *name; int name_length; @@ -2541,6 +2560,9 @@ struct VectorColumnDefinition { enum VectorElementType element_type; enum Vec0DistanceMetrics distance_metric; enum Vec0IndexType index_type; +#if SQLITE_VEC_ENABLE_RESCORE + struct Vec0RescoreConfig rescore; +#endif }; struct Vec0PartitionColumnDefinition { @@ -2577,6 +2599,111 @@ size_t vector_column_byte_size(struct VectorColumnDefinition column) { return vector_byte_size(column.element_type, column.dimensions); } +#if SQLITE_VEC_ENABLE_RESCORE +/** + * @brief Parse rescore options from an "INDEXED BY rescore(...)" clause. + * + * @param scanner Scanner positioned right after the opening '(' of rescore(...) + * @param outConfig Output rescore config + * @param pzErr Error message output + * @return int SQLITE_OK on success, SQLITE_ERROR on error. + */ +static int vec0_parse_rescore_options(struct Vec0Scanner *scanner, + struct Vec0RescoreConfig *outConfig, + char **pzErr) { + struct Vec0Token token; + int rc; + int hasQuantizer = 0; + outConfig->oversample = 8; + outConfig->quantizer_type = 0; + + while (1) { + rc = vec0_scanner_next(scanner, &token); + if (rc == VEC0_TOKEN_RESULT_EOF) { + break; + } + // ')' closes rescore options + if (rc == VEC0_TOKEN_RESULT_SOME && token.token_type == TOKEN_TYPE_RPAREN) { + break; + } + if (rc != VEC0_TOKEN_RESULT_SOME || token.token_type != TOKEN_TYPE_IDENTIFIER) { + *pzErr = sqlite3_mprintf("Expected option name in rescore(...)"); + return SQLITE_ERROR; + } + + char *key = token.start; + int keyLength = token.end - token.start; + + // expect '=' + rc = vec0_scanner_next(scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME || token.token_type != TOKEN_TYPE_EQ) { + *pzErr = sqlite3_mprintf("Expected '=' after option name in rescore(...)"); + return SQLITE_ERROR; + } + + // value + rc = vec0_scanner_next(scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME) { + *pzErr = sqlite3_mprintf("Expected value after '=' in rescore(...)"); + return SQLITE_ERROR; + } + + if (sqlite3_strnicmp(key, "quantizer", keyLength) == 0) { + if (token.token_type != TOKEN_TYPE_IDENTIFIER) { + *pzErr = sqlite3_mprintf("Expected identifier for quantizer value in rescore(...)"); + return SQLITE_ERROR; + } + int valLen = token.end - token.start; + if (sqlite3_strnicmp(token.start, "bit", valLen) == 0) { + outConfig->quantizer_type = VEC0_RESCORE_QUANTIZER_BIT; + } else if (sqlite3_strnicmp(token.start, "int8", valLen) == 0) { + outConfig->quantizer_type = VEC0_RESCORE_QUANTIZER_INT8; + } else { + *pzErr = sqlite3_mprintf("Unknown quantizer type '%.*s' in rescore(...). Expected 'bit' or 'int8'.", valLen, token.start); + return SQLITE_ERROR; + } + hasQuantizer = 1; + } else if (sqlite3_strnicmp(key, "oversample", keyLength) == 0) { + if (token.token_type != TOKEN_TYPE_DIGIT) { + *pzErr = sqlite3_mprintf("Expected integer for oversample value in rescore(...)"); + return SQLITE_ERROR; + } + outConfig->oversample = atoi(token.start); + if (outConfig->oversample <= 0 || outConfig->oversample > 128) { + *pzErr = sqlite3_mprintf("oversample in rescore(...) must be between 1 and 128, got %d", outConfig->oversample); + return SQLITE_ERROR; + } + } else { + *pzErr = sqlite3_mprintf("Unknown option '%.*s' in rescore(...)", keyLength, key); + return SQLITE_ERROR; + } + + // optional comma between options + rc = vec0_scanner_next(scanner, &token); + if (rc == VEC0_TOKEN_RESULT_EOF) { + break; + } + if (rc == VEC0_TOKEN_RESULT_SOME && token.token_type == TOKEN_TYPE_RPAREN) { + break; + } + if (rc == VEC0_TOKEN_RESULT_SOME && token.token_type == TOKEN_TYPE_COMMA) { + continue; + } + // If it's not a comma or rparen, it might be the next key — push back isn't + // possible with this scanner, so we'll treat unexpected tokens as errors + *pzErr = sqlite3_mprintf("Unexpected token in rescore(...) options"); + return SQLITE_ERROR; + } + + if (!hasQuantizer) { + *pzErr = sqlite3_mprintf("rescore(...) requires a 'quantizer' option (quantizer=bit or quantizer=int8)"); + return SQLITE_ERROR; + } + + return SQLITE_OK; +} +#endif /* SQLITE_VEC_ENABLE_RESCORE */ + /** * @brief Parse an vec0 vtab argv[i] column definition and see if * it's a vector column defintion, ex `contents_embedding float[768]`. @@ -2601,6 +2728,10 @@ int vec0_parse_vector_column(const char *source, int source_length, enum VectorElementType elementType; enum Vec0DistanceMetrics distanceMetric = VEC0_DISTANCE_METRIC_L2; enum Vec0IndexType indexType = VEC0_INDEX_TYPE_FLAT; +#if SQLITE_VEC_ENABLE_RESCORE + struct Vec0RescoreConfig rescoreConfig; + memset(&rescoreConfig, 0, sizeof(rescoreConfig)); +#endif int dimensions; vec0_scanner_init(&scanner, source, source_length); @@ -2704,6 +2835,7 @@ int vec0_parse_vector_column(const char *source, int source_length, return SQLITE_ERROR; } } + // INDEXED BY flat() | rescore(...) else if (sqlite3_strnicmp(key, "indexed", keyLength) == 0) { // expect "by" rc = vec0_scanner_next(&scanner, &token); @@ -2733,7 +2865,32 @@ int vec0_parse_vector_column(const char *source, int source_length, token.token_type != TOKEN_TYPE_RPAREN) { return SQLITE_ERROR; } - } else { + } +#if SQLITE_VEC_ENABLE_RESCORE + else if (sqlite3_strnicmp(token.start, "rescore", indexNameLen) == 0) { + indexType = VEC0_INDEX_TYPE_RESCORE; + if (elementType != SQLITE_VEC_ELEMENT_TYPE_FLOAT32) { + return SQLITE_ERROR; + } + // expect '(' + rc = vec0_scanner_next(&scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME || token.token_type != TOKEN_TYPE_LPAREN) { + return SQLITE_ERROR; + } + char *rescoreErr = NULL; + rc = vec0_parse_rescore_options(&scanner, &rescoreConfig, &rescoreErr); + if (rc != SQLITE_OK) { + if (rescoreErr) sqlite3_free(rescoreErr); + return SQLITE_ERROR; + } + // validate dimensions for bit quantizer + if (rescoreConfig.quantizer_type == VEC0_RESCORE_QUANTIZER_BIT && + (dimensions % CHAR_BIT) != 0) { + return SQLITE_ERROR; + } + } +#endif + else { // unknown index type return SQLITE_ERROR; } @@ -2753,6 +2910,9 @@ int vec0_parse_vector_column(const char *source, int source_length, outColumn->element_type = elementType; outColumn->dimensions = dimensions; outColumn->index_type = indexType; +#if SQLITE_VEC_ENABLE_RESCORE + outColumn->rescore = rescoreConfig; +#endif return SQLITE_OK; } @@ -3093,6 +3253,19 @@ struct vec0_vtab { // The first numVectorColumns entries must be freed with sqlite3_free() char *shadowVectorChunksNames[VEC0_MAX_VECTOR_COLUMNS]; +#if SQLITE_VEC_ENABLE_RESCORE + // Name of all rescore chunk shadow tables, ie `_rescore_chunks00` + // Only populated for vector columns with rescore enabled. + // Must be freed with sqlite3_free() + char *shadowRescoreChunksNames[VEC0_MAX_VECTOR_COLUMNS]; + + // Name of all rescore vector shadow tables, ie `_rescore_vectors00` + // Rowid-keyed table for fast random-access float vector reads during rescore. + // Only populated for vector columns with rescore enabled. + // Must be freed with sqlite3_free() + char *shadowRescoreVectorsNames[VEC0_MAX_VECTOR_COLUMNS]; +#endif + // Name of all metadata chunk shadow tables, ie `_metadatachunks00` // Only the first numMetadataColumns entries will be available. // The first numMetadataColumns entries must be freed with sqlite3_free() @@ -3162,6 +3335,18 @@ struct vec0_vtab { sqlite3_stmt *stmtRowidsGetChunkPosition; }; +#if SQLITE_VEC_ENABLE_RESCORE +// Forward declarations for rescore functions (defined in sqlite-vec-rescore.c, +// included later after all helpers they depend on are defined). +static int rescore_create_tables(vec0_vtab *p, sqlite3 *db, char **pzErr); +static int rescore_drop_tables(vec0_vtab *p); +static int rescore_new_chunk(vec0_vtab *p, i64 chunk_rowid); +static int rescore_on_insert(vec0_vtab *p, i64 chunk_rowid, i64 chunk_offset, + i64 rowid, void *vectorDatas[]); +static int rescore_on_delete(vec0_vtab *p, i64 chunk_id, u64 chunk_offset, i64 rowid); +static int rescore_delete_chunk(vec0_vtab *p, i64 chunk_id); +#endif + /** * @brief Finalize all the sqlite3_stmt members in a vec0_vtab. * @@ -3201,6 +3386,14 @@ void vec0_free(vec0_vtab *p) { sqlite3_free(p->shadowVectorChunksNames[i]); p->shadowVectorChunksNames[i] = NULL; +#if SQLITE_VEC_ENABLE_RESCORE + sqlite3_free(p->shadowRescoreChunksNames[i]); + p->shadowRescoreChunksNames[i] = NULL; + + sqlite3_free(p->shadowRescoreVectorsNames[i]); + p->shadowRescoreVectorsNames[i] = NULL; +#endif + sqlite3_free(p->vector_columns[i].name); p->vector_columns[i].name = NULL; } @@ -3493,6 +3686,41 @@ int vec0_get_vector_data(vec0_vtab *pVtab, i64 rowid, int vector_column_idx, assert((vector_column_idx >= 0) && (vector_column_idx < pVtab->numVectorColumns)); +#if SQLITE_VEC_ENABLE_RESCORE + // Rescore columns store float vectors in _rescore_vectors (rowid-keyed) + if (p->vector_columns[vector_column_idx].index_type == VEC0_INDEX_TYPE_RESCORE) { + size = vector_column_byte_size(p->vector_columns[vector_column_idx]); + rc = sqlite3_blob_open(p->db, p->schemaName, + p->shadowRescoreVectorsNames[vector_column_idx], + "vector", rowid, 0, &vectorBlob); + if (rc != SQLITE_OK) { + vtab_set_error(&pVtab->base, + "Could not fetch vector data for %lld from rescore vectors", + rowid); + rc = SQLITE_ERROR; + goto cleanup; + } + buf = sqlite3_malloc(size); + if (!buf) { + rc = SQLITE_NOMEM; + goto cleanup; + } + rc = sqlite3_blob_read(vectorBlob, buf, size, 0); + if (rc != SQLITE_OK) { + sqlite3_free(buf); + buf = NULL; + rc = SQLITE_ERROR; + goto cleanup; + } + *outVector = buf; + if (outVectorSize) { + *outVectorSize = size; + } + rc = SQLITE_OK; + goto cleanup; + } +#endif /* SQLITE_VEC_ENABLE_RESCORE */ + rc = vec0_get_chunk_position(pVtab, rowid, NULL, &chunk_id, &chunk_offset); if (rc == SQLITE_EMPTY) { vtab_set_error(&pVtab->base, "Could not find a row with rowid %lld", rowid); @@ -4096,6 +4324,14 @@ int vec0_new_chunk(vec0_vtab *p, sqlite3_value ** partitionKeyValues, i64 *chunk continue; } int vector_column_idx = p->user_column_idxs[i]; + +#if SQLITE_VEC_ENABLE_RESCORE + // Rescore columns don't use _vector_chunks for float storage + if (p->vector_columns[vector_column_idx].index_type == VEC0_INDEX_TYPE_RESCORE) { + continue; + } +#endif + i64 vectorsSize = p->chunk_size * vector_column_byte_size(p->vector_columns[vector_column_idx]); @@ -4126,6 +4362,14 @@ int vec0_new_chunk(vec0_vtab *p, sqlite3_value ** partitionKeyValues, i64 *chunk } } +#if SQLITE_VEC_ENABLE_RESCORE + // Create new rescore chunks for each rescore-enabled vector column + rc = rescore_new_chunk(p, rowid); + if (rc != SQLITE_OK) { + return rc; + } +#endif + // Step 3: Create new metadata chunks for each metadata column for (int i = 0; i < vec0_num_defined_user_columns(p); i++) { if(p->user_column_kinds[i] != SQLITE_VEC0_USER_COLUMN_KIND_METADATA) { @@ -4487,6 +4731,35 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, goto error; } +#if SQLITE_VEC_ENABLE_RESCORE + { + int hasRescore = 0; + for (int i = 0; i < numVectorColumns; i++) { + if (pNew->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE) { + hasRescore = 1; + break; + } + } + if (hasRescore) { + if (numAuxiliaryColumns > 0) { + *pzErr = sqlite3_mprintf(VEC_CONSTRUCTOR_ERROR + "Auxiliary columns are not supported with rescore indexes"); + goto error; + } + if (numMetadataColumns > 0) { + *pzErr = sqlite3_mprintf(VEC_CONSTRUCTOR_ERROR + "Metadata columns are not supported with rescore indexes"); + goto error; + } + if (numPartitionColumns > 0) { + *pzErr = sqlite3_mprintf(VEC_CONSTRUCTOR_ERROR + "Partition key columns are not supported with rescore indexes"); + goto error; + } + } + } +#endif + sqlite3_str *createStr = sqlite3_str_new(NULL); sqlite3_str_appendall(createStr, "CREATE TABLE x("); if (pkColumnName) { @@ -4577,6 +4850,20 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, if (!pNew->shadowVectorChunksNames[i]) { goto error; } +#if SQLITE_VEC_ENABLE_RESCORE + if (pNew->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE) { + pNew->shadowRescoreChunksNames[i] = + sqlite3_mprintf("%s_rescore_chunks%02d", tableName, i); + if (!pNew->shadowRescoreChunksNames[i]) { + goto error; + } + pNew->shadowRescoreVectorsNames[i] = + sqlite3_mprintf("%s_rescore_vectors%02d", tableName, i); + if (!pNew->shadowRescoreVectorsNames[i]) { + goto error; + } + } +#endif } for (int i = 0; i < pNew->numMetadataColumns; i++) { pNew->shadowMetadataChunksNames[i] = @@ -4700,6 +4987,11 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, sqlite3_finalize(stmt); for (int i = 0; i < pNew->numVectorColumns; i++) { +#if SQLITE_VEC_ENABLE_RESCORE + // Rescore columns don't use _vector_chunks + if (pNew->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE) + continue; +#endif char *zSql = sqlite3_mprintf(VEC0_SHADOW_VECTOR_N_CREATE, pNew->schemaName, pNew->tableName, i); if (!zSql) { @@ -4718,6 +5010,13 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, sqlite3_finalize(stmt); } +#if SQLITE_VEC_ENABLE_RESCORE + rc = rescore_create_tables(pNew, db, pzErr); + if (rc != SQLITE_OK) { + goto error; + } +#endif + // See SHADOW_TABLE_ROWID_QUIRK in vec0_new_chunk() — same "rowid PRIMARY KEY" // without INTEGER type issue applies here. for (int i = 0; i < pNew->numMetadataColumns; i++) { @@ -4852,6 +5151,10 @@ static int vec0Destroy(sqlite3_vtab *pVtab) { sqlite3_finalize(stmt); for (int i = 0; i < p->numVectorColumns; i++) { +#if SQLITE_VEC_ENABLE_RESCORE + if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE) + continue; +#endif zSql = sqlite3_mprintf("DROP TABLE \"%w\".\"%w\"", p->schemaName, p->shadowVectorChunksNames[i]); rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0); @@ -4863,6 +5166,13 @@ static int vec0Destroy(sqlite3_vtab *pVtab) { sqlite3_finalize(stmt); } +#if SQLITE_VEC_ENABLE_RESCORE + rc = rescore_drop_tables(p); + if (rc != SQLITE_OK) { + goto done; + } +#endif + if(p->numAuxiliaryColumns > 0) { zSql = sqlite3_mprintf("DROP TABLE " VEC0_SHADOW_AUXILIARY_NAME, p->schemaName, p->tableName); rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0); @@ -6624,6 +6934,10 @@ cleanup: return rc; } +#if SQLITE_VEC_ENABLE_RESCORE +#include "sqlite-vec-rescore.c" +#endif + int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum, const char *idxStr, int argc, sqlite3_value **argv) { assert(argc == (strlen(idxStr)-1) / 4); @@ -6856,6 +7170,21 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum, } #endif +#if SQLITE_VEC_ENABLE_RESCORE + // Dispatch to rescore KNN path if this vector column has rescore enabled + if (vector_column->index_type == VEC0_INDEX_TYPE_RESCORE) { + rc = rescore_knn(p, pCur, vector_column, vectorColumnIdx, arrayRowidsIn, + aMetadataIn, idxStr, argc, argv, queryVector, k, knn_data); + if (rc != SQLITE_OK) { + goto cleanup; + } + pCur->knn_data = knn_data; + pCur->query_plan = VEC0_QUERY_PLAN_KNN; + rc = SQLITE_OK; + goto cleanup; + } +#endif + rc = vec0_chunks_iter(p, idxStr, argc, argv, &stmtChunks); if (rc != SQLITE_OK) { // IMP: V06942_23781 @@ -7680,6 +8009,12 @@ int vec0Update_InsertWriteFinalStep(vec0_vtab *p, i64 chunk_rowid, // Go insert the vector data into the vector chunk shadow tables for (int i = 0; i < p->numVectorColumns; i++) { +#if SQLITE_VEC_ENABLE_RESCORE + // Rescore columns store float vectors in _rescore_vectors instead + if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE) + continue; +#endif + sqlite3_blob *blobVectors; rc = sqlite3_blob_open(p->db, p->schemaName, p->shadowVectorChunksNames[i], "vectors", chunk_rowid, 1, &blobVectors); @@ -8082,6 +8417,13 @@ int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv, goto cleanup; } +#if SQLITE_VEC_ENABLE_RESCORE + rc = rescore_on_insert(p, chunk_rowid, chunk_offset, rowid, vectorDatas); + if (rc != SQLITE_OK) { + goto cleanup; + } +#endif + if(p->numAuxiliaryColumns > 0) { sqlite3_stmt *stmt; sqlite3_str * s = sqlite3_str_new(NULL); @@ -8272,6 +8614,11 @@ int vec0Update_Delete_ClearVectors(vec0_vtab *p, i64 chunk_id, u64 chunk_offset) { int rc, brc; for (int i = 0; i < p->numVectorColumns; i++) { +#if SQLITE_VEC_ENABLE_RESCORE + // Rescore columns don't use _vector_chunks + if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE) + continue; +#endif sqlite3_blob *blobVectors = NULL; size_t n = vector_column_byte_size(p->vector_columns[i]); @@ -8383,6 +8730,10 @@ int vec0Update_Delete_DeleteChunkIfEmpty(vec0_vtab *p, i64 chunk_id, // Delete from each _vector_chunksNN for (int i = 0; i < p->numVectorColumns; i++) { +#if SQLITE_VEC_ENABLE_RESCORE + if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE) + continue; +#endif zSql = sqlite3_mprintf( "DELETE FROM " VEC0_SHADOW_VECTOR_N_NAME " WHERE rowid = ?", p->schemaName, p->tableName, i); @@ -8399,6 +8750,12 @@ int vec0Update_Delete_DeleteChunkIfEmpty(vec0_vtab *p, i64 chunk_id, return SQLITE_ERROR; } +#if SQLITE_VEC_ENABLE_RESCORE + rc = rescore_delete_chunk(p, chunk_id); + if (rc != SQLITE_OK) + return rc; +#endif + // Delete from each _metadatachunksNN for (int i = 0; i < p->numMetadataColumns; i++) { zSql = sqlite3_mprintf( @@ -8606,6 +8963,14 @@ int vec0Update_Delete(sqlite3_vtab *pVTab, sqlite3_value *idValue) { return rc; } +#if SQLITE_VEC_ENABLE_RESCORE + // 4b. zero out quantized data in rescore chunk tables, delete from rescore vectors + rc = rescore_on_delete(p, chunk_id, chunk_offset, rowid); + if (rc != SQLITE_OK) { + return rc; + } +#endif + // 5. delete from _rowids table rc = vec0Update_Delete_DeleteRowids(p, rowid); if (rc != SQLITE_OK) { @@ -8663,8 +9028,11 @@ int vec0Update_UpdateAuxColumn(vec0_vtab *p, int auxiliary_column_idx, sqlite3_v } int vec0Update_UpdateVectorColumn(vec0_vtab *p, i64 chunk_id, i64 chunk_offset, - int i, sqlite3_value *valueVector) { + int i, sqlite3_value *valueVector, i64 rowid) { int rc; +#if !SQLITE_VEC_ENABLE_RESCORE + UNUSED_PARAMETER(rowid); +#endif sqlite3_blob *blobVectors = NULL; @@ -8708,6 +9076,59 @@ int vec0Update_UpdateVectorColumn(vec0_vtab *p, i64 chunk_id, i64 chunk_offset, goto cleanup; } +#if SQLITE_VEC_ENABLE_RESCORE + if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_RESCORE) { + // For rescore columns, update _rescore_vectors and _rescore_chunks + struct VectorColumnDefinition *col = &p->vector_columns[i]; + size_t qsize = rescore_quantized_byte_size(col); + size_t fsize = vector_column_byte_size(*col); + + // 1. Update quantized chunk + { + void *qbuf = sqlite3_malloc(qsize); + if (!qbuf) { rc = SQLITE_NOMEM; goto cleanup; } + switch (col->rescore.quantizer_type) { + case VEC0_RESCORE_QUANTIZER_BIT: + rescore_quantize_float_to_bit((const float *)vector, (uint8_t *)qbuf, col->dimensions); + break; + case VEC0_RESCORE_QUANTIZER_INT8: + rescore_quantize_float_to_int8((const float *)vector, (int8_t *)qbuf, col->dimensions); + break; + } + sqlite3_blob *blobQ = NULL; + rc = sqlite3_blob_open(p->db, p->schemaName, + p->shadowRescoreChunksNames[i], "vectors", + chunk_id, 1, &blobQ); + if (rc != SQLITE_OK) { sqlite3_free(qbuf); goto cleanup; } + rc = sqlite3_blob_write(blobQ, qbuf, qsize, chunk_offset * qsize); + sqlite3_free(qbuf); + int brc2 = sqlite3_blob_close(blobQ); + if (rc != SQLITE_OK) goto cleanup; + if (brc2 != SQLITE_OK) { rc = brc2; goto cleanup; } + } + + // 2. Update float vector in _rescore_vectors (keyed by user rowid) + { + char *zSql = sqlite3_mprintf( + "UPDATE \"%w\".\"%w\" SET vector = ? WHERE rowid = ?", + p->schemaName, p->shadowRescoreVectorsNames[i]); + if (!zSql) { rc = SQLITE_NOMEM; goto cleanup; } + sqlite3_stmt *stmtUp; + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmtUp, NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) goto cleanup; + sqlite3_bind_blob(stmtUp, 1, vector, fsize, SQLITE_TRANSIENT); + sqlite3_bind_int64(stmtUp, 2, rowid); + rc = sqlite3_step(stmtUp); + sqlite3_finalize(stmtUp); + if (rc != SQLITE_DONE) { rc = SQLITE_ERROR; goto cleanup; } + } + + rc = SQLITE_OK; + goto cleanup; + } +#endif + rc = sqlite3_blob_open(p->db, p->schemaName, p->shadowVectorChunksNames[i], "vectors", chunk_id, 1, &blobVectors); if (rc != SQLITE_OK) { @@ -8839,7 +9260,7 @@ int vec0Update_Update(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv) { } rc = vec0Update_UpdateVectorColumn(p, chunk_id, chunk_offset, vector_idx, - valueVector); + valueVector, rowid); if (rc != SQLITE_OK) { return SQLITE_ERROR; } @@ -8997,9 +9418,15 @@ static sqlite3_module vec0Module = { #else #define SQLITE_VEC_DEBUG_BUILD_NEON "" #endif +#if SQLITE_VEC_ENABLE_RESCORE +#define SQLITE_VEC_DEBUG_BUILD_RESCORE "rescore" +#else +#define SQLITE_VEC_DEBUG_BUILD_RESCORE "" +#endif #define SQLITE_VEC_DEBUG_BUILD \ - SQLITE_VEC_DEBUG_BUILD_AVX " " SQLITE_VEC_DEBUG_BUILD_NEON + SQLITE_VEC_DEBUG_BUILD_AVX " " SQLITE_VEC_DEBUG_BUILD_NEON " " \ + SQLITE_VEC_DEBUG_BUILD_RESCORE #define SQLITE_VEC_DEBUG_STRING \ "Version: " SQLITE_VEC_VERSION "\n" \ diff --git a/tests/fuzz/.gitignore b/tests/fuzz/.gitignore index 757d1ac..b9c7d30 100644 --- a/tests/fuzz/.gitignore +++ b/tests/fuzz/.gitignore @@ -1,2 +1,7 @@ *.dSYM targets/ +corpus/ +crash-* +leak-* +timeout-* +*.log diff --git a/tests/fuzz/Makefile b/tests/fuzz/Makefile index 21629ef..0030c2e 100644 --- a/tests/fuzz/Makefile +++ b/tests/fuzz/Makefile @@ -72,10 +72,34 @@ $(TARGET_DIR)/vec_mismatch: vec-mismatch.c $(FUZZ_SRCS) | $(TARGET_DIR) $(TARGET_DIR)/vec0_delete_completeness: vec0-delete-completeness.c $(FUZZ_SRCS) | $(TARGET_DIR) $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@ +$(TARGET_DIR)/rescore_operations: rescore-operations.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE $(FUZZ_SRCS) $< -o $@ + +$(TARGET_DIR)/rescore_create: rescore-create.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE $(FUZZ_SRCS) $< -o $@ + +$(TARGET_DIR)/rescore_quantize: rescore-quantize.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE -DSQLITE_VEC_TEST $(FUZZ_SRCS) $< -o $@ + +$(TARGET_DIR)/rescore_shadow_corrupt: rescore-shadow-corrupt.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE $(FUZZ_SRCS) $< -o $@ + +$(TARGET_DIR)/rescore_knn_deep: rescore-knn-deep.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE $(FUZZ_SRCS) $< -o $@ + +$(TARGET_DIR)/rescore_quantize_edge: rescore-quantize-edge.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE -DSQLITE_VEC_TEST $(FUZZ_SRCS) $< -o $@ + +$(TARGET_DIR)/rescore_interleave: rescore-interleave.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) -DSQLITE_VEC_ENABLE_RESCORE $(FUZZ_SRCS) $< -o $@ + FUZZ_TARGETS = vec0_create exec json numpy \ shadow_corrupt vec0_operations scalar_functions \ vec0_create_full metadata_columns vec_each vec_mismatch \ - vec0_delete_completeness + vec0_delete_completeness \ + rescore_operations rescore_create rescore_quantize \ + rescore_shadow_corrupt rescore_knn_deep \ + rescore_quantize_edge rescore_interleave all: $(addprefix $(TARGET_DIR)/,$(FUZZ_TARGETS)) diff --git a/tests/fuzz/rescore-create.c b/tests/fuzz/rescore-create.c new file mode 100644 index 0000000..3e69d6d --- /dev/null +++ b/tests/fuzz/rescore-create.c @@ -0,0 +1,36 @@ +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + int rc = SQLITE_OK; + sqlite3 *db; + sqlite3_stmt *stmt; + + rc = sqlite3_open(":memory:", &db); + assert(rc == SQLITE_OK); + rc = sqlite3_vec_init(db, NULL, NULL); + assert(rc == SQLITE_OK); + + sqlite3_str *s = sqlite3_str_new(NULL); + assert(s); + sqlite3_str_appendall(s, "CREATE VIRTUAL TABLE v USING vec0(emb float[128] indexed by rescore("); + sqlite3_str_appendf(s, "%.*s", (int)size, data); + sqlite3_str_appendall(s, "))"); + const char *zSql = sqlite3_str_finish(s); + assert(zSql); + + rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, NULL); + sqlite3_free((void *)zSql); + if (rc == SQLITE_OK) { + sqlite3_step(stmt); + } + sqlite3_finalize(stmt); + sqlite3_close(db); + return 0; +} diff --git a/tests/fuzz/rescore-create.dict b/tests/fuzz/rescore-create.dict new file mode 100644 index 0000000..a8adf71 --- /dev/null +++ b/tests/fuzz/rescore-create.dict @@ -0,0 +1,20 @@ +"rescore" +"quantizer" +"bit" +"int8" +"oversample" +"indexed" +"by" +"float" +"(" +")" +"," +"=" +"[" +"]" +"1" +"8" +"16" +"128" +"256" +"1024" diff --git a/tests/fuzz/rescore-interleave.c b/tests/fuzz/rescore-interleave.c new file mode 100644 index 0000000..74e8b8d --- /dev/null +++ b/tests/fuzz/rescore-interleave.c @@ -0,0 +1,151 @@ +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +/** + * Fuzz target: interleaved insert/update/delete/KNN operations on rescore + * tables with BOTH quantizer types, exercising the int8 quantizer path + * and the update code path that the existing rescore-operations.c misses. + * + * Key differences from rescore-operations.c: + * - Tests BOTH bit and int8 quantizers (the existing target only tests bit) + * - Fuzz-controlled query vectors (not fixed [1,0,0,...]) + * - Exercises the UPDATE path (line 9080+ in sqlite-vec.c) + * - Tests with 16 dimensions (more realistic, exercises more of the + * quantization loop) + * - Interleaves KNN between mutations to stress the blob_reopen path + * when _rescore_vectors rows have been deleted/modified + */ +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + if (size < 8) return 0; + + int rc; + sqlite3 *db; + sqlite3_stmt *stmtInsert = NULL; + sqlite3_stmt *stmtUpdate = NULL; + sqlite3_stmt *stmtDelete = NULL; + sqlite3_stmt *stmtKnn = NULL; + + rc = sqlite3_open(":memory:", &db); + assert(rc == SQLITE_OK); + rc = sqlite3_vec_init(db, NULL, NULL); + assert(rc == SQLITE_OK); + + /* Use first byte to pick quantizer */ + int use_int8 = data[0] & 1; + data++; size--; + + const char *create_sql = use_int8 + ? "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[16] indexed by rescore(quantizer=int8))" + : "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[16] indexed by rescore(quantizer=bit))"; + + rc = sqlite3_exec(db, create_sql, NULL, NULL, NULL); + if (rc != SQLITE_OK) { sqlite3_close(db); return 0; } + + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); + sqlite3_prepare_v2(db, + "UPDATE v SET emb = ? WHERE rowid = ?", -1, &stmtUpdate, NULL); + sqlite3_prepare_v2(db, + "DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL); + sqlite3_prepare_v2(db, + "SELECT rowid, distance FROM v WHERE emb MATCH ? " + "ORDER BY distance LIMIT 5", -1, &stmtKnn, NULL); + + if (!stmtInsert || !stmtUpdate || !stmtDelete || !stmtKnn) + goto cleanup; + + size_t i = 0; + while (i + 2 <= size) { + uint8_t op = data[i++] % 5; /* 5 operations now */ + uint8_t rowid_byte = data[i++]; + int64_t rowid = (int64_t)(rowid_byte % 24) + 1; + + switch (op) { + case 0: { + /* INSERT: consume bytes for 16 floats */ + float vec[16] = {0}; + for (int j = 0; j < 16 && i < size; j++, i++) { + vec[j] = (float)((int8_t)data[i]) / 8.0f; + } + sqlite3_reset(stmtInsert); + sqlite3_bind_int64(stmtInsert, 1, rowid); + sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT); + sqlite3_step(stmtInsert); + break; + } + case 1: { + /* DELETE */ + sqlite3_reset(stmtDelete); + sqlite3_bind_int64(stmtDelete, 1, rowid); + sqlite3_step(stmtDelete); + break; + } + case 2: { + /* KNN with fuzz-controlled query vector */ + float qvec[16] = {0}; + for (int j = 0; j < 16 && i < size; j++, i++) { + qvec[j] = (float)((int8_t)data[i]) / 4.0f; + } + sqlite3_reset(stmtKnn); + sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_STATIC); + while (sqlite3_step(stmtKnn) == SQLITE_ROW) { + (void)sqlite3_column_int64(stmtKnn, 0); + (void)sqlite3_column_double(stmtKnn, 1); + } + break; + } + case 3: { + /* UPDATE: modify an existing vector (exercises rescore update path) */ + float vec[16] = {0}; + for (int j = 0; j < 16 && i < size; j++, i++) { + vec[j] = (float)((int8_t)data[i]) / 6.0f; + } + sqlite3_reset(stmtUpdate); + sqlite3_bind_blob(stmtUpdate, 1, vec, sizeof(vec), SQLITE_TRANSIENT); + sqlite3_bind_int64(stmtUpdate, 2, rowid); + sqlite3_step(stmtUpdate); + break; + } + case 4: { + /* INSERT then immediately UPDATE same row (stresses blob lifecycle) */ + float vec1[16] = {0}; + float vec2[16] = {0}; + for (int j = 0; j < 16 && i < size; j++, i++) { + vec1[j] = (float)((int8_t)data[i]) / 10.0f; + vec2[j] = -vec1[j]; /* opposite direction */ + } + /* Insert */ + sqlite3_reset(stmtInsert); + sqlite3_bind_int64(stmtInsert, 1, rowid); + sqlite3_bind_blob(stmtInsert, 2, vec1, sizeof(vec1), SQLITE_TRANSIENT); + if (sqlite3_step(stmtInsert) == SQLITE_DONE) { + /* Only update if insert succeeded (rowid might already exist) */ + sqlite3_reset(stmtUpdate); + sqlite3_bind_blob(stmtUpdate, 1, vec2, sizeof(vec2), SQLITE_TRANSIENT); + sqlite3_bind_int64(stmtUpdate, 2, rowid); + sqlite3_step(stmtUpdate); + } + break; + } + } + } + + /* Final consistency check: full scan must not crash */ + sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL); + +cleanup: + sqlite3_finalize(stmtInsert); + sqlite3_finalize(stmtUpdate); + sqlite3_finalize(stmtDelete); + sqlite3_finalize(stmtKnn); + sqlite3_close(db); + return 0; +} diff --git a/tests/fuzz/rescore-knn-deep.c b/tests/fuzz/rescore-knn-deep.c new file mode 100644 index 0000000..8ff3c37 --- /dev/null +++ b/tests/fuzz/rescore-knn-deep.c @@ -0,0 +1,178 @@ +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +/** + * Fuzz target: deep exercise of rescore KNN with fuzz-controlled query vectors + * and both quantizer types (bit + int8), multiple distance metrics. + * + * The existing rescore-operations.c only tests bit quantizer with a fixed + * query vector. This target: + * - Tests both bit and int8 quantizers + * - Uses fuzz-controlled query vectors (hits NaN/Inf/denormal paths) + * - Tests all distance metrics with int8 (L2, cosine, L1) + * - Exercises large LIMIT values (oversample multiplication) + * - Tests KNN with rowid IN constraints + * - Exercises the insert->query->update->query->delete->query cycle + */ +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + if (size < 20) return 0; + + int rc; + sqlite3 *db; + + rc = sqlite3_open(":memory:", &db); + assert(rc == SQLITE_OK); + rc = sqlite3_vec_init(db, NULL, NULL); + assert(rc == SQLITE_OK); + + /* Use first 4 bytes for configuration */ + uint8_t config = data[0]; + uint8_t num_inserts = (data[1] % 20) + 3; /* 3..22 inserts */ + uint8_t limit_val = (data[2] % 50) + 1; /* 1..50 for LIMIT */ + uint8_t metric_choice = data[3] % 3; + data += 4; + size -= 4; + + int use_int8 = config & 1; + const char *metric_str; + switch (metric_choice) { + case 0: metric_str = ""; break; /* default L2 */ + case 1: metric_str = " distance_metric=cosine"; break; + case 2: metric_str = " distance_metric=l1"; break; + default: metric_str = ""; break; + } + + /* Build CREATE TABLE statement */ + char create_sql[256]; + if (use_int8) { + snprintf(create_sql, sizeof(create_sql), + "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[16] indexed by rescore(quantizer=int8)%s)", metric_str); + } else { + /* bit quantizer ignores distance_metric for the coarse pass (always hamming), + but the float rescore phase uses the specified metric */ + snprintf(create_sql, sizeof(create_sql), + "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[16] indexed by rescore(quantizer=bit)%s)", metric_str); + } + + rc = sqlite3_exec(db, create_sql, NULL, NULL, NULL); + if (rc != SQLITE_OK) { sqlite3_close(db); return 0; } + + /* Insert vectors using fuzz data */ + { + sqlite3_stmt *ins = NULL; + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &ins, NULL); + if (!ins) { sqlite3_close(db); return 0; } + + size_t cursor = 0; + for (int i = 0; i < num_inserts && cursor + 1 < size; i++) { + float vec[16]; + for (int j = 0; j < 16; j++) { + if (cursor < size) { + /* Map fuzz byte to float -- includes potential for + interesting float values via reinterpretation */ + int8_t sb = (int8_t)data[cursor++]; + vec[j] = (float)sb / 5.0f; + } else { + vec[j] = 0.0f; + } + } + sqlite3_reset(ins); + sqlite3_bind_int64(ins, 1, (sqlite3_int64)(i + 1)); + sqlite3_bind_blob(ins, 2, vec, sizeof(vec), SQLITE_TRANSIENT); + sqlite3_step(ins); + } + sqlite3_finalize(ins); + } + + /* Build a fuzz-controlled query vector from remaining data */ + float qvec[16] = {0}; + { + size_t cursor = 0; + for (int j = 0; j < 16 && cursor < size; j++) { + int8_t sb = (int8_t)data[cursor++]; + qvec[j] = (float)sb / 3.0f; + } + } + + /* KNN query with fuzz-controlled vector and LIMIT */ + { + char knn_sql[256]; + snprintf(knn_sql, sizeof(knn_sql), + "SELECT rowid, distance FROM v WHERE emb MATCH ? " + "ORDER BY distance LIMIT %d", (int)limit_val); + + sqlite3_stmt *knn = NULL; + sqlite3_prepare_v2(db, knn_sql, -1, &knn, NULL); + if (knn) { + sqlite3_bind_blob(knn, 1, qvec, sizeof(qvec), SQLITE_STATIC); + while (sqlite3_step(knn) == SQLITE_ROW) { + /* Read results to ensure distance computation didn't produce garbage + that crashes the cursor iteration */ + (void)sqlite3_column_int64(knn, 0); + (void)sqlite3_column_double(knn, 1); + } + sqlite3_finalize(knn); + } + } + + /* Update some vectors, then query again */ + { + float uvec[16]; + for (int j = 0; j < 16; j++) uvec[j] = qvec[15 - j]; /* reverse of query */ + sqlite3_stmt *upd = NULL; + sqlite3_prepare_v2(db, + "UPDATE v SET emb = ? WHERE rowid = 1", -1, &upd, NULL); + if (upd) { + sqlite3_bind_blob(upd, 1, uvec, sizeof(uvec), SQLITE_STATIC); + sqlite3_step(upd); + sqlite3_finalize(upd); + } + } + + /* Second KNN after update */ + { + sqlite3_stmt *knn = NULL; + sqlite3_prepare_v2(db, + "SELECT rowid, distance FROM v WHERE emb MATCH ? " + "ORDER BY distance LIMIT 10", -1, &knn, NULL); + if (knn) { + sqlite3_bind_blob(knn, 1, qvec, sizeof(qvec), SQLITE_STATIC); + while (sqlite3_step(knn) == SQLITE_ROW) {} + sqlite3_finalize(knn); + } + } + + /* Delete half the rows, then KNN again */ + for (int i = 1; i <= num_inserts; i += 2) { + char del_sql[64]; + snprintf(del_sql, sizeof(del_sql), + "DELETE FROM v WHERE rowid = %d", i); + sqlite3_exec(db, del_sql, NULL, NULL, NULL); + } + + /* Third KNN after deletes -- exercises distance computation over + zeroed-out slots in the quantized chunk */ + { + sqlite3_stmt *knn = NULL; + sqlite3_prepare_v2(db, + "SELECT rowid, distance FROM v WHERE emb MATCH ? " + "ORDER BY distance LIMIT 5", -1, &knn, NULL); + if (knn) { + sqlite3_bind_blob(knn, 1, qvec, sizeof(qvec), SQLITE_STATIC); + while (sqlite3_step(knn) == SQLITE_ROW) {} + sqlite3_finalize(knn); + } + } + + sqlite3_close(db); + return 0; +} diff --git a/tests/fuzz/rescore-operations.c b/tests/fuzz/rescore-operations.c new file mode 100644 index 0000000..4bb7ff1 --- /dev/null +++ b/tests/fuzz/rescore-operations.c @@ -0,0 +1,96 @@ +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + if (size < 6) return 0; + + int rc; + sqlite3 *db; + sqlite3_stmt *stmtInsert = NULL; + sqlite3_stmt *stmtDelete = NULL; + sqlite3_stmt *stmtKnn = NULL; + sqlite3_stmt *stmtScan = NULL; + + rc = sqlite3_open(":memory:", &db); + assert(rc == SQLITE_OK); + rc = sqlite3_vec_init(db, NULL, NULL); + assert(rc == SQLITE_OK); + + rc = sqlite3_exec(db, + "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[8] indexed by rescore(quantizer=bit))", + NULL, NULL, NULL); + if (rc != SQLITE_OK) { sqlite3_close(db); return 0; } + + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); + sqlite3_prepare_v2(db, + "DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL); + sqlite3_prepare_v2(db, + "SELECT rowid, distance FROM v WHERE emb MATCH ? ORDER BY distance LIMIT 3", + -1, &stmtKnn, NULL); + sqlite3_prepare_v2(db, + "SELECT rowid FROM v", -1, &stmtScan, NULL); + + if (!stmtInsert || !stmtDelete || !stmtKnn || !stmtScan) goto cleanup; + + size_t i = 0; + while (i + 2 <= size) { + uint8_t op = data[i++] % 4; + uint8_t rowid_byte = data[i++]; + int64_t rowid = (int64_t)(rowid_byte % 32) + 1; + + switch (op) { + case 0: { + // INSERT: consume 32 bytes for 8 floats, or use what's left + float vec[8] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + for (int j = 0; j < 8 && i < size; j++, i++) { + vec[j] = (float)((int8_t)data[i]) / 10.0f; + } + sqlite3_reset(stmtInsert); + sqlite3_bind_int64(stmtInsert, 1, rowid); + sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT); + sqlite3_step(stmtInsert); + break; + } + case 1: { + // DELETE + sqlite3_reset(stmtDelete); + sqlite3_bind_int64(stmtDelete, 1, rowid); + sqlite3_step(stmtDelete); + break; + } + case 2: { + // KNN query with a fixed query vector + float qvec[8] = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + sqlite3_reset(stmtKnn); + sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_STATIC); + while (sqlite3_step(stmtKnn) == SQLITE_ROW) {} + break; + } + case 3: { + // Full scan + sqlite3_reset(stmtScan); + while (sqlite3_step(stmtScan) == SQLITE_ROW) {} + break; + } + } + } + + // Final operations -- must not crash regardless of prior state + sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL); + +cleanup: + sqlite3_finalize(stmtInsert); + sqlite3_finalize(stmtDelete); + sqlite3_finalize(stmtKnn); + sqlite3_finalize(stmtScan); + sqlite3_close(db); + return 0; +} diff --git a/tests/fuzz/rescore-quantize-edge.c b/tests/fuzz/rescore-quantize-edge.c new file mode 100644 index 0000000..4ab9e20 --- /dev/null +++ b/tests/fuzz/rescore-quantize-edge.c @@ -0,0 +1,177 @@ +#include +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +/* Test wrappers from sqlite-vec-rescore.c (SQLITE_VEC_TEST build) */ +extern void _test_rescore_quantize_float_to_bit(const float *src, uint8_t *dst, size_t dim); +extern void _test_rescore_quantize_float_to_int8(const float *src, int8_t *dst, size_t dim); +extern size_t _test_rescore_quantized_byte_size_bit(size_t dimensions); +extern size_t _test_rescore_quantized_byte_size_int8(size_t dimensions); + +/** + * Fuzz target: edge cases in rescore quantization functions. + * + * The existing rescore-quantize.c only tests dimensions that are multiples of 8 + * and never passes special float values. This target: + * + * - Tests rescore_quantized_byte_size with arbitrary dimension values + * (including 0, 1, 7, MAX values -- looking for integer division issues) + * - Passes raw float reinterpretation of fuzz bytes (NaN, Inf, denormals, + * negative zero -- these are the values that break min/max/range logic) + * - Tests the int8 quantizer with all-identical values (range=0 branch) + * - Tests the int8 quantizer with extreme ranges (overflow in scale calc) + * - Tests bit quantizer with exact float threshold (0.0f boundary) + */ +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + if (size < 8) return 0; + + uint8_t mode = data[0] % 5; + data++; size--; + + switch (mode) { + case 0: { + /* Test rescore_quantized_byte_size with fuzz-controlled dimensions. + This function does dimensions / CHAR_BIT for bit, dimensions for int8. + We're checking it doesn't do anything weird with edge values. */ + if (size < sizeof(size_t)) return 0; + size_t dim; + memcpy(&dim, data, sizeof(dim)); + + /* These should never crash, just return values */ + size_t bit_size = _test_rescore_quantized_byte_size_bit(dim); + size_t int8_size = _test_rescore_quantized_byte_size_int8(dim); + + /* Verify basic invariants */ + (void)bit_size; + (void)int8_size; + break; + } + + case 1: { + /* Bit quantize with raw reinterpreted floats (NaN, Inf, denormal). + The key check: src[i] >= 0.0f -- NaN comparison is always false, + so NaN should produce 0-bits. But denormals and -0.0f are tricky. */ + size_t num_floats = size / sizeof(float); + if (num_floats == 0) return 0; + /* Round to multiple of 8 for bit quantizer */ + size_t dim = (num_floats / 8) * 8; + if (dim == 0) return 0; + + const float *src = (const float *)data; + size_t bit_bytes = dim / 8; + uint8_t *dst = (uint8_t *)malloc(bit_bytes); + if (!dst) return 0; + + _test_rescore_quantize_float_to_bit(src, dst, dim); + + /* Verify: for each bit, if src >= 0 then bit should be set */ + for (size_t i = 0; i < dim; i++) { + int bit_set = (dst[i / 8] >> (i % 8)) & 1; + if (src[i] >= 0.0f) { + assert(bit_set == 1); + } else if (src[i] < 0.0f) { + /* Definitely negative -- bit must be 0 */ + assert(bit_set == 0); + } + /* NaN: comparison is false, so bit_set should be 0 */ + } + + free(dst); + break; + } + + case 2: { + /* Int8 quantize with raw reinterpreted floats. + The dangerous paths: + - All values identical (range == 0) -> memset path + - vmin/vmax with NaN (NaN < anything is false, NaN > anything is false) + - Extreme range causing scale = 255/range to be Inf or 0 + - denormals near the clamping boundaries */ + size_t num_floats = size / sizeof(float); + if (num_floats == 0) return 0; + + const float *src = (const float *)data; + int8_t *dst = (int8_t *)malloc(num_floats); + if (!dst) return 0; + + _test_rescore_quantize_float_to_int8(src, dst, num_floats); + + /* Output must always be in [-128, 127] (trivially true for int8_t, + but check the actual clamping logic worked) */ + for (size_t i = 0; i < num_floats; i++) { + assert(dst[i] >= -128 && dst[i] <= 127); + } + + free(dst); + break; + } + + case 3: { + /* Int8 quantize stress: all-same values (range=0 branch) */ + size_t dim = (size < 64) ? size : 64; + if (dim == 0) return 0; + + float *src = (float *)malloc(dim * sizeof(float)); + int8_t *dst = (int8_t *)malloc(dim); + if (!src || !dst) { free(src); free(dst); return 0; } + + /* Fill with a single value derived from fuzz data */ + float val; + memcpy(&val, data, sizeof(float) < size ? sizeof(float) : size); + for (size_t i = 0; i < dim; i++) src[i] = val; + + _test_rescore_quantize_float_to_int8(src, dst, dim); + + /* All outputs should be 0 when range == 0 */ + for (size_t i = 0; i < dim; i++) { + assert(dst[i] == 0); + } + + free(src); + free(dst); + break; + } + + case 4: { + /* Int8 quantize with extreme range: one huge positive, one huge negative. + Tests scale = 255.0f / range overflow to Inf, then v * Inf = Inf, + then clamping to [-128, 127]. */ + if (size < 2 * sizeof(float)) return 0; + + float extreme[2]; + memcpy(extreme, data, 2 * sizeof(float)); + + /* Only test if both are finite (NaN/Inf tested in case 2) */ + if (!isfinite(extreme[0]) || !isfinite(extreme[1])) return 0; + + /* Build a vector with these two extreme values plus some fuzz */ + size_t dim = 16; + float src[16]; + src[0] = extreme[0]; + src[1] = extreme[1]; + for (size_t i = 2; i < dim; i++) { + if (2 * sizeof(float) + (i - 2) < size) { + src[i] = (float)((int8_t)data[2 * sizeof(float) + (i - 2)]) * 1000.0f; + } else { + src[i] = 0.0f; + } + } + + int8_t dst[16]; + _test_rescore_quantize_float_to_int8(src, dst, dim); + + for (size_t i = 0; i < dim; i++) { + assert(dst[i] >= -128 && dst[i] <= 127); + } + break; + } + } + + return 0; +} diff --git a/tests/fuzz/rescore-quantize.c b/tests/fuzz/rescore-quantize.c new file mode 100644 index 0000000..6aad445 --- /dev/null +++ b/tests/fuzz/rescore-quantize.c @@ -0,0 +1,54 @@ +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +/* These are SQLITE_VEC_TEST wrappers defined in sqlite-vec-rescore.c */ +extern void _test_rescore_quantize_float_to_bit(const float *src, uint8_t *dst, size_t dim); +extern void _test_rescore_quantize_float_to_int8(const float *src, int8_t *dst, size_t dim); + +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + /* Need at least 4 bytes for one float */ + if (size < 4) return 0; + + /* Use the input as an array of floats. Dimensions must be a multiple of 8 + * for the bit quantizer. */ + size_t num_floats = size / sizeof(float); + if (num_floats == 0) return 0; + + /* Round down to multiple of 8 for bit quantizer compatibility */ + size_t dim = (num_floats / 8) * 8; + if (dim == 0) dim = 8; + if (dim > num_floats) return 0; + + const float *src = (const float *)data; + + /* Allocate output buffers */ + size_t bit_bytes = dim / 8; + uint8_t *bit_dst = (uint8_t *)malloc(bit_bytes); + int8_t *int8_dst = (int8_t *)malloc(dim); + if (!bit_dst || !int8_dst) { + free(bit_dst); + free(int8_dst); + return 0; + } + + /* Test bit quantization */ + _test_rescore_quantize_float_to_bit(src, bit_dst, dim); + + /* Test int8 quantization */ + _test_rescore_quantize_float_to_int8(src, int8_dst, dim); + + /* Verify int8 output is in range */ + for (size_t i = 0; i < dim; i++) { + assert(int8_dst[i] >= -128 && int8_dst[i] <= 127); + } + + free(bit_dst); + free(int8_dst); + return 0; +} diff --git a/tests/fuzz/rescore-shadow-corrupt.c b/tests/fuzz/rescore-shadow-corrupt.c new file mode 100644 index 0000000..edd87ef --- /dev/null +++ b/tests/fuzz/rescore-shadow-corrupt.c @@ -0,0 +1,230 @@ +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +/** + * Fuzz target: corrupt rescore shadow tables then exercise KNN/read/write. + * + * This targets the dangerous code paths in rescore_knn (Phase 1 + 2): + * - sqlite3_blob_read into baseVectors with potentially wrong-sized blobs + * - distance computation on corrupted/partial quantized data + * - blob_reopen on _rescore_vectors with missing/corrupted rows + * - insert/delete after corruption (blob_write to wrong offsets) + * + * The existing shadow-corrupt.c only tests vec0 without rescore. + */ +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + if (size < 4) return 0; + + int rc; + sqlite3 *db; + + rc = sqlite3_open(":memory:", &db); + assert(rc == SQLITE_OK); + rc = sqlite3_vec_init(db, NULL, NULL); + assert(rc == SQLITE_OK); + + /* Pick quantizer type from first byte */ + int use_int8 = data[0] & 1; + int target = (data[1] % 8); + const uint8_t *payload = data + 2; + int payload_size = (int)(size - 2); + + const char *create_sql = use_int8 + ? "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[16] indexed by rescore(quantizer=int8))" + : "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[16] indexed by rescore(quantizer=bit))"; + + rc = sqlite3_exec(db, create_sql, NULL, NULL, NULL); + if (rc != SQLITE_OK) { sqlite3_close(db); return 0; } + + /* Insert several vectors so there's a full chunk to corrupt */ + { + sqlite3_stmt *ins = NULL; + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &ins, NULL); + if (!ins) { sqlite3_close(db); return 0; } + + for (int i = 1; i <= 8; i++) { + float vec[16]; + for (int j = 0; j < 16; j++) vec[j] = (float)(i * 10 + j) / 100.0f; + sqlite3_reset(ins); + sqlite3_bind_int64(ins, 1, i); + sqlite3_bind_blob(ins, 2, vec, sizeof(vec), SQLITE_TRANSIENT); + sqlite3_step(ins); + } + sqlite3_finalize(ins); + } + + /* Now corrupt rescore shadow tables based on fuzz input */ + sqlite3_stmt *stmt = NULL; + + switch (target) { + case 0: { + /* Corrupt _rescore_chunks00 vectors blob with fuzz data */ + rc = sqlite3_prepare_v2(db, + "UPDATE v_rescore_chunks00 SET vectors = ? WHERE rowid = 1", + -1, &stmt, NULL); + if (rc == SQLITE_OK) { + sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC); + sqlite3_step(stmt); + sqlite3_finalize(stmt); + stmt = NULL; + } + break; + } + case 1: { + /* Corrupt _rescore_vectors00 vector blob for a specific row */ + rc = sqlite3_prepare_v2(db, + "UPDATE v_rescore_vectors00 SET vector = ? WHERE rowid = 3", + -1, &stmt, NULL); + if (rc == SQLITE_OK) { + sqlite3_bind_blob(stmt, 1, payload, payload_size, SQLITE_STATIC); + sqlite3_step(stmt); + sqlite3_finalize(stmt); + stmt = NULL; + } + break; + } + case 2: { + /* Truncate the quantized chunk blob to wrong size */ + rc = sqlite3_prepare_v2(db, + "UPDATE v_rescore_chunks00 SET vectors = X'DEADBEEF' WHERE rowid = 1", + -1, &stmt, NULL); + if (rc == SQLITE_OK) { + sqlite3_step(stmt); + sqlite3_finalize(stmt); + stmt = NULL; + } + break; + } + case 3: { + /* Delete rows from _rescore_vectors (orphan the float vectors) */ + sqlite3_exec(db, + "DELETE FROM v_rescore_vectors00 WHERE rowid IN (2, 4, 6)", + NULL, NULL, NULL); + break; + } + case 4: { + /* Delete the chunk row entirely from _rescore_chunks */ + sqlite3_exec(db, + "DELETE FROM v_rescore_chunks00 WHERE rowid = 1", + NULL, NULL, NULL); + break; + } + case 5: { + /* Set vectors to NULL in _rescore_chunks */ + sqlite3_exec(db, + "UPDATE v_rescore_chunks00 SET vectors = NULL WHERE rowid = 1", + NULL, NULL, NULL); + break; + } + case 6: { + /* Set vector to NULL in _rescore_vectors */ + sqlite3_exec(db, + "UPDATE v_rescore_vectors00 SET vector = NULL WHERE rowid = 3", + NULL, NULL, NULL); + break; + } + case 7: { + /* Corrupt BOTH tables with fuzz data */ + int half = payload_size / 2; + rc = sqlite3_prepare_v2(db, + "UPDATE v_rescore_chunks00 SET vectors = ? WHERE rowid = 1", + -1, &stmt, NULL); + if (rc == SQLITE_OK) { + sqlite3_bind_blob(stmt, 1, payload, half, SQLITE_STATIC); + sqlite3_step(stmt); + sqlite3_finalize(stmt); + stmt = NULL; + } + rc = sqlite3_prepare_v2(db, + "UPDATE v_rescore_vectors00 SET vector = ? WHERE rowid = 1", + -1, &stmt, NULL); + if (rc == SQLITE_OK) { + sqlite3_bind_blob(stmt, 1, payload + half, + payload_size - half, SQLITE_STATIC); + sqlite3_step(stmt); + sqlite3_finalize(stmt); + stmt = NULL; + } + break; + } + } + + /* Exercise ALL read/write paths -- NONE should crash */ + + /* KNN query (triggers rescore_knn Phase 1 + Phase 2) */ + { + float qvec[16] = {1,0,0,0, 0,1,0,0, 0,0,1,0, 0,0,0,1}; + sqlite3_stmt *knn = NULL; + sqlite3_prepare_v2(db, + "SELECT rowid, distance FROM v WHERE emb MATCH ? " + "ORDER BY distance LIMIT 5", -1, &knn, NULL); + if (knn) { + sqlite3_bind_blob(knn, 1, qvec, sizeof(qvec), SQLITE_STATIC); + while (sqlite3_step(knn) == SQLITE_ROW) {} + sqlite3_finalize(knn); + } + } + + /* Full scan (triggers reading from _rescore_vectors) */ + sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL); + + /* Point lookups */ + sqlite3_exec(db, "SELECT * FROM v WHERE rowid = 1", NULL, NULL, NULL); + sqlite3_exec(db, "SELECT * FROM v WHERE rowid = 3", NULL, NULL, NULL); + + /* Insert after corruption */ + { + float vec[16] = {0}; + sqlite3_stmt *ins = NULL; + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (99, ?)", -1, &ins, NULL); + if (ins) { + sqlite3_bind_blob(ins, 1, vec, sizeof(vec), SQLITE_STATIC); + sqlite3_step(ins); + sqlite3_finalize(ins); + } + } + + /* Delete after corruption */ + sqlite3_exec(db, "DELETE FROM v WHERE rowid = 5", NULL, NULL, NULL); + + /* Update after corruption */ + { + float vec[16] = {1,1,1,1, 1,1,1,1, 1,1,1,1, 1,1,1,1}; + sqlite3_stmt *upd = NULL; + sqlite3_prepare_v2(db, + "UPDATE v SET emb = ? WHERE rowid = 1", -1, &upd, NULL); + if (upd) { + sqlite3_bind_blob(upd, 1, vec, sizeof(vec), SQLITE_STATIC); + sqlite3_step(upd); + sqlite3_finalize(upd); + } + } + + /* KNN again after modifications to corrupted state */ + { + float qvec[16] = {0,0,0,0, 0,0,0,0, 1,1,1,1, 1,1,1,1}; + sqlite3_stmt *knn = NULL; + sqlite3_prepare_v2(db, + "SELECT rowid, distance FROM v WHERE emb MATCH ? " + "ORDER BY distance LIMIT 3", -1, &knn, NULL); + if (knn) { + sqlite3_bind_blob(knn, 1, qvec, sizeof(qvec), SQLITE_STATIC); + while (sqlite3_step(knn) == SQLITE_ROW) {} + sqlite3_finalize(knn); + } + } + + sqlite3_exec(db, "DROP TABLE v", NULL, NULL, NULL); + sqlite3_close(db); + return 0; +} diff --git a/tests/sqlite-vec-internal.h b/tests/sqlite-vec-internal.h index a02c72a..cbc2c08 100644 --- a/tests/sqlite-vec-internal.h +++ b/tests/sqlite-vec-internal.h @@ -65,8 +65,23 @@ enum Vec0DistanceMetrics { enum Vec0IndexType { VEC0_INDEX_TYPE_FLAT = 1, +#ifdef SQLITE_VEC_ENABLE_RESCORE + VEC0_INDEX_TYPE_RESCORE = 2, +#endif }; +#ifdef SQLITE_VEC_ENABLE_RESCORE +enum Vec0RescoreQuantizerType { + VEC0_RESCORE_QUANTIZER_BIT = 1, + VEC0_RESCORE_QUANTIZER_INT8 = 2, +}; + +struct Vec0RescoreConfig { + enum Vec0RescoreQuantizerType quantizer_type; + int oversample; +}; +#endif + struct VectorColumnDefinition { char *name; int name_length; @@ -74,6 +89,9 @@ struct VectorColumnDefinition { enum VectorElementType element_type; enum Vec0DistanceMetrics distance_metric; enum Vec0IndexType index_type; +#ifdef SQLITE_VEC_ENABLE_RESCORE + struct Vec0RescoreConfig rescore; +#endif }; int vec0_parse_vector_column(const char *source, int source_length, @@ -88,6 +106,13 @@ int vec0_parse_partition_key_definition(const char *source, int source_length, float _test_distance_l2_sqr_float(const float *a, const float *b, size_t dims); float _test_distance_cosine_float(const float *a, const float *b, size_t dims); float _test_distance_hamming(const unsigned char *a, const unsigned char *b, size_t dims); + +#ifdef SQLITE_VEC_ENABLE_RESCORE +void _test_rescore_quantize_float_to_bit(const float *src, uint8_t *dst, size_t dim); +void _test_rescore_quantize_float_to_int8(const float *src, int8_t *dst, size_t dim); +size_t _test_rescore_quantized_byte_size_bit(size_t dimensions); +size_t _test_rescore_quantized_byte_size_int8(size_t dimensions); +#endif #endif #endif /* SQLITE_VEC_INTERNAL_H */ diff --git a/tests/test-rescore-mutations.py b/tests/test-rescore-mutations.py new file mode 100644 index 0000000..28495c2 --- /dev/null +++ b/tests/test-rescore-mutations.py @@ -0,0 +1,470 @@ +"""Mutation and edge-case tests for the rescore index feature.""" +import struct +import sqlite3 +import pytest +import math +import random + + +@pytest.fixture() +def db(): + db = sqlite3.connect(":memory:") + db.row_factory = sqlite3.Row + db.enable_load_extension(True) + db.load_extension("dist/vec0") + db.enable_load_extension(False) + return db + + +def float_vec(values): + """Pack a list of floats into a blob for sqlite-vec.""" + return struct.pack(f"{len(values)}f", *values) + + +def unpack_float_vec(blob): + """Unpack a float vector blob.""" + n = len(blob) // 4 + return list(struct.unpack(f"{n}f", blob)) + + +# ============================================================================ +# Error cases: rescore + aux/metadata/partition +# ============================================================================ + + +def test_create_error_with_aux_column(db): + """Rescore should reject auxiliary columns.""" + with pytest.raises(sqlite3.OperationalError, match="Auxiliary columns"): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)," + " +extra text" + ")" + ) + + +def test_create_error_with_metadata_column(db): + """Rescore should reject metadata columns.""" + with pytest.raises(sqlite3.OperationalError, match="Metadata columns"): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)," + " genre text" + ")" + ) + + +def test_create_error_with_partition_key(db): + """Rescore should reject partition key columns.""" + with pytest.raises(sqlite3.OperationalError, match="Partition key"): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)," + " user_id integer partition key" + ")" + ) + + +# ============================================================================ +# Insert / batch / delete / update mutations +# ============================================================================ + + +def test_insert_single_verify_knn(db): + """Insert a single row and verify KNN returns it.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)]) + rows = db.execute( + "SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1", + [float_vec([1.0] * 8)], + ).fetchall() + assert len(rows) == 1 + assert rows[0]["rowid"] == 1 + assert rows[0]["distance"] < 0.01 + + +def test_insert_large_batch(db): + """Insert 200+ rows (multiple chunks with default chunk_size=1024) and verify count and KNN.""" + dim = 16 + n = 200 + random.seed(99) + db.execute( + f"CREATE VIRTUAL TABLE t USING vec0(" + f" embedding float[{dim}] indexed by rescore(quantizer=int8)" + f")" + ) + for i in range(n): + v = [random.gauss(0, 1) for _ in range(dim)] + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (?, ?)", + [i + 1, float_vec(v)], + ) + row = db.execute("SELECT count(*) as cnt FROM t").fetchone() + assert row["cnt"] == n + + # KNN should return results + query = float_vec([random.gauss(0, 1) for _ in range(dim)]) + rows = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10", + [query], + ).fetchall() + assert len(rows) == 10 + + +def test_delete_all_rows(db): + """Delete every row, verify count=0, KNN returns empty.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + for i in range(20): + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (?, ?)", + [i + 1, float_vec([float(i)] * 8)], + ) + assert db.execute("SELECT count(*) as cnt FROM t").fetchone()["cnt"] == 20 + + for i in range(20): + db.execute("DELETE FROM t WHERE rowid = ?", [i + 1]) + + assert db.execute("SELECT count(*) as cnt FROM t").fetchone()["cnt"] == 0 + + rows = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 5", + [float_vec([0.0] * 8)], + ).fetchall() + assert len(rows) == 0 + + +def test_delete_then_reinsert_same_rowid(db): + """Delete rowid=1, re-insert rowid=1 with different vector, verify KNN uses new vector.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=int8)" + ")" + ) + # Insert rowid=1 near origin, rowid=2 far from origin + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (1, ?)", + [float_vec([0.1] * 8)], + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (2, ?)", + [float_vec([100.0] * 8)], + ) + + # KNN to [0]*8 -> rowid 1 is closer + rows = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1", + [float_vec([0.0] * 8)], + ).fetchall() + assert rows[0]["rowid"] == 1 + + # Delete rowid=1, re-insert with vector far from origin + db.execute("DELETE FROM t WHERE rowid = 1") + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (1, ?)", + [float_vec([200.0] * 8)], + ) + + # Now KNN to [0]*8 -> rowid 2 should be closer + rows = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1", + [float_vec([0.0] * 8)], + ).fetchall() + assert rows[0]["rowid"] == 2 + + +def test_update_vector(db): + """UPDATE the vector column and verify KNN reflects new value.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=int8)" + ")" + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (1, ?)", + [float_vec([0.0] * 8)], + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (2, ?)", + [float_vec([10.0] * 8)], + ) + + # Update rowid=1 to be far away + db.execute( + "UPDATE t SET embedding = ? WHERE rowid = 1", + [float_vec([100.0] * 8)], + ) + + # Now KNN to [0]*8 -> rowid 2 should be closest + rows = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1", + [float_vec([0.0] * 8)], + ).fetchall() + assert rows[0]["rowid"] == 2 + + +def test_knn_after_delete_all_but_one(db): + """Insert 50 rows, delete 49, KNN should only return the survivor.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + for i in range(50): + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (?, ?)", + [i + 1, float_vec([float(i)] * 8)], + ) + # Delete all except rowid=25 + for i in range(50): + if i + 1 != 25: + db.execute("DELETE FROM t WHERE rowid = ?", [i + 1]) + + assert db.execute("SELECT count(*) as cnt FROM t").fetchone()["cnt"] == 1 + + rows = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10", + [float_vec([0.0] * 8)], + ).fetchall() + assert len(rows) == 1 + assert rows[0]["rowid"] == 25 + + +# ============================================================================ +# Edge cases +# ============================================================================ + + +def test_single_row_knn(db): + """Table with exactly 1 row. LIMIT 1 returns it; LIMIT 5 returns 1.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)]) + + rows = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1", + [float_vec([1.0] * 8)], + ).fetchall() + assert len(rows) == 1 + + rows = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 5", + [float_vec([1.0] * 8)], + ).fetchall() + assert len(rows) == 1 + + +def test_knn_with_all_identical_vectors(db): + """All vectors are the same. All distances should be equal.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=int8)" + ")" + ) + vec = [3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0] + for i in range(10): + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (?, ?)", + [i + 1, float_vec(vec)], + ) + + rows = db.execute( + "SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10", + [float_vec(vec)], + ).fetchall() + assert len(rows) == 10 + # All distances should be ~0 (exact match) + for r in rows: + assert r["distance"] < 0.01 + + +def test_zero_vector_insert(db): + """Insert the zero vector [0,0,...,0]. Should not crash quantization.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (1, ?)", + [float_vec([0.0] * 8)], + ) + row = db.execute("SELECT count(*) as cnt FROM t").fetchone() + assert row["cnt"] == 1 + + # Also test int8 quantizer with zero vector + db.execute( + "CREATE VIRTUAL TABLE t2 USING vec0(" + " embedding float[8] indexed by rescore(quantizer=int8)" + ")" + ) + db.execute( + "INSERT INTO t2(rowid, embedding) VALUES (1, ?)", + [float_vec([0.0] * 8)], + ) + row = db.execute("SELECT count(*) as cnt FROM t2").fetchone() + assert row["cnt"] == 1 + + +def test_very_large_values(db): + """Insert vectors with very large float values. Quantization should not crash.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=int8)" + ")" + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (1, ?)", + [float_vec([1e30] * 8)], + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (2, ?)", + [float_vec([1e30, -1e30, 1e30, -1e30, 1e30, -1e30, 1e30, -1e30])], + ) + row = db.execute("SELECT count(*) as cnt FROM t").fetchone() + assert row["cnt"] == 2 + + +def test_negative_values(db): + """Insert vectors with all negative values. Bit quantization maps all to 0.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (1, ?)", + [float_vec([-1.0, -2.0, -3.0, -4.0, -5.0, -6.0, -7.0, -8.0])], + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (2, ?)", + [float_vec([-0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8])], + ) + row = db.execute("SELECT count(*) as cnt FROM t").fetchone() + assert row["cnt"] == 2 + + # KNN should still work + rows = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 2", + [float_vec([-0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8])], + ).fetchall() + assert len(rows) == 2 + assert rows[0]["rowid"] == 2 + + +def test_single_dimension(db): + """Single-dimension vector (edge case for quantization).""" + # int8 quantizer (bit needs dim divisible by 8) + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=int8)" + ")" + ) + db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)]) + db.execute("INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec([5.0] * 8)]) + rows = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1", + [float_vec([1.0] * 8)], + ).fetchall() + assert rows[0]["rowid"] == 1 + + +# ============================================================================ +# vec_debug() verification +# ============================================================================ + + +def test_vec_debug_contains_rescore(db): + """vec_debug() should contain 'rescore' in build flags when compiled with SQLITE_VEC_ENABLE_RESCORE.""" + row = db.execute("SELECT vec_debug() as d").fetchone() + assert "rescore" in row["d"] + + +# ============================================================================ +# Insert batch recall test +# ============================================================================ + + +def test_insert_batch_recall(db): + """Insert 150 rows and verify KNN recall is reasonable (>0.6).""" + dim = 16 + n = 150 + k = 10 + random.seed(77) + + db.execute( + f"CREATE VIRTUAL TABLE t_rescore USING vec0(" + f" embedding float[{dim}] indexed by rescore(quantizer=int8, oversample=16)" + f")" + ) + db.execute( + f"CREATE VIRTUAL TABLE t_flat USING vec0(embedding float[{dim}])" + ) + + vectors = [[random.gauss(0, 1) for _ in range(dim)] for _ in range(n)] + for i, v in enumerate(vectors): + blob = float_vec(v) + db.execute( + "INSERT INTO t_rescore(rowid, embedding) VALUES (?, ?)", [i + 1, blob] + ) + db.execute( + "INSERT INTO t_flat(rowid, embedding) VALUES (?, ?)", [i + 1, blob] + ) + + query = float_vec([random.gauss(0, 1) for _ in range(dim)]) + + rescore_rows = db.execute( + "SELECT rowid FROM t_rescore WHERE embedding MATCH ? ORDER BY distance LIMIT ?", + [query, k], + ).fetchall() + flat_rows = db.execute( + "SELECT rowid FROM t_flat WHERE embedding MATCH ? ORDER BY distance LIMIT ?", + [query, k], + ).fetchall() + + rescore_ids = {r["rowid"] for r in rescore_rows} + flat_ids = {r["rowid"] for r in flat_rows} + recall = len(rescore_ids & flat_ids) / k + assert recall >= 0.6, f"Recall too low: {recall}" + + +# ============================================================================ +# Distance metric variants +# ============================================================================ + + +def test_knn_int8_cosine(db): + """Rescore with quantizer=int8 and distance_metric=cosine.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] distance_metric=cosine indexed by rescore(quantizer=int8)" + ")" + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (1, ?)", + [float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (2, ?)", + [float_vec([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (3, ?)", + [float_vec([1.0, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ) + rows = db.execute( + "SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 2", + [float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ).fetchall() + assert rows[0]["rowid"] == 1 + assert rows[0]["distance"] < 0.01 diff --git a/tests/test-rescore.py b/tests/test-rescore.py new file mode 100644 index 0000000..5025857 --- /dev/null +++ b/tests/test-rescore.py @@ -0,0 +1,568 @@ +"""Tests for the rescore index feature in sqlite-vec.""" +import struct +import sqlite3 +import pytest +import math +import random + + +@pytest.fixture() +def db(): + db = sqlite3.connect(":memory:") + db.row_factory = sqlite3.Row + db.enable_load_extension(True) + db.load_extension("dist/vec0") + db.enable_load_extension(False) + return db + + +def float_vec(values): + """Pack a list of floats into a blob for sqlite-vec.""" + return struct.pack(f"{len(values)}f", *values) + + +def unpack_float_vec(blob): + """Unpack a float vector blob.""" + n = len(blob) // 4 + return list(struct.unpack(f"{n}f", blob)) + + +# ============================================================================ +# Creation tests +# ============================================================================ + + +def test_create_bit(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[128] indexed by rescore(quantizer=bit)" + ")" + ) + # Table exists and has the right structure + row = db.execute( + "SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'" + ).fetchone() + assert row["cnt"] > 0 + + +def test_create_int8(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[128] indexed by rescore(quantizer=int8)" + ")" + ) + row = db.execute( + "SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'" + ).fetchone() + assert row["cnt"] > 0 + + +def test_create_with_oversample(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[128] indexed by rescore(quantizer=bit, oversample=16)" + ")" + ) + row = db.execute( + "SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'" + ).fetchone() + assert row["cnt"] > 0 + + +def test_create_with_distance_metric(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[128] distance_metric=cosine indexed by rescore(quantizer=bit)" + ")" + ) + row = db.execute( + "SELECT count(*) as cnt FROM sqlite_master WHERE name LIKE 't_%'" + ).fetchone() + assert row["cnt"] > 0 + + +def test_create_error_missing_quantizer(db): + with pytest.raises(sqlite3.OperationalError): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[128] indexed by rescore(oversample=8)" + ")" + ) + + +def test_create_error_invalid_quantizer(db): + with pytest.raises(sqlite3.OperationalError): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[128] indexed by rescore(quantizer=float)" + ")" + ) + + +def test_create_error_on_bit_column(db): + with pytest.raises(sqlite3.OperationalError): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding bit[1024] indexed by rescore(quantizer=bit)" + ")" + ) + + +def test_create_error_on_int8_column(db): + with pytest.raises(sqlite3.OperationalError): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding int8[128] indexed by rescore(quantizer=bit)" + ")" + ) + + +def test_create_error_bad_oversample_zero(db): + with pytest.raises(sqlite3.OperationalError): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[128] indexed by rescore(quantizer=bit, oversample=0)" + ")" + ) + + +def test_create_error_bad_oversample_too_large(db): + with pytest.raises(sqlite3.OperationalError): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[128] indexed by rescore(quantizer=bit, oversample=999)" + ")" + ) + + +def test_create_error_bit_dim_not_divisible_by_8(db): + with pytest.raises(sqlite3.OperationalError): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[100] indexed by rescore(quantizer=bit)" + ")" + ) + + +# ============================================================================ +# Shadow table tests +# ============================================================================ + + +def test_shadow_tables_exist(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[128] indexed by rescore(quantizer=bit)" + ")" + ) + tables = [ + r[0] + for r in db.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 't_%' ORDER BY name" + ).fetchall() + ] + assert "t_rescore_chunks00" in tables + assert "t_rescore_vectors00" in tables + # Rescore columns don't create _vector_chunks + assert "t_vector_chunks00" not in tables + + +def test_drop_cleans_up(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[128] indexed by rescore(quantizer=bit)" + ")" + ) + db.execute("DROP TABLE t") + tables = [ + r[0] + for r in db.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 't_%'" + ).fetchall() + ] + assert len(tables) == 0 + + +# ============================================================================ +# Insert tests +# ============================================================================ + + +def test_insert_single(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)]) + row = db.execute("SELECT count(*) as cnt FROM t").fetchone() + assert row["cnt"] == 1 + + +def test_insert_multiple(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=int8)" + ")" + ) + for i in range(10): + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (?, ?)", + [i + 1, float_vec([float(i)] * 8)], + ) + row = db.execute("SELECT count(*) as cnt FROM t").fetchone() + assert row["cnt"] == 10 + + +# ============================================================================ +# Delete tests +# ============================================================================ + + +def test_delete_single(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)]) + db.execute("DELETE FROM t WHERE rowid = 1") + row = db.execute("SELECT count(*) as cnt FROM t").fetchone() + assert row["cnt"] == 0 + + +def test_delete_and_reinsert(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)]) + db.execute("DELETE FROM t WHERE rowid = 1") + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec([2.0] * 8)] + ) + row = db.execute("SELECT count(*) as cnt FROM t").fetchone() + assert row["cnt"] == 1 + + +def test_point_query_returns_float(db): + """SELECT by rowid should return the original float vector, not quantized.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + vals = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] + db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec(vals)]) + row = db.execute("SELECT embedding FROM t WHERE rowid = 1").fetchone() + result = unpack_float_vec(row["embedding"]) + for a, b in zip(result, vals): + assert abs(a - b) < 1e-6 + + +# ============================================================================ +# KNN tests +# ============================================================================ + + +def test_knn_basic_bit(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + # Insert vectors where [1,0,0,...] is closest to query [1,0,0,...] + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (1, ?)", + [float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (2, ?)", + [float_vec([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (3, ?)", + [float_vec([0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ) + rows = db.execute( + "SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1", + [float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ).fetchall() + assert len(rows) == 1 + assert rows[0]["rowid"] == 1 + + +def test_knn_basic_int8(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=int8)" + ")" + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (1, ?)", + [float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (2, ?)", + [float_vec([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (3, ?)", + [float_vec([0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ) + rows = db.execute( + "SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1", + [float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ).fetchall() + assert len(rows) == 1 + assert rows[0]["rowid"] == 1 + + +def test_knn_returns_float_distances(db): + """KNN should return float-precision distances, not quantized distances.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + v1 = [1.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + v2 = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] + db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec(v1)]) + db.execute("INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec(v2)]) + + query = [1.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + rows = db.execute( + "SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 2", + [float_vec(query)], + ).fetchall() + + # First result should be exact match with distance ~0 + assert rows[0]["rowid"] == 1 + assert rows[0]["distance"] < 0.01 + + # Second result should have a float distance + # sqrt((1-0)^2 + (0.5-0)^2 + (0-1)^2) = sqrt(2.25) = 1.5 + assert abs(rows[1]["distance"] - 1.5) < 0.01 + + +def test_knn_recall(db): + """With enough vectors, rescore should achieve good recall (>0.9).""" + dim = 32 + n = 1000 + k = 10 + random.seed(42) + + db.execute( + "CREATE VIRTUAL TABLE t_rescore USING vec0(" + f" embedding float[{dim}] indexed by rescore(quantizer=bit, oversample=16)" + ")" + ) + db.execute( + f"CREATE VIRTUAL TABLE t_flat USING vec0(embedding float[{dim}])" + ) + + vectors = [[random.gauss(0, 1) for _ in range(dim)] for _ in range(n)] + for i, v in enumerate(vectors): + blob = float_vec(v) + db.execute( + "INSERT INTO t_rescore(rowid, embedding) VALUES (?, ?)", [i + 1, blob] + ) + db.execute( + "INSERT INTO t_flat(rowid, embedding) VALUES (?, ?)", [i + 1, blob] + ) + + query = float_vec([random.gauss(0, 1) for _ in range(dim)]) + + rescore_rows = db.execute( + "SELECT rowid FROM t_rescore WHERE embedding MATCH ? ORDER BY distance LIMIT ?", + [query, k], + ).fetchall() + flat_rows = db.execute( + "SELECT rowid FROM t_flat WHERE embedding MATCH ? ORDER BY distance LIMIT ?", + [query, k], + ).fetchall() + + rescore_ids = {r["rowid"] for r in rescore_rows} + flat_ids = {r["rowid"] for r in flat_rows} + recall = len(rescore_ids & flat_ids) / k + assert recall >= 0.7, f"Recall too low: {recall}" + + +def test_knn_cosine(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] distance_metric=cosine indexed by rescore(quantizer=bit)" + ")" + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (1, ?)", + [float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ) + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (2, ?)", + [float_vec([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ) + rows = db.execute( + "SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 1", + [float_vec([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ).fetchall() + assert rows[0]["rowid"] == 1 + # cosine distance of identical vectors should be ~0 + assert rows[0]["distance"] < 0.01 + + +def test_knn_empty_table(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + rows = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 5", + [float_vec([1.0] * 8)], + ).fetchall() + assert len(rows) == 0 + + +def test_knn_k_larger_than_n(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec([1.0] * 8)]) + db.execute("INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec([2.0] * 8)]) + rows = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 10", + [float_vec([1.0] * 8)], + ).fetchall() + assert len(rows) == 2 + + +# ============================================================================ +# Integration / edge case tests +# ============================================================================ + + +def test_knn_with_rowid_in(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=bit)" + ")" + ) + for i in range(5): + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (?, ?)", + [i + 1, float_vec([float(i)] * 8)], + ) + # Only search within rowids 1, 3, 5 + rows = db.execute( + "SELECT rowid FROM t WHERE embedding MATCH ? AND rowid IN (1, 3, 5) ORDER BY distance LIMIT 3", + [float_vec([0.0] * 8)], + ).fetchall() + result_ids = {r["rowid"] for r in rows} + assert result_ids <= {1, 3, 5} + + +def test_knn_after_deletes(db): + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " embedding float[8] indexed by rescore(quantizer=int8)" + ")" + ) + for i in range(10): + db.execute( + "INSERT INTO t(rowid, embedding) VALUES (?, ?)", + [i + 1, float_vec([float(i)] * 8)], + ) + # Delete the closest match (rowid 1 = [0,0,...]) + db.execute("DELETE FROM t WHERE rowid = 1") + rows = db.execute( + "SELECT rowid, distance FROM t WHERE embedding MATCH ? ORDER BY distance LIMIT 5", + [float_vec([0.0] * 8)], + ).fetchall() + # Verify ordering: rowid 2 ([1]*8) should be closest, then 3 ([2]*8), etc. + assert len(rows) >= 2 + assert rows[0]["distance"] <= rows[1]["distance"] + # rowid 2 = [1,1,...] → L2 = sqrt(8) ≈ 2.83, rowid 3 = [2,2,...] → L2 = sqrt(32) ≈ 5.66 + assert rows[0]["rowid"] == 2, f"Expected rowid 2, got {rows[0]['rowid']} with dist={rows[0]['distance']}" + + +def test_oversample_effect(db): + """Higher oversample should give equal or better recall.""" + dim = 32 + n = 500 + k = 10 + random.seed(123) + + vectors = [[random.gauss(0, 1) for _ in range(dim)] for _ in range(n)] + query = float_vec([random.gauss(0, 1) for _ in range(dim)]) + + recalls = [] + for oversample in [2, 16]: + tname = f"t_os{oversample}" + db.execute( + f"CREATE VIRTUAL TABLE {tname} USING vec0(" + f" embedding float[{dim}] indexed by rescore(quantizer=bit, oversample={oversample})" + ")" + ) + for i, v in enumerate(vectors): + db.execute( + f"INSERT INTO {tname}(rowid, embedding) VALUES (?, ?)", + [i + 1, float_vec(v)], + ) + rows = db.execute( + f"SELECT rowid FROM {tname} WHERE embedding MATCH ? ORDER BY distance LIMIT ?", + [query, k], + ).fetchall() + recalls.append({r["rowid"] for r in rows}) + + # Also get ground truth + db.execute(f"CREATE VIRTUAL TABLE t_flat USING vec0(embedding float[{dim}])") + for i, v in enumerate(vectors): + db.execute( + "INSERT INTO t_flat(rowid, embedding) VALUES (?, ?)", + [i + 1, float_vec(v)], + ) + gt_rows = db.execute( + "SELECT rowid FROM t_flat WHERE embedding MATCH ? ORDER BY distance LIMIT ?", + [query, k], + ).fetchall() + gt_ids = {r["rowid"] for r in gt_rows} + + recall_low = len(recalls[0] & gt_ids) / k + recall_high = len(recalls[1] & gt_ids) / k + assert recall_high >= recall_low + + +def test_multiple_vector_columns(db): + """One column with rescore, one without.""" + db.execute( + "CREATE VIRTUAL TABLE t USING vec0(" + " v1 float[8] indexed by rescore(quantizer=bit)," + " v2 float[8]" + ")" + ) + db.execute( + "INSERT INTO t(rowid, v1, v2) VALUES (1, ?, ?)", + [float_vec([1.0] * 8), float_vec([0.0] * 8)], + ) + db.execute( + "INSERT INTO t(rowid, v1, v2) VALUES (2, ?, ?)", + [float_vec([0.0] * 8), float_vec([1.0] * 8)], + ) + + # KNN on v1 (rescore path) + rows = db.execute( + "SELECT rowid FROM t WHERE v1 MATCH ? ORDER BY distance LIMIT 1", + [float_vec([1.0] * 8)], + ).fetchall() + assert rows[0]["rowid"] == 1 + + # KNN on v2 (normal path) + rows = db.execute( + "SELECT rowid FROM t WHERE v2 MATCH ? ORDER BY distance LIMIT 1", + [float_vec([1.0] * 8)], + ).fetchall() + assert rows[0]["rowid"] == 2 diff --git a/tests/test-unit.c b/tests/test-unit.c index 9eb8704..b180625 100644 --- a/tests/test-unit.c +++ b/tests/test-unit.c @@ -760,6 +760,202 @@ void test_distance_hamming() { printf(" All distance_hamming tests passed.\n"); } +#ifdef SQLITE_VEC_ENABLE_RESCORE + +void test_rescore_quantize_float_to_bit() { + printf("Starting %s...\n", __func__); + uint8_t dst[16]; + + // All positive -> all bits 1 + { + float src[8] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + memset(dst, 0, sizeof(dst)); + _test_rescore_quantize_float_to_bit(src, dst, 8); + assert(dst[0] == 0xFF); + } + + // All negative -> all bits 0 + { + float src[8] = {-1.0f, -2.0f, -3.0f, -4.0f, -5.0f, -6.0f, -7.0f, -8.0f}; + memset(dst, 0xFF, sizeof(dst)); + _test_rescore_quantize_float_to_bit(src, dst, 8); + assert(dst[0] == 0x00); + } + + // Alternating positive/negative + { + float src[8] = {1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f}; + _test_rescore_quantize_float_to_bit(src, dst, 8); + // bits 0,2,4,6 set => 0b01010101 = 0x55 + assert(dst[0] == 0x55); + } + + // Zero values -> bit is set (>= 0.0f) + { + float src[8] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + _test_rescore_quantize_float_to_bit(src, dst, 8); + assert(dst[0] == 0xFF); + } + + // 128 dimensions -> 16 bytes output + { + float src[128]; + for (int i = 0; i < 128; i++) src[i] = (i % 2 == 0) ? 1.0f : -1.0f; + memset(dst, 0, 16); + _test_rescore_quantize_float_to_bit(src, dst, 128); + // Even indices set: bits 0,2,4,6 in each byte => 0x55 + for (int i = 0; i < 16; i++) { + assert(dst[i] == 0x55); + } + } + + printf(" All rescore_quantize_float_to_bit tests passed.\n"); +} + +void test_rescore_quantize_float_to_int8() { + printf("Starting %s...\n", __func__); + int8_t dst[256]; + + // Uniform vector -> all zeros (range=0) + { + float src[8] = {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f}; + _test_rescore_quantize_float_to_int8(src, dst, 8); + for (int i = 0; i < 8; i++) { + assert(dst[i] == 0); + } + } + + // [0.0, 1.0] -> should map to [-128, 127] + { + float src[2] = {0.0f, 1.0f}; + _test_rescore_quantize_float_to_int8(src, dst, 2); + assert(dst[0] == -128); + assert(dst[1] == 127); + } + + // [-1.0, 0.0] -> should map to [-128, 127] + { + float src[2] = {-1.0f, 0.0f}; + _test_rescore_quantize_float_to_int8(src, dst, 2); + assert(dst[0] == -128); + assert(dst[1] == 127); + } + + // Single-element: range=0 -> 0 + { + float src[1] = {42.0f}; + _test_rescore_quantize_float_to_int8(src, dst, 1); + assert(dst[0] == 0); + } + + // Verify range: all outputs in [-128, 127], min near -128, max near 127 + { + float src[4] = {-100.0f, 0.0f, 100.0f, 50.0f}; + _test_rescore_quantize_float_to_int8(src, dst, 4); + for (int i = 0; i < 4; i++) { + assert(dst[i] >= -128 && dst[i] <= 127); + } + // Min maps to -128 (exact), max maps to ~127 (may lose 1 to float rounding) + assert(dst[0] == -128); + assert(dst[2] >= 126 && dst[2] <= 127); + // Middle value (50) should be positive + assert(dst[3] > 0); + } + + printf(" All rescore_quantize_float_to_int8 tests passed.\n"); +} + +void test_rescore_quantized_byte_size() { + printf("Starting %s...\n", __func__); + + // Bit quantizer: dims/8 + assert(_test_rescore_quantized_byte_size_bit(128) == 16); + assert(_test_rescore_quantized_byte_size_bit(8) == 1); + assert(_test_rescore_quantized_byte_size_bit(1024) == 128); + + // Int8 quantizer: dims + assert(_test_rescore_quantized_byte_size_int8(128) == 128); + assert(_test_rescore_quantized_byte_size_int8(8) == 8); + assert(_test_rescore_quantized_byte_size_int8(1024) == 1024); + + printf(" All rescore_quantized_byte_size tests passed.\n"); +} + +void test_vec0_parse_vector_column_rescore() { + printf("Starting %s...\n", __func__); + struct VectorColumnDefinition col; + int rc; + + // Basic bit quantizer + { + const char *input = "emb float[128] indexed by rescore(quantizer=bit)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.index_type == VEC0_INDEX_TYPE_RESCORE); + assert(col.rescore.quantizer_type == VEC0_RESCORE_QUANTIZER_BIT); + assert(col.rescore.oversample == 8); // default + assert(col.dimensions == 128); + sqlite3_free(col.name); + } + + // Int8 quantizer + { + const char *input = "emb float[128] indexed by rescore(quantizer=int8)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.index_type == VEC0_INDEX_TYPE_RESCORE); + assert(col.rescore.quantizer_type == VEC0_RESCORE_QUANTIZER_INT8); + sqlite3_free(col.name); + } + + // Bit quantizer with oversample + { + const char *input = "emb float[128] indexed by rescore(quantizer=bit, oversample=16)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.index_type == VEC0_INDEX_TYPE_RESCORE); + assert(col.rescore.quantizer_type == VEC0_RESCORE_QUANTIZER_BIT); + assert(col.rescore.oversample == 16); + sqlite3_free(col.name); + } + + // Error: non-float element type + { + const char *input = "emb int8[128] indexed by rescore(quantizer=bit)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_ERROR); + } + + // Error: dims not divisible by 8 for bit quantizer + { + const char *input = "emb float[100] indexed by rescore(quantizer=bit)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_ERROR); + } + + // Error: missing quantizer + { + const char *input = "emb float[128] indexed by rescore(oversample=8)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_ERROR); + } + + // With distance_metric=cosine + { + const char *input = "emb float[128] distance_metric=cosine indexed by rescore(quantizer=int8)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.index_type == VEC0_INDEX_TYPE_RESCORE); + assert(col.distance_metric == VEC0_DISTANCE_METRIC_COSINE); + assert(col.rescore.quantizer_type == VEC0_RESCORE_QUANTIZER_INT8); + sqlite3_free(col.name); + } + + printf(" All vec0_parse_vector_column_rescore tests passed.\n"); +} + +#endif /* SQLITE_VEC_ENABLE_RESCORE */ + int main() { printf("Starting unit tests...\n"); #ifdef SQLITE_VEC_ENABLE_AVX @@ -768,6 +964,9 @@ int main() { #ifdef SQLITE_VEC_ENABLE_NEON printf("SQLITE_VEC_ENABLE_NEON=1\n"); #endif +#ifdef SQLITE_VEC_ENABLE_RESCORE + printf("SQLITE_VEC_ENABLE_RESCORE=1\n"); +#endif #if !defined(SQLITE_VEC_ENABLE_AVX) && !defined(SQLITE_VEC_ENABLE_NEON) printf("SIMD: none\n"); #endif @@ -778,5 +977,11 @@ int main() { test_distance_l2_sqr_float(); test_distance_cosine_float(); test_distance_hamming(); +#ifdef SQLITE_VEC_ENABLE_RESCORE + test_rescore_quantize_float_to_bit(); + test_rescore_quantize_float_to_int8(); + test_rescore_quantized_byte_size(); + test_vec0_parse_vector_column_rescore(); +#endif printf("All unit tests passed.\n"); } From 69f7b658e99684876a501a62d79c397e905db3b7 Mon Sep 17 00:00:00 2001 From: Alex Garcia Date: Mon, 30 Mar 2026 16:40:44 -0700 Subject: [PATCH 3/3] rm unnecessary TODO --- TODO.md | 73 --------------------------------------------------------- 1 file changed, 73 deletions(-) delete mode 100644 TODO.md diff --git a/TODO.md b/TODO.md deleted file mode 100644 index 4c3cc19..0000000 --- a/TODO.md +++ /dev/null @@ -1,73 +0,0 @@ -# TODO: `ann` base branch + consolidated benchmarks - -## 1. Create `ann` branch with shared code - -### 1.1 Branch setup -- [x] `git checkout -B ann origin/main` -- [x] Cherry-pick `624f998` (vec0_distance_full shared distance dispatch) -- [x] Cherry-pick stdint.h fix for test header -- [ ] Pull NEON cosine optimization from ivf-yolo3 into shared code - - Currently only in ivf branch but is general-purpose (benefits all distance calcs) - - Lives in `distance_cosine_float()` — ~57 lines of ARM NEON vectorized cosine - -### 1.2 Benchmark infrastructure (`benchmarks-ann/`) -- [x] Seed data pipeline (`seed/Makefile`, `seed/build_base_db.py`) -- [x] Ground truth generator (`ground_truth.py`) -- [x] Results schema (`schema.sql`) -- [x] Benchmark runner with `INDEX_REGISTRY` extension point (`bench.py`) - - Baseline configs (float, int8-rescore, bit-rescore) implemented - - Index branches register their types via `INDEX_REGISTRY` dict -- [x] Makefile with baseline targets -- [x] README - -### 1.3 Rebase feature branches onto `ann` -- [x] Rebase `diskann-yolo2` onto `ann` (1 commit: DiskANN implementation) -- [x] Rebase `ivf-yolo3` onto `ann` (1 commit: IVF implementation) -- [x] Rebase `annoy-yolo2` onto `ann` (2 commits: Annoy implementation + schema fix) -- [x] Verify each branch has only its index-specific commits remaining -- [ ] Force-push all 4 branches to origin - ---- - -## 2. Per-branch: register index type in benchmarks - -Each index branch should add to `benchmarks-ann/` when rebased onto `ann`: - -### 2.1 Register in `bench.py` - -Add an `INDEX_REGISTRY` entry. Each entry provides: -- `defaults` — default param values -- `create_table_sql(params)` — CREATE VIRTUAL TABLE with INDEXED BY clause -- `insert_sql(params)` — custom insert SQL, or None for default -- `post_insert_hook(conn, params)` — training/building step, returns time -- `run_query(conn, params, query, k)` — custom query, or None for default MATCH -- `describe(params)` — one-line description for report output - -### 2.2 Add configs to `Makefile` - -Append index-specific config variables and targets. Example pattern: - -```makefile -DISKANN_CONFIGS = \ - "diskann-R48-binary:type=diskann,R=48,L=128,quantizer=binary" \ - ... - -ALL_CONFIGS += $(DISKANN_CONFIGS) - -bench-diskann: seed - $(BENCH) --subset-size 10000 -k 10 -o runs/diskann $(BASELINES) $(DISKANN_CONFIGS) - ... -``` - -### 2.3 Migrate existing benchmark results/docs - -- Move useful results docs (RESULTS.md, etc.) into `benchmarks-ann/results/` -- Delete redundant per-branch benchmark directories once consolidated infra is proven - ---- - -## 3. Future improvements - -- [ ] Reporting script (`report.py`) — query results.db, produce markdown comparison tables -- [ ] Profiling targets in Makefile (lift from ivf-yolo3's Instruments/perf wrappers) -- [ ] Pre-computed ground truth integration (use GT DB files instead of on-the-fly brute-force)