diff --git a/.github/workflows/release-demo.yml b/.github/workflows/release-demo.yml
new file mode 100644
index 0000000..2f4b396
--- /dev/null
+++ b/.github/workflows/release-demo.yml
@@ -0,0 +1,118 @@
+name: "Release Demo (DiskANN)"
+on:
+ push:
+ branches: [diskann-yolo2]
+permissions:
+ contents: write
+jobs:
+ build-linux-x86_64-extension:
+ runs-on: ubuntu-22.04
+ steps:
+ - uses: actions/checkout@v4
+ - run: ./scripts/vendor.sh
+ - run: make loadable static
+ - uses: actions/upload-artifact@v4
+ with:
+ name: sqlite-vec-linux-x86_64-extension
+ path: dist/*
+ build-linux-aarch64-extension:
+ runs-on: ubuntu-22.04-arm
+ steps:
+ - uses: actions/checkout@v4
+ - run: ./scripts/vendor.sh
+ - run: make loadable static
+ - uses: actions/upload-artifact@v4
+ with:
+ name: sqlite-vec-linux-aarch64-extension
+ path: dist/*
+ build-macos-x86_64-extension:
+ runs-on: macos-15-intel
+ steps:
+ - uses: actions/checkout@v4
+ - run: ./scripts/vendor.sh
+ - run: make loadable static
+ - uses: actions/upload-artifact@v4
+ with:
+ name: sqlite-vec-macos-x86_64-extension
+ path: dist/*
+ build-macos-aarch64-extension:
+ runs-on: macos-14
+ steps:
+ - uses: actions/checkout@v4
+ - run: ./scripts/vendor.sh
+ - run: make loadable static
+ - uses: actions/upload-artifact@v4
+ with:
+ name: sqlite-vec-macos-aarch64-extension
+ path: dist/*
+ build-windows-x86_64-extension:
+ runs-on: windows-2022
+ steps:
+ - uses: actions/checkout@v4
+ - uses: ilammy/msvc-dev-cmd@v1
+ - uses: actions/setup-python@v5
+ with:
+ python-version: "3.12"
+ - run: ./scripts/vendor.sh
+ shell: bash
+ - run: make sqlite-vec.h
+ - run: mkdir dist
+ - run: cl.exe /fPIC -shared /W4 /Ivendor/ /O2 /LD sqlite-vec.c -o dist/vec0.dll
+ - uses: actions/upload-artifact@v4
+ with:
+ name: sqlite-vec-windows-x86_64-extension
+ path: dist/*
+ dist:
+ runs-on: ubuntu-latest
+ needs:
+ [
+ build-linux-x86_64-extension,
+ build-linux-aarch64-extension,
+ build-macos-x86_64-extension,
+ build-macos-aarch64-extension,
+ build-windows-x86_64-extension,
+ ]
+ steps:
+ - uses: actions/checkout@v4
+ - uses: actions/download-artifact@v4
+ with:
+ name: sqlite-vec-linux-x86_64-extension
+ path: dist/linux-x86_64
+ - uses: actions/download-artifact@v4
+ with:
+ name: sqlite-vec-linux-aarch64-extension
+ path: dist/linux-aarch64
+ - uses: actions/download-artifact@v4
+ with:
+ name: sqlite-vec-macos-x86_64-extension
+ path: dist/macos-x86_64
+ - uses: actions/download-artifact@v4
+ with:
+ name: sqlite-vec-macos-aarch64-extension
+ path: dist/macos-aarch64
+ - uses: actions/download-artifact@v4
+ with:
+ name: sqlite-vec-windows-x86_64-extension
+ path: dist/windows-x86_64
+ - run: make sqlite-vec.h
+ - run: |
+ ./scripts/vendor.sh
+ make amalgamation
+ mkdir -p amalgamation
+ cp dist/sqlite-vec.c sqlite-vec.h amalgamation/
+ rm dist/sqlite-vec.c
+ - uses: asg017/setup-sqlite-dist@73e37b2ffb0b51e64a64eb035da38c958b9ff6c6
+ - run: sqlite-dist build --set-version $(cat VERSION)
+ - name: Create release and upload assets
+ env:
+ GH_TOKEN: ${{ github.token }}
+ run: |
+ SHORT_SHA=$(echo "${{ github.sha }}" | head -c 10)
+ TAG="diskann-${SHORT_SHA}"
+ zip -j "amalgamation/sqlite-vec-amalgamation.zip" amalgamation/sqlite-vec.c amalgamation/sqlite-vec.h
+ gh release create "$TAG" \
+ --title "$TAG" \
+ --target "${{ github.sha }}" \
+ --prerelease \
+ amalgamation/sqlite-vec-amalgamation.zip \
+ .sqlite-dist/pip/*
diff --git a/Makefile b/Makefile
index 2758ee5..89907fa 100644
--- a/Makefile
+++ b/Makefile
@@ -204,7 +204,7 @@ test-loadable-watch:
watchexec --exts c,py,Makefile --clear -- make test-loadable
test-unit:
- $(CC) -DSQLITE_CORE -DSQLITE_VEC_TEST -DSQLITE_VEC_ENABLE_RESCORE tests/test-unit.c sqlite-vec.c vendor/sqlite3.c -I./ -Ivendor -o $(prefix)/test-unit && $(prefix)/test-unit
+ $(CC) -DSQLITE_CORE -DSQLITE_VEC_TEST -DSQLITE_VEC_ENABLE_RESCORE -DSQLITE_VEC_ENABLE_DISKANN=1 tests/test-unit.c sqlite-vec.c vendor/sqlite3.c -I./ -Ivendor $(CFLAGS) -o $(prefix)/test-unit && $(prefix)/test-unit
# Standalone sqlite3 CLI with vec0 compiled in. Useful for benchmarking,
# profiling (has debug symbols), and scripting without .load_extension.
diff --git a/benchmarks-ann/Makefile b/benchmarks-ann/Makefile
index 6081457..ddceb65 100644
--- a/benchmarks-ann/Makefile
+++ b/benchmarks-ann/Makefile
@@ -19,9 +19,16 @@ RESCORE_CONFIGS = \
"rescore-bit-os16:type=rescore,quantizer=bit,oversample=16" \
"rescore-int8-os8:type=rescore,quantizer=int8,oversample=8"
-ALL_CONFIGS = $(BASELINES) $(RESCORE_CONFIGS) $(IVF_CONFIGS)
+# --- DiskANN configs ---
+DISKANN_CONFIGS = \
+ "diskann-R48-binary:type=diskann,R=48,L=128,quantizer=binary" \
+ "diskann-R72-binary:type=diskann,R=72,L=128,quantizer=binary" \
+ "diskann-R72-int8:type=diskann,R=72,L=128,quantizer=int8" \
+ "diskann-R72-L256:type=diskann,R=72,L=256,quantizer=binary"
-.PHONY: seed ground-truth bench-smoke bench-rescore bench-ivf bench-10k bench-50k bench-100k bench-all \
+ALL_CONFIGS = $(BASELINES) $(RESCORE_CONFIGS) $(IVF_CONFIGS) $(DISKANN_CONFIGS)
+
+.PHONY: seed ground-truth bench-smoke bench-rescore bench-ivf bench-diskann bench-10k bench-50k bench-100k bench-all \
report clean
# --- Data preparation ---
@@ -37,7 +44,8 @@ ground-truth: seed
bench-smoke: seed
$(BENCH) --subset-size 5000 -k 10 -n 20 -o runs/smoke \
"brute-float:type=baseline,variant=float" \
- "ivf-quick:type=ivf,nlist=16,nprobe=4"
+ "ivf-quick:type=ivf,nlist=16,nprobe=4" \
+ "diskann-quick:type=diskann,R=48,L=64,quantizer=binary"
bench-rescore: seed
$(BENCH) --subset-size 10000 -k 10 -o runs/rescore \
@@ -62,6 +70,12 @@ bench-ivf: seed
$(BENCH) --subset-size 50000 -k 10 -o runs/ivf $(BASELINES) $(IVF_CONFIGS)
$(BENCH) --subset-size 100000 -k 10 -o runs/ivf $(BASELINES) $(IVF_CONFIGS)
+# --- DiskANN across sizes ---
+bench-diskann: seed
+ $(BENCH) --subset-size 10000 -k 10 -o runs/diskann $(BASELINES) $(DISKANN_CONFIGS)
+ $(BENCH) --subset-size 50000 -k 10 -o runs/diskann $(BASELINES) $(DISKANN_CONFIGS)
+ $(BENCH) --subset-size 100000 -k 10 -o runs/diskann $(BASELINES) $(DISKANN_CONFIGS)
+
# --- Report ---
report:
@echo "Use: sqlite3 runs/
/results.db 'SELECT * FROM bench_results ORDER BY recall DESC'"
diff --git a/benchmarks-ann/bench.py b/benchmarks-ann/bench.py
index c640628..520db77 100644
--- a/benchmarks-ann/bench.py
+++ b/benchmarks-ann/bench.py
@@ -6,18 +6,16 @@ across different vec0 configurations.
Config format: name:type=,key=val,key=val
- Baseline (brute-force) keys:
- type=baseline, variant=float|int8|bit, oversample=8
-
- Index-specific types can be registered via INDEX_REGISTRY (see below).
+ Available types: none, vec0-flat, rescore, ivf, diskann
Usage:
python bench.py --subset-size 10000 \
- "brute-float:type=baseline,variant=float" \
- "brute-int8:type=baseline,variant=int8" \
- "brute-bit:type=baseline,variant=bit"
+ "raw:type=none" \
+ "flat:type=vec0-flat,variant=float" \
+ "flat-int8:type=vec0-flat,variant=int8"
"""
import argparse
+from datetime import datetime, timezone
import os
import sqlite3
import statistics
@@ -56,11 +54,118 @@ INDEX_REGISTRY = {}
# ============================================================================
-# Baseline implementation
+# "none" — regular table, no vec0, manual KNN via vec_distance_cosine()
# ============================================================================
-def _baseline_create_table_sql(params):
+def _none_create_table_sql(params):
+ variant = params["variant"]
+ if variant == "int8":
+ return (
+ "CREATE TABLE vec_items ("
+ " id INTEGER PRIMARY KEY,"
+ " embedding BLOB NOT NULL,"
+ " embedding_int8 BLOB NOT NULL)"
+ )
+ elif variant == "bit":
+ return (
+ "CREATE TABLE vec_items ("
+ " id INTEGER PRIMARY KEY,"
+ " embedding BLOB NOT NULL,"
+ " embedding_bq BLOB NOT NULL)"
+ )
+ return (
+ "CREATE TABLE vec_items ("
+ " id INTEGER PRIMARY KEY,"
+ " embedding BLOB NOT NULL)"
+ )
+
+
+def _none_insert_sql(params):
+ variant = params["variant"]
+ if variant == "int8":
+ return (
+ "INSERT INTO vec_items(id, embedding, embedding_int8) "
+ "SELECT id, vector, vec_quantize_int8(vector, 'unit') "
+ "FROM base.train WHERE id >= :lo AND id < :hi"
+ )
+ elif variant == "bit":
+ return (
+ "INSERT INTO vec_items(id, embedding, embedding_bq) "
+ "SELECT id, vector, vec_quantize_binary(vector) "
+ "FROM base.train WHERE id >= :lo AND id < :hi"
+ )
+ return (
+ "INSERT INTO vec_items(id, embedding) "
+ "SELECT id, vector FROM base.train WHERE id >= :lo AND id < :hi"
+ )
+
+
+def _none_run_query(conn, params, query, k):
+ variant = params["variant"]
+ oversample = params.get("oversample", 8)
+
+ if variant == "int8":
+ q_int8 = conn.execute(
+ "SELECT vec_quantize_int8(:query, 'unit')", {"query": query}
+ ).fetchone()[0]
+ return conn.execute(
+ "WITH coarse AS ("
+ " SELECT id, embedding FROM ("
+ " SELECT id, embedding, vec_distance_cosine(vec_int8(:q_int8), vec_int8(embedding_int8)) as dist "
+ " FROM vec_items ORDER BY dist LIMIT :oversample_k"
+ " )"
+ ") "
+ "SELECT id, vec_distance_cosine(:query, embedding) as distance "
+ "FROM coarse ORDER BY 2 LIMIT :k",
+ {"q_int8": q_int8, "query": query, "k": k, "oversample_k": k * oversample},
+ ).fetchall()
+ elif variant == "bit":
+ q_bit = conn.execute(
+ "SELECT vec_quantize_binary(:query)", {"query": query}
+ ).fetchone()[0]
+ return conn.execute(
+ "WITH coarse AS ("
+ " SELECT id, embedding FROM ("
+ " SELECT id, embedding, vec_distance_hamming(vec_bit(:q_bit), vec_bit(embedding_bq)) as dist "
+ " FROM vec_items ORDER BY dist LIMIT :oversample_k"
+ " )"
+ ") "
+ "SELECT id, vec_distance_cosine(:query, embedding) as distance "
+ "FROM coarse ORDER BY 2 LIMIT :k",
+ {"q_bit": q_bit, "query": query, "k": k, "oversample_k": k * oversample},
+ ).fetchall()
+
+ return conn.execute(
+ "SELECT id, vec_distance_cosine(:query, embedding) as distance "
+ "FROM vec_items ORDER BY 2 LIMIT :k",
+ {"query": query, "k": k},
+ ).fetchall()
+
+
+def _none_describe(params):
+ v = params["variant"]
+ if v in ("int8", "bit"):
+ return f"none {v} (os={params['oversample']})"
+ return f"none float"
+
+
+INDEX_REGISTRY["none"] = {
+ "defaults": {"variant": "float", "oversample": 8},
+ "create_table_sql": _none_create_table_sql,
+ "insert_sql": _none_insert_sql,
+ "post_insert_hook": None,
+ "run_query": _none_run_query,
+ "describe": _none_describe,
+}
+
+
+# ============================================================================
+# vec0-flat — vec0 virtual table with brute-force MATCH
+# ============================================================================
+
+
+def _vec0flat_create_table_sql(params):
variant = params["variant"]
extra = ""
if variant == "int8":
@@ -76,7 +181,7 @@ def _baseline_create_table_sql(params):
)
-def _baseline_insert_sql(params):
+def _vec0flat_insert_sql(params):
variant = params["variant"]
if variant == "int8":
return (
@@ -93,7 +198,7 @@ def _baseline_insert_sql(params):
return None # use default
-def _baseline_run_query(conn, params, query, k):
+def _vec0flat_run_query(conn, params, query, k):
variant = params["variant"]
oversample = params.get("oversample", 8)
@@ -123,20 +228,20 @@ def _baseline_run_query(conn, params, query, k):
return None # use default MATCH
-def _baseline_describe(params):
+def _vec0flat_describe(params):
v = params["variant"]
if v in ("int8", "bit"):
- return f"baseline {v} (os={params['oversample']})"
- return f"baseline {v}"
+ return f"vec0-flat {v} (os={params['oversample']})"
+ return f"vec0-flat {v}"
-INDEX_REGISTRY["baseline"] = {
+INDEX_REGISTRY["vec0-flat"] = {
"defaults": {"variant": "float", "oversample": 8},
- "create_table_sql": _baseline_create_table_sql,
- "insert_sql": _baseline_insert_sql,
+ "create_table_sql": _vec0flat_create_table_sql,
+ "insert_sql": _vec0flat_insert_sql,
"post_insert_hook": None,
- "run_query": _baseline_run_query,
- "describe": _baseline_describe,
+ "run_query": _vec0flat_run_query,
+ "describe": _vec0flat_describe,
}
@@ -215,12 +320,64 @@ INDEX_REGISTRY["ivf"] = {
}
+# ============================================================================
+# DiskANN implementation
+# ============================================================================
+
+
+def _diskann_create_table_sql(params):
+ bt = params["buffer_threshold"]
+ extra = f", buffer_threshold={bt}" if bt > 0 else ""
+ return (
+ f"CREATE VIRTUAL TABLE vec_items USING vec0("
+ f" id integer primary key,"
+ f" embedding float[768] distance_metric=cosine"
+ f" INDEXED BY diskann("
+ f" neighbor_quantizer={params['quantizer']},"
+ f" n_neighbors={params['R']},"
+ f" search_list_size={params['L']}"
+ f" {extra}"
+ f" )"
+ f")"
+ )
+
+
+def _diskann_pre_query_hook(conn, params):
+ L_search = params.get("L_search")
+ if L_search:
+ conn.execute(
+ "INSERT INTO vec_items(id) VALUES (?)",
+ (f"search_list_size_search={L_search}",),
+ )
+ conn.commit()
+ print(f" Set search_list_size_search={L_search}")
+
+
+def _diskann_describe(params):
+ desc = f"diskann q={params['quantizer']:<6} R={params['R']:<3} L={params['L']}"
+ L_search = params.get("L_search")
+ if L_search:
+ desc += f" L_search={L_search}"
+ return desc
+
+
+INDEX_REGISTRY["diskann"] = {
+ "defaults": {"R": 72, "L": 128, "quantizer": "binary", "buffer_threshold": 0},
+ "create_table_sql": _diskann_create_table_sql,
+ "insert_sql": None,
+ "post_insert_hook": None,
+ "pre_query_hook": _diskann_pre_query_hook,
+ "run_query": None,
+ "describe": _diskann_describe,
+}
+
+
# ============================================================================
# Config parsing
# ============================================================================
INT_KEYS = {
- "R", "L", "buffer_threshold", "nlist", "nprobe", "oversample",
+ "R", "L", "L_search", "buffer_threshold", "nlist", "nprobe", "oversample",
"n_trees", "search_k",
}
@@ -238,7 +395,7 @@ def parse_config(spec):
k, v = kv.split("=", 1)
raw[k.strip()] = v.strip()
- index_type = raw.pop("type", "baseline")
+ index_type = raw.pop("type", "vec0-flat")
if index_type not in INDEX_REGISTRY:
raise ValueError(
f"Unknown index type: {index_type}. "
@@ -289,7 +446,7 @@ def insert_loop(conn, sql, subset_size, label=""):
return time.perf_counter() - t0
-def open_bench_db(db_path, ext_path, base_db):
+def create_bench_db(db_path, ext_path, base_db):
if os.path.exists(db_path):
os.remove(db_path)
conn = sqlite3.connect(db_path)
@@ -300,6 +457,19 @@ def open_bench_db(db_path, ext_path, base_db):
return conn
+def open_existing_bench_db(db_path, ext_path, base_db):
+ if not os.path.exists(db_path):
+ raise FileNotFoundError(
+ f"Index DB not found: {db_path}\n"
+ f"Build it first with: --phase build"
+ )
+ conn = sqlite3.connect(db_path)
+ conn.enable_load_extension(True)
+ conn.load_extension(ext_path)
+ conn.execute(f"ATTACH DATABASE '{base_db}' AS base")
+ return conn
+
+
DEFAULT_INSERT_SQL = (
"INSERT INTO vec_items(id, embedding) "
"SELECT id, vector FROM base.train WHERE id >= :lo AND id < :hi"
@@ -313,7 +483,7 @@ DEFAULT_INSERT_SQL = (
def build_index(base_db, ext_path, name, params, subset_size, out_dir):
db_path = os.path.join(out_dir, f"{name}.{subset_size}.db")
- conn = open_bench_db(db_path, ext_path, base_db)
+ conn = create_bench_db(db_path, ext_path, base_db)
reg = INDEX_REGISTRY[params["index_type"]]
@@ -364,12 +534,16 @@ def _default_match_query(conn, query, k):
).fetchall()
-def measure_knn(db_path, ext_path, base_db, params, subset_size, k=10, n=50):
+def measure_knn(db_path, ext_path, base_db, params, subset_size, k=10, n=50,
+ pre_query_hook=None):
conn = sqlite3.connect(db_path)
conn.enable_load_extension(True)
conn.load_extension(ext_path)
conn.execute(f"ATTACH DATABASE '{base_db}' AS base")
+ if pre_query_hook:
+ pre_query_hook(conn, params)
+
query_vectors = load_query_vectors(base_db, n)
reg = INDEX_REGISTRY[params["index_type"]]
@@ -431,6 +605,34 @@ def measure_knn(db_path, ext_path, base_db, params, subset_size, k=10, n=50):
# ============================================================================
+def open_results_db(results_path):
+ db = sqlite3.connect(results_path)
+ db.executescript(open(os.path.join(_SCRIPT_DIR, "schema.sql")).read())
+ # Migrate existing DBs that predate the runs table
+ cols = {r[1] for r in db.execute("PRAGMA table_info(runs)").fetchall()}
+ if "phase" not in cols:
+ db.execute("ALTER TABLE runs ADD COLUMN phase TEXT NOT NULL DEFAULT 'both'")
+ db.commit()
+ return db
+
+
+def create_run(db, config_name, index_type, subset_size, phase, k=None, n=None):
+ cur = db.execute(
+ "INSERT INTO runs (config_name, index_type, subset_size, phase, status, k, n) "
+ "VALUES (?, ?, ?, ?, 'pending', ?, ?)",
+ (config_name, index_type, subset_size, phase, k, n),
+ )
+ db.commit()
+ return cur.lastrowid
+
+
+def update_run(db, run_id, **kwargs):
+ sets = ", ".join(f"{k} = ?" for k in kwargs)
+ vals = list(kwargs.values()) + [run_id]
+ db.execute(f"UPDATE runs SET {sets} WHERE run_id = ?", vals)
+ db.commit()
+
+
def save_results(results_path, rows):
db = sqlite3.connect(results_path)
db.executescript(open(os.path.join(_SCRIPT_DIR, "schema.sql")).read())
@@ -500,6 +702,8 @@ def main():
parser.add_argument("--subset-size", type=int, required=True)
parser.add_argument("-k", type=int, default=10, help="KNN k (default 10)")
parser.add_argument("-n", type=int, default=50, help="number of queries (default 50)")
+ parser.add_argument("--phase", choices=["build", "query", "both"], default="both",
+ help="build=build only, query=query existing index, both=default")
parser.add_argument("--base-db", default=BASE_DB)
parser.add_argument("--ext", default=EXT_PATH)
parser.add_argument("-o", "--out-dir", default="runs")
@@ -508,55 +712,164 @@ def main():
args = parser.parse_args()
os.makedirs(args.out_dir, exist_ok=True)
- results_db = args.results_db or os.path.join(args.out_dir, "results.db")
+ results_db_path = args.results_db or os.path.join(args.out_dir, "results.db")
configs = [parse_config(c) for c in args.configs]
+ results_db = open_results_db(results_db_path)
all_results = []
for i, (name, params) in enumerate(configs, 1):
reg = INDEX_REGISTRY[params["index_type"]]
desc = reg["describe"](params)
- print(f"\n[{i}/{len(configs)}] {name} ({desc.strip()})")
+ print(f"\n[{i}/{len(configs)}] {name} ({desc.strip()}) [phase={args.phase}]")
- build = build_index(
- args.base_db, args.ext, name, params, args.subset_size, args.out_dir
- )
- train_str = f" + {build['train_time_s']}s train" if build["train_time_s"] > 0 else ""
- print(
- f" Build: {build['insert_time_s']}s insert{train_str} "
- f"{build['file_size_mb']} MB"
- )
+ db_path = os.path.join(args.out_dir, f"{name}.{args.subset_size}.db")
- print(f" Measuring KNN (k={args.k}, n={args.n})...")
- knn = measure_knn(
- build["db_path"], args.ext, args.base_db,
- params, args.subset_size, k=args.k, n=args.n,
- )
- print(f" KNN: mean={knn['mean_ms']}ms recall@{args.k}={knn['recall']}")
+ if args.phase == "build":
+ run_id = create_run(results_db, name, params["index_type"],
+ args.subset_size, "build")
+ update_run(results_db, run_id, status="inserting")
- all_results.append({
- "name": name,
- "n_vectors": args.subset_size,
- "index_type": params["index_type"],
- "config_desc": desc,
- "db_path": build["db_path"],
- "insert_time_s": build["insert_time_s"],
- "train_time_s": build["train_time_s"],
- "total_time_s": build["total_time_s"],
- "insert_per_vec_ms": build["insert_per_vec_ms"],
- "rows": build["rows"],
- "file_size_mb": build["file_size_mb"],
- "k": args.k,
- "n_queries": args.n,
- "mean_ms": knn["mean_ms"],
- "median_ms": knn["median_ms"],
- "p99_ms": knn["p99_ms"],
- "total_ms": knn["total_ms"],
- "recall": knn["recall"],
- })
+ build = build_index(
+ args.base_db, args.ext, name, params, args.subset_size, args.out_dir
+ )
+ train_str = f" + {build['train_time_s']}s train" if build["train_time_s"] > 0 else ""
+ print(
+ f" Build: {build['insert_time_s']}s insert{train_str} "
+ f"{build['file_size_mb']} MB"
+ )
+ update_run(results_db, run_id,
+ status="built",
+ db_path=build["db_path"],
+ insert_time_s=build["insert_time_s"],
+ train_time_s=build["train_time_s"],
+ total_build_time_s=build["total_time_s"],
+ rows=build["rows"],
+ file_size_mb=build["file_size_mb"],
+ finished_at=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"))
+ print(f" Index DB: {build['db_path']}")
- print_report(all_results)
- save_results(results_db, all_results)
- print(f"\nResults saved to {results_db}")
+ elif args.phase == "query":
+ if not os.path.exists(db_path):
+ raise FileNotFoundError(
+ f"Index DB not found: {db_path}\n"
+ f"Build it first with: --phase build"
+ )
+
+ run_id = create_run(results_db, name, params["index_type"],
+ args.subset_size, "query", k=args.k, n=args.n)
+ update_run(results_db, run_id, status="querying")
+
+ pre_hook = reg.get("pre_query_hook")
+ print(f" Measuring KNN (k={args.k}, n={args.n})...")
+ knn = measure_knn(
+ db_path, args.ext, args.base_db,
+ params, args.subset_size, k=args.k, n=args.n,
+ pre_query_hook=pre_hook,
+ )
+ print(f" KNN: mean={knn['mean_ms']}ms recall@{args.k}={knn['recall']}")
+
+ qps = round(args.n / (knn["total_ms"] / 1000), 1) if knn["total_ms"] > 0 else 0
+ update_run(results_db, run_id,
+ status="done",
+ db_path=db_path,
+ mean_ms=knn["mean_ms"],
+ median_ms=knn["median_ms"],
+ p99_ms=knn["p99_ms"],
+ total_query_ms=knn["total_ms"],
+ qps=qps,
+ recall=knn["recall"],
+ finished_at=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"))
+
+ file_size_mb = os.path.getsize(db_path) / (1024 * 1024)
+ all_results.append({
+ "name": name,
+ "n_vectors": args.subset_size,
+ "index_type": params["index_type"],
+ "config_desc": desc,
+ "db_path": db_path,
+ "insert_time_s": 0,
+ "train_time_s": 0,
+ "total_time_s": 0,
+ "insert_per_vec_ms": 0,
+ "rows": 0,
+ "file_size_mb": file_size_mb,
+ "k": args.k,
+ "n_queries": args.n,
+ "mean_ms": knn["mean_ms"],
+ "median_ms": knn["median_ms"],
+ "p99_ms": knn["p99_ms"],
+ "total_ms": knn["total_ms"],
+ "recall": knn["recall"],
+ })
+
+ else: # both
+ run_id = create_run(results_db, name, params["index_type"],
+ args.subset_size, "both", k=args.k, n=args.n)
+ update_run(results_db, run_id, status="inserting")
+
+ build = build_index(
+ args.base_db, args.ext, name, params, args.subset_size, args.out_dir
+ )
+ train_str = f" + {build['train_time_s']}s train" if build["train_time_s"] > 0 else ""
+ print(
+ f" Build: {build['insert_time_s']}s insert{train_str} "
+ f"{build['file_size_mb']} MB"
+ )
+ update_run(results_db, run_id, status="querying",
+ db_path=build["db_path"],
+ insert_time_s=build["insert_time_s"],
+ train_time_s=build["train_time_s"],
+ total_build_time_s=build["total_time_s"],
+ rows=build["rows"],
+ file_size_mb=build["file_size_mb"])
+
+ print(f" Measuring KNN (k={args.k}, n={args.n})...")
+ knn = measure_knn(
+ build["db_path"], args.ext, args.base_db,
+ params, args.subset_size, k=args.k, n=args.n,
+ )
+ print(f" KNN: mean={knn['mean_ms']}ms recall@{args.k}={knn['recall']}")
+
+ qps = round(args.n / (knn["total_ms"] / 1000), 1) if knn["total_ms"] > 0 else 0
+ update_run(results_db, run_id,
+ status="done",
+ mean_ms=knn["mean_ms"],
+ median_ms=knn["median_ms"],
+ p99_ms=knn["p99_ms"],
+ total_query_ms=knn["total_ms"],
+ qps=qps,
+ recall=knn["recall"],
+ finished_at=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"))
+
+ all_results.append({
+ "name": name,
+ "n_vectors": args.subset_size,
+ "index_type": params["index_type"],
+ "config_desc": desc,
+ "db_path": build["db_path"],
+ "insert_time_s": build["insert_time_s"],
+ "train_time_s": build["train_time_s"],
+ "total_time_s": build["total_time_s"],
+ "insert_per_vec_ms": build["insert_per_vec_ms"],
+ "rows": build["rows"],
+ "file_size_mb": build["file_size_mb"],
+ "k": args.k,
+ "n_queries": args.n,
+ "mean_ms": knn["mean_ms"],
+ "median_ms": knn["median_ms"],
+ "p99_ms": knn["p99_ms"],
+ "total_ms": knn["total_ms"],
+ "recall": knn["recall"],
+ })
+
+ results_db.close()
+
+ if all_results:
+ print_report(all_results)
+ save_results(results_db_path, all_results)
+ print(f"\nResults saved to {results_db_path}")
+ elif args.phase == "build":
+ print(f"\nBuild complete. Results tracked in {results_db_path}")
if __name__ == "__main__":
diff --git a/benchmarks-ann/schema.sql b/benchmarks-ann/schema.sql
index 681df4e..ae8acf3 100644
--- a/benchmarks-ann/schema.sql
+++ b/benchmarks-ann/schema.sql
@@ -3,6 +3,31 @@
-- "baseline"; index-specific branches add their own types (registered
-- via INDEX_REGISTRY in bench.py).
+CREATE TABLE IF NOT EXISTS runs (
+ run_id INTEGER PRIMARY KEY AUTOINCREMENT,
+ config_name TEXT NOT NULL,
+ index_type TEXT NOT NULL,
+ subset_size INTEGER NOT NULL,
+ phase TEXT NOT NULL DEFAULT 'both', -- 'build', 'query', or 'both'
+ status TEXT NOT NULL DEFAULT 'pending',
+ k INTEGER,
+ n INTEGER,
+ db_path TEXT,
+ insert_time_s REAL,
+ train_time_s REAL,
+ total_build_time_s REAL,
+ rows INTEGER,
+ file_size_mb REAL,
+ mean_ms REAL,
+ median_ms REAL,
+ p99_ms REAL,
+ total_query_ms REAL,
+ qps REAL,
+ recall REAL,
+ created_at TEXT NOT NULL DEFAULT (datetime('now')),
+ finished_at TEXT
+);
+
CREATE TABLE IF NOT EXISTS build_results (
config_name TEXT NOT NULL,
index_type TEXT NOT NULL,
diff --git a/sqlite-vec-diskann.c b/sqlite-vec-diskann.c
new file mode 100644
index 0000000..1a5fd2b
--- /dev/null
+++ b/sqlite-vec-diskann.c
@@ -0,0 +1,1768 @@
+// DiskANN algorithm implementation
+// This file is #include'd into sqlite-vec.c — not compiled separately.
+
+// ============================================================
+// DiskANN node blob encode/decode functions
+// ============================================================
+
+/** Compute size of validity bitmap in bytes. */
+int diskann_validity_byte_size(int n_neighbors) {
+ return n_neighbors / CHAR_BIT;
+}
+
+/** Compute size of neighbor IDs blob in bytes. */
+size_t diskann_neighbor_ids_byte_size(int n_neighbors) {
+ return (size_t)n_neighbors * sizeof(i64);
+}
+
+/** Compute size of quantized vectors blob in bytes. */
+size_t diskann_neighbor_qvecs_byte_size(
+ int n_neighbors, enum Vec0DiskannQuantizerType quantizer_type,
+ size_t dimensions) {
+ return (size_t)n_neighbors *
+ diskann_quantized_vector_byte_size(quantizer_type, dimensions);
+}
+
+/**
+ * Create empty blobs for a new DiskANN node (all neighbors invalid).
+ * Caller must free the returned pointers with sqlite3_free().
+ */
+int diskann_node_init(
+ int n_neighbors, enum Vec0DiskannQuantizerType quantizer_type,
+ size_t dimensions,
+ u8 **outValidity, int *outValiditySize,
+ u8 **outNeighborIds, int *outNeighborIdsSize,
+ u8 **outNeighborQvecs, int *outNeighborQvecsSize) {
+
+ int validitySize = diskann_validity_byte_size(n_neighbors);
+ size_t idsSize = diskann_neighbor_ids_byte_size(n_neighbors);
+ size_t qvecsSize = diskann_neighbor_qvecs_byte_size(
+ n_neighbors, quantizer_type, dimensions);
+
+ u8 *validity = sqlite3_malloc(validitySize);
+ u8 *ids = sqlite3_malloc(idsSize);
+ u8 *qvecs = sqlite3_malloc(qvecsSize);
+
+ if (!validity || !ids || !qvecs) {
+ sqlite3_free(validity);
+ sqlite3_free(ids);
+ sqlite3_free(qvecs);
+ return SQLITE_NOMEM;
+ }
+
+ memset(validity, 0, validitySize);
+ memset(ids, 0, idsSize);
+ memset(qvecs, 0, qvecsSize);
+
+ *outValidity = validity; *outValiditySize = validitySize;
+ *outNeighborIds = ids; *outNeighborIdsSize = (int)idsSize;
+ *outNeighborQvecs = qvecs; *outNeighborQvecsSize = (int)qvecsSize;
+ return SQLITE_OK;
+}
+
+/** Check if neighbor slot i is valid. */
+int diskann_validity_get(const u8 *validity, int i) {
+ return (validity[i / CHAR_BIT] >> (i % CHAR_BIT)) & 1;
+}
+
+/** Set neighbor slot i as valid (1) or invalid (0). */
+void diskann_validity_set(u8 *validity, int i, int value) {
+ if (value) {
+ validity[i / CHAR_BIT] |= (1 << (i % CHAR_BIT));
+ } else {
+ validity[i / CHAR_BIT] &= ~(1 << (i % CHAR_BIT));
+ }
+}
+
+/** Count the number of valid neighbors. */
+int diskann_validity_count(const u8 *validity, int n_neighbors) {
+ int count = 0;
+ for (int i = 0; i < n_neighbors; i++) {
+ count += diskann_validity_get(validity, i);
+ }
+ return count;
+}
+
+/** Get the rowid of the neighbor in slot i. */
+i64 diskann_neighbor_id_get(const u8 *neighbor_ids, int i) {
+ i64 result;
+ memcpy(&result, neighbor_ids + i * sizeof(i64), sizeof(i64));
+ return result;
+}
+
+/** Set the rowid of the neighbor in slot i. */
+void diskann_neighbor_id_set(u8 *neighbor_ids, int i, i64 rowid) {
+ memcpy(neighbor_ids + i * sizeof(i64), &rowid, sizeof(i64));
+}
+
+/** Get a pointer to the quantized vector in slot i (read-only). */
+const u8 *diskann_neighbor_qvec_get(
+ const u8 *qvecs, int i,
+ enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions) {
+ size_t qvec_size = diskann_quantized_vector_byte_size(quantizer_type, dimensions);
+ return qvecs + (size_t)i * qvec_size;
+}
+
+/** Copy a quantized vector into slot i. */
+void diskann_neighbor_qvec_set(
+ u8 *qvecs, int i, const u8 *src_qvec,
+ enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions) {
+ size_t qvec_size = diskann_quantized_vector_byte_size(quantizer_type, dimensions);
+ memcpy(qvecs + (size_t)i * qvec_size, src_qvec, qvec_size);
+}
+
+/**
+ * Set neighbor slot i with a rowid and quantized vector, and mark as valid.
+ */
+void diskann_node_set_neighbor(
+ u8 *validity, u8 *neighbor_ids, u8 *qvecs, int i,
+ i64 neighbor_rowid, const u8 *neighbor_qvec,
+ enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions) {
+ diskann_validity_set(validity, i, 1);
+ diskann_neighbor_id_set(neighbor_ids, i, neighbor_rowid);
+ diskann_neighbor_qvec_set(qvecs, i, neighbor_qvec, quantizer_type, dimensions);
+}
+
+/**
+ * Clear neighbor slot i (mark invalid, zero out data).
+ */
+void diskann_node_clear_neighbor(
+ u8 *validity, u8 *neighbor_ids, u8 *qvecs, int i,
+ enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions) {
+ diskann_validity_set(validity, i, 0);
+ diskann_neighbor_id_set(neighbor_ids, i, 0);
+ size_t qvec_size = diskann_quantized_vector_byte_size(quantizer_type, dimensions);
+ memset(qvecs + (size_t)i * qvec_size, 0, qvec_size);
+}
+
+/**
+ * Quantize a full-precision float32 vector into the target quantizer format.
+ * Output buffer must be pre-allocated with diskann_quantized_vector_byte_size() bytes.
+ */
+int diskann_quantize_vector(
+ const f32 *src, size_t dimensions,
+ enum Vec0DiskannQuantizerType quantizer_type,
+ u8 *out) {
+
+ switch (quantizer_type) {
+ case VEC0_DISKANN_QUANTIZER_BINARY: {
+ memset(out, 0, dimensions / CHAR_BIT);
+ for (size_t i = 0; i < dimensions; i++) {
+ if (src[i] > 0.0f) {
+ out[i / CHAR_BIT] |= (1 << (i % CHAR_BIT));
+ }
+ }
+ return SQLITE_OK;
+ }
+ case VEC0_DISKANN_QUANTIZER_INT8: {
+ f32 step = (1.0f - (-1.0f)) / 255.0f;
+ for (size_t i = 0; i < dimensions; i++) {
+ ((i8 *)out)[i] = (i8)(((src[i] - (-1.0f)) / step) - 128.0f);
+ }
+ return SQLITE_OK;
+ }
+ }
+ return SQLITE_ERROR;
+}
+
+/**
+ * Compute approximate distance between a full-precision query vector and a
+ * quantized neighbor vector. Used during graph traversal.
+ */
+/**
+ * Compute distance between a pre-quantized query and a quantized neighbor.
+ * The caller is responsible for quantizing the query vector once and passing
+ * the result here for each neighbor comparison.
+ */
+static f32 diskann_distance_quantized_precomputed(
+ const u8 *query_quantized, const u8 *quantized_neighbor,
+ size_t dimensions,
+ enum Vec0DiskannQuantizerType quantizer_type,
+ enum Vec0DistanceMetrics distance_metric) {
+
+ switch (quantizer_type) {
+ case VEC0_DISKANN_QUANTIZER_BINARY:
+ return distance_hamming(query_quantized, quantized_neighbor, &dimensions);
+ case VEC0_DISKANN_QUANTIZER_INT8: {
+ switch (distance_metric) {
+ case VEC0_DISTANCE_METRIC_L2:
+ return distance_l2_sqr_int8(query_quantized, quantized_neighbor, &dimensions);
+ case VEC0_DISTANCE_METRIC_COSINE:
+ return distance_cosine_int8(query_quantized, quantized_neighbor, &dimensions);
+ case VEC0_DISTANCE_METRIC_L1:
+ return (f32)distance_l1_int8(query_quantized, quantized_neighbor, &dimensions);
+ }
+ break;
+ }
+ }
+ return FLT_MAX;
+}
+
+/**
+ * Quantize a float query vector. Returns allocated buffer (caller must free).
+ */
+static u8 *diskann_quantize_query(
+ const f32 *query_vector, size_t dimensions,
+ enum Vec0DiskannQuantizerType quantizer_type) {
+ size_t qsize = diskann_quantized_vector_byte_size(quantizer_type, dimensions);
+ u8 *buf = sqlite3_malloc(qsize);
+ if (!buf) return NULL;
+ diskann_quantize_vector(query_vector, dimensions, quantizer_type, buf);
+ return buf;
+}
+
+/**
+ * Legacy wrapper: quantizes on-the-fly (used by callers that don't pre-quantize).
+ */
+f32 diskann_distance_quantized(
+ const void *query_vector, const u8 *quantized_neighbor,
+ size_t dimensions,
+ enum Vec0DiskannQuantizerType quantizer_type,
+ enum Vec0DistanceMetrics distance_metric) {
+
+ u8 *query_q = diskann_quantize_query((const f32 *)query_vector, dimensions, quantizer_type);
+ if (!query_q) return FLT_MAX;
+ f32 dist = diskann_distance_quantized_precomputed(
+ query_q, quantized_neighbor, dimensions, quantizer_type, distance_metric);
+ sqlite3_free(query_q);
+ return dist;
+}
+
+// ============================================================
+// DiskANN medoid / entry point management
+// ============================================================
+
+/**
+ * Get the current medoid rowid for the given vector column's DiskANN index.
+ * Returns SQLITE_OK with *outMedoid set to the medoid rowid.
+ * If the graph is empty, returns SQLITE_OK with *outIsEmpty = 1.
+ */
+static int diskann_medoid_get(vec0_vtab *p, int vec_col_idx,
+ i64 *outMedoid, int *outIsEmpty) {
+ int rc;
+ sqlite3_stmt *stmt = NULL;
+ char *key = sqlite3_mprintf("diskann_medoid_%02d", vec_col_idx);
+ char *zSql = sqlite3_mprintf(
+ "SELECT value FROM " VEC0_SHADOW_INFO_NAME " WHERE key = ?",
+ p->schemaName, p->tableName);
+ if (!key || !zSql) {
+ sqlite3_free(key);
+ sqlite3_free(zSql);
+ return SQLITE_NOMEM;
+ }
+
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL);
+ sqlite3_free(zSql);
+ if (rc != SQLITE_OK) {
+ sqlite3_free(key);
+ return rc;
+ }
+
+ sqlite3_bind_text(stmt, 1, key, -1, sqlite3_free);
+ rc = sqlite3_step(stmt);
+ if (rc == SQLITE_ROW) {
+ if (sqlite3_column_type(stmt, 0) == SQLITE_NULL) {
+ *outIsEmpty = 1;
+ } else {
+ *outIsEmpty = 0;
+ *outMedoid = sqlite3_column_int64(stmt, 0);
+ }
+ rc = SQLITE_OK;
+ } else {
+ rc = SQLITE_ERROR;
+ }
+ sqlite3_finalize(stmt);
+ return rc;
+}
+
+/**
+ * Set the medoid rowid for the given vector column's DiskANN index.
+ * Pass isEmpty = 1 to mark the graph as empty (NULL medoid).
+ */
+static int diskann_medoid_set(vec0_vtab *p, int vec_col_idx,
+ i64 medoidRowid, int isEmpty) {
+ int rc;
+ sqlite3_stmt *stmt = NULL;
+ char *key = sqlite3_mprintf("diskann_medoid_%02d", vec_col_idx);
+ char *zSql = sqlite3_mprintf(
+ "UPDATE " VEC0_SHADOW_INFO_NAME " SET value = ?2 WHERE key = ?1",
+ p->schemaName, p->tableName);
+ if (!key || !zSql) {
+ sqlite3_free(key);
+ sqlite3_free(zSql);
+ return SQLITE_NOMEM;
+ }
+
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL);
+ sqlite3_free(zSql);
+ if (rc != SQLITE_OK) {
+ sqlite3_free(key);
+ return rc;
+ }
+
+ sqlite3_bind_text(stmt, 1, key, -1, sqlite3_free);
+ if (isEmpty) {
+ sqlite3_bind_null(stmt, 2);
+ } else {
+ sqlite3_bind_int64(stmt, 2, medoidRowid);
+ }
+ rc = sqlite3_step(stmt);
+ sqlite3_finalize(stmt);
+ return (rc == SQLITE_DONE) ? SQLITE_OK : SQLITE_ERROR;
+}
+
+
+/**
+ * Called when deleting a vector. If the deleted vector was the medoid,
+ * pick a new one (the first available vector, or set to empty if none remain).
+ */
+static int diskann_medoid_handle_delete(vec0_vtab *p, int vec_col_idx,
+ i64 deletedRowid) {
+ i64 currentMedoid;
+ int isEmpty;
+ int rc = diskann_medoid_get(p, vec_col_idx, ¤tMedoid, &isEmpty);
+ if (rc != SQLITE_OK) return rc;
+
+ if (!isEmpty && currentMedoid == deletedRowid) {
+ sqlite3_stmt *stmt = NULL;
+ char *zSql = sqlite3_mprintf(
+ "SELECT rowid FROM " VEC0_SHADOW_VECTORS_N_NAME " WHERE rowid != ?1 LIMIT 1",
+ p->schemaName, p->tableName, vec_col_idx);
+ if (!zSql) return SQLITE_NOMEM;
+
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL);
+ sqlite3_free(zSql);
+ if (rc != SQLITE_OK) return rc;
+
+ sqlite3_bind_int64(stmt, 1, deletedRowid);
+ rc = sqlite3_step(stmt);
+ if (rc == SQLITE_ROW) {
+ i64 newMedoid = sqlite3_column_int64(stmt, 0);
+ sqlite3_finalize(stmt);
+ return diskann_medoid_set(p, vec_col_idx, newMedoid, 0);
+ } else {
+ sqlite3_finalize(stmt);
+ return diskann_medoid_set(p, vec_col_idx, -1, 1);
+ }
+ }
+ return SQLITE_OK;
+}
+
+// ============================================================
+// DiskANN database I/O helpers
+// ============================================================
+
+/**
+ * Read a node's full data from _diskann_nodes.
+ * Returns blobs that must be freed by the caller with sqlite3_free().
+ */
+static int diskann_node_read(vec0_vtab *p, int vec_col_idx, i64 rowid,
+ u8 **outValidity, int *outValiditySize,
+ u8 **outNeighborIds, int *outNeighborIdsSize,
+ u8 **outQvecs, int *outQvecsSize) {
+ int rc;
+ if (!p->stmtDiskannNodeRead[vec_col_idx]) {
+ char *zSql = sqlite3_mprintf(
+ "SELECT neighbors_validity, neighbor_ids, neighbor_quantized_vectors "
+ "FROM " VEC0_SHADOW_DISKANN_NODES_N_NAME " WHERE rowid = ?",
+ p->schemaName, p->tableName, vec_col_idx);
+ if (!zSql) return SQLITE_NOMEM;
+ rc = sqlite3_prepare_v2(p->db, zSql, -1,
+ &p->stmtDiskannNodeRead[vec_col_idx], NULL);
+ sqlite3_free(zSql);
+ if (rc != SQLITE_OK) return rc;
+ }
+
+ sqlite3_stmt *stmt = p->stmtDiskannNodeRead[vec_col_idx];
+ sqlite3_reset(stmt);
+ sqlite3_bind_int64(stmt, 1, rowid);
+
+ rc = sqlite3_step(stmt);
+ if (rc != SQLITE_ROW) {
+ return SQLITE_ERROR;
+ }
+
+ int vs = sqlite3_column_bytes(stmt, 0);
+ int is = sqlite3_column_bytes(stmt, 1);
+ int qs = sqlite3_column_bytes(stmt, 2);
+
+ // Validate blob sizes against config expectations to detect truncated /
+ // corrupt data before any caller iterates using cfg->n_neighbors.
+ {
+ struct VectorColumnDefinition *col = &p->vector_columns[vec_col_idx];
+ struct Vec0DiskannConfig *cfg = &col->diskann;
+ int expectedVs = diskann_validity_byte_size(cfg->n_neighbors);
+ int expectedIs = (int)diskann_neighbor_ids_byte_size(cfg->n_neighbors);
+ int expectedQs = (int)diskann_neighbor_qvecs_byte_size(
+ cfg->n_neighbors, cfg->quantizer_type, col->dimensions);
+ if (vs != expectedVs || is != expectedIs || qs != expectedQs) {
+ return SQLITE_CORRUPT;
+ }
+ }
+
+ u8 *v = sqlite3_malloc(vs);
+ u8 *ids = sqlite3_malloc(is);
+ u8 *qv = sqlite3_malloc(qs);
+ if (!v || !ids || !qv) {
+ sqlite3_free(v);
+ sqlite3_free(ids);
+ sqlite3_free(qv);
+ return SQLITE_NOMEM;
+ }
+
+ memcpy(v, sqlite3_column_blob(stmt, 0), vs);
+ memcpy(ids, sqlite3_column_blob(stmt, 1), is);
+ memcpy(qv, sqlite3_column_blob(stmt, 2), qs);
+
+ *outValidity = v; *outValiditySize = vs;
+ *outNeighborIds = ids; *outNeighborIdsSize = is;
+ *outQvecs = qv; *outQvecsSize = qs;
+ return SQLITE_OK;
+}
+
+/**
+ * Write (INSERT OR REPLACE) a node's data to _diskann_nodes.
+ */
+static int diskann_node_write(vec0_vtab *p, int vec_col_idx, i64 rowid,
+ const u8 *validity, int validitySize,
+ const u8 *neighborIds, int neighborIdsSize,
+ const u8 *qvecs, int qvecsSize) {
+ int rc;
+ if (!p->stmtDiskannNodeWrite[vec_col_idx]) {
+ char *zSql = sqlite3_mprintf(
+ "INSERT OR REPLACE INTO " VEC0_SHADOW_DISKANN_NODES_N_NAME
+ " (rowid, neighbors_validity, neighbor_ids, neighbor_quantized_vectors) "
+ "VALUES (?, ?, ?, ?)",
+ p->schemaName, p->tableName, vec_col_idx);
+ if (!zSql) return SQLITE_NOMEM;
+ rc = sqlite3_prepare_v2(p->db, zSql, -1,
+ &p->stmtDiskannNodeWrite[vec_col_idx], NULL);
+ sqlite3_free(zSql);
+ if (rc != SQLITE_OK) return rc;
+ }
+
+ sqlite3_stmt *stmt = p->stmtDiskannNodeWrite[vec_col_idx];
+ sqlite3_reset(stmt);
+ sqlite3_bind_int64(stmt, 1, rowid);
+ sqlite3_bind_blob(stmt, 2, validity, validitySize, SQLITE_TRANSIENT);
+ sqlite3_bind_blob(stmt, 3, neighborIds, neighborIdsSize, SQLITE_TRANSIENT);
+ sqlite3_bind_blob(stmt, 4, qvecs, qvecsSize, SQLITE_TRANSIENT);
+
+ rc = sqlite3_step(stmt);
+ return (rc == SQLITE_DONE) ? SQLITE_OK : SQLITE_ERROR;
+}
+
+/**
+ * Read the full-precision vector for a given rowid from _vectors.
+ * Caller must free *outVector with sqlite3_free().
+ */
+static int diskann_vector_read(vec0_vtab *p, int vec_col_idx, i64 rowid,
+ void **outVector, int *outVectorSize) {
+ int rc;
+ if (!p->stmtVectorsRead[vec_col_idx]) {
+ char *zSql = sqlite3_mprintf(
+ "SELECT vector FROM " VEC0_SHADOW_VECTORS_N_NAME " WHERE rowid = ?",
+ p->schemaName, p->tableName, vec_col_idx);
+ if (!zSql) return SQLITE_NOMEM;
+ rc = sqlite3_prepare_v2(p->db, zSql, -1,
+ &p->stmtVectorsRead[vec_col_idx], NULL);
+ sqlite3_free(zSql);
+ if (rc != SQLITE_OK) return rc;
+ }
+
+ sqlite3_stmt *stmt = p->stmtVectorsRead[vec_col_idx];
+ sqlite3_reset(stmt);
+ sqlite3_bind_int64(stmt, 1, rowid);
+
+ rc = sqlite3_step(stmt);
+ if (rc != SQLITE_ROW) {
+ return SQLITE_ERROR;
+ }
+
+ int sz = sqlite3_column_bytes(stmt, 0);
+ void *vec = sqlite3_malloc(sz);
+ if (!vec) return SQLITE_NOMEM;
+ memcpy(vec, sqlite3_column_blob(stmt, 0), sz);
+
+ *outVector = vec;
+ *outVectorSize = sz;
+ return SQLITE_OK;
+}
+
+/**
+ * Write a full-precision vector to _vectors.
+ */
+static int diskann_vector_write(vec0_vtab *p, int vec_col_idx, i64 rowid,
+ const void *vector, int vectorSize) {
+ int rc;
+ if (!p->stmtVectorsInsert[vec_col_idx]) {
+ char *zSql = sqlite3_mprintf(
+ "INSERT OR REPLACE INTO " VEC0_SHADOW_VECTORS_N_NAME
+ " (rowid, vector) VALUES (?, ?)",
+ p->schemaName, p->tableName, vec_col_idx);
+ if (!zSql) return SQLITE_NOMEM;
+ rc = sqlite3_prepare_v2(p->db, zSql, -1,
+ &p->stmtVectorsInsert[vec_col_idx], NULL);
+ sqlite3_free(zSql);
+ if (rc != SQLITE_OK) return rc;
+ }
+
+ sqlite3_stmt *stmt = p->stmtVectorsInsert[vec_col_idx];
+ sqlite3_reset(stmt);
+ sqlite3_bind_int64(stmt, 1, rowid);
+ sqlite3_bind_blob(stmt, 2, vector, vectorSize, SQLITE_TRANSIENT);
+
+ rc = sqlite3_step(stmt);
+ return (rc == SQLITE_DONE) ? SQLITE_OK : SQLITE_ERROR;
+}
+
+// ============================================================
+// DiskANN search data structures
+// ============================================================
+
+/**
+ * A sorted candidate list for greedy beam search.
+ */
+struct DiskannCandidateList {
+ struct Vec0DiskannCandidate *items;
+ int count;
+ int capacity;
+};
+
+static int diskann_candidate_list_init(struct DiskannCandidateList *list, int capacity) {
+ list->items = sqlite3_malloc(capacity * sizeof(struct Vec0DiskannCandidate));
+ if (!list->items) return SQLITE_NOMEM;
+ list->count = 0;
+ list->capacity = capacity;
+ return SQLITE_OK;
+}
+
+static void diskann_candidate_list_free(struct DiskannCandidateList *list) {
+ sqlite3_free(list->items);
+ list->items = NULL;
+ list->count = 0;
+ list->capacity = 0;
+}
+
+/**
+ * Insert a candidate into the sorted list, maintaining sort order by distance.
+ * Deduplicates by rowid. If at capacity and new candidate is worse, discards it.
+ * Returns 1 if inserted, 0 if discarded.
+ */
+static int diskann_candidate_list_insert(
+ struct DiskannCandidateList *list, i64 rowid, f32 distance) {
+
+ // Check for duplicate
+ for (int i = 0; i < list->count; i++) {
+ if (list->items[i].rowid == rowid) {
+ // Update distance if better
+ if (distance < list->items[i].distance) {
+ list->items[i].distance = distance;
+ // Re-sort this item into position
+ struct Vec0DiskannCandidate tmp = list->items[i];
+ int j = i - 1;
+ while (j >= 0 && list->items[j].distance > tmp.distance) {
+ list->items[j + 1] = list->items[j];
+ j--;
+ }
+ list->items[j + 1] = tmp;
+ }
+ return 1;
+ }
+ }
+
+ // If at capacity, check if new candidate is better than worst
+ if (list->count >= list->capacity) {
+ if (distance >= list->items[list->count - 1].distance) {
+ return 0; // Discard
+ }
+ list->count--; // Make room by dropping the worst
+ }
+
+ // Binary search for insertion point
+ int lo = 0, hi = list->count;
+ while (lo < hi) {
+ int mid = (lo + hi) / 2;
+ if (list->items[mid].distance < distance) {
+ lo = mid + 1;
+ } else {
+ hi = mid;
+ }
+ }
+
+ // Shift elements to make room
+ memmove(&list->items[lo + 1], &list->items[lo],
+ (list->count - lo) * sizeof(struct Vec0DiskannCandidate));
+
+ list->items[lo].rowid = rowid;
+ list->items[lo].distance = distance;
+ list->items[lo].visited = 0;
+ list->count++;
+ return 1;
+}
+
+/**
+ * Find the closest unvisited candidate. Returns its index, or -1 if none.
+ */
+static int diskann_candidate_list_next_unvisited(
+ const struct DiskannCandidateList *list) {
+ for (int i = 0; i < list->count; i++) {
+ if (!list->items[i].visited) return i;
+ }
+ return -1;
+}
+
+
+
+/**
+ * Simple hash set for tracking visited rowids during search.
+ * Uses open addressing with linear probing.
+ */
+struct DiskannVisitedSet {
+ i64 *slots;
+ int capacity;
+ int count;
+};
+
+static int diskann_visited_set_init(struct DiskannVisitedSet *set, int capacity) {
+ // Round up to power of 2
+ int cap = 16;
+ while (cap < capacity) cap *= 2;
+ set->slots = sqlite3_malloc(cap * sizeof(i64));
+ if (!set->slots) return SQLITE_NOMEM;
+ memset(set->slots, 0, cap * sizeof(i64));
+ set->capacity = cap;
+ set->count = 0;
+ return SQLITE_OK;
+}
+
+static void diskann_visited_set_free(struct DiskannVisitedSet *set) {
+ sqlite3_free(set->slots);
+ set->slots = NULL;
+ set->capacity = 0;
+ set->count = 0;
+}
+
+static int diskann_visited_set_contains(const struct DiskannVisitedSet *set, i64 rowid) {
+ if (rowid == 0) return 0; // 0 is our sentinel for empty
+ int mask = set->capacity - 1;
+ int idx = (int)(((u64)rowid * 0x9E3779B97F4A7C15ULL) >> 32) & mask;
+ for (int i = 0; i < set->capacity; i++) {
+ int slot = (idx + i) & mask;
+ if (set->slots[slot] == 0) return 0;
+ if (set->slots[slot] == rowid) return 1;
+ }
+ return 0;
+}
+
+static int diskann_visited_set_insert(struct DiskannVisitedSet *set, i64 rowid) {
+ if (rowid == 0) return 0;
+ int mask = set->capacity - 1;
+ int idx = (int)(((u64)rowid * 0x9E3779B97F4A7C15ULL) >> 32) & mask;
+ for (int i = 0; i < set->capacity; i++) {
+ int slot = (idx + i) & mask;
+ if (set->slots[slot] == 0) {
+ set->slots[slot] = rowid;
+ set->count++;
+ return 1;
+ }
+ if (set->slots[slot] == rowid) return 0; // Already there
+ }
+ return 0; // Full (shouldn't happen with proper sizing)
+}
+
+// ============================================================
+// DiskANN greedy beam search (LM-Search)
+// ============================================================
+
+/**
+ * Perform LM-Search: greedy beam search over the DiskANN graph.
+ * Follows Algorithm 1 from the LM-DiskANN paper.
+ */
+static int diskann_search(
+ vec0_vtab *p, int vec_col_idx,
+ const void *queryVector, size_t dimensions,
+ enum VectorElementType elementType,
+ int k, int searchListSize,
+ i64 *outRowids, f32 *outDistances, int *outCount) {
+
+ struct VectorColumnDefinition *col = &p->vector_columns[vec_col_idx];
+ struct Vec0DiskannConfig *cfg = &col->diskann;
+ int rc;
+
+ if (searchListSize <= 0) {
+ searchListSize = cfg->search_list_size_search > 0 ? cfg->search_list_size_search : cfg->search_list_size;
+ }
+ if (searchListSize < k) {
+ searchListSize = k;
+ }
+
+ // 1. Get the medoid (entry point)
+ i64 medoid;
+ int isEmpty;
+ rc = diskann_medoid_get(p, vec_col_idx, &medoid, &isEmpty);
+ if (rc != SQLITE_OK) return rc;
+ if (isEmpty) {
+ *outCount = 0;
+ return SQLITE_OK;
+ }
+
+ // 2. Compute distance from query to medoid using full-precision vector
+ void *medoidVector = NULL;
+ int medoidVectorSize;
+ rc = diskann_vector_read(p, vec_col_idx, medoid, &medoidVector, &medoidVectorSize);
+ if (rc != SQLITE_OK) return rc;
+
+ f32 medoidDist = vec0_distance_full(queryVector, medoidVector,
+ dimensions, elementType,
+ col->distance_metric);
+ sqlite3_free(medoidVector);
+
+ // 3. Initialize candidate list and visited set
+ struct DiskannCandidateList candidates;
+ rc = diskann_candidate_list_init(&candidates, searchListSize);
+ if (rc != SQLITE_OK) return rc;
+
+ struct DiskannVisitedSet visited;
+ rc = diskann_visited_set_init(&visited, searchListSize * 4);
+ if (rc != SQLITE_OK) {
+ diskann_candidate_list_free(&candidates);
+ return rc;
+ }
+
+ // Seed with medoid
+ diskann_candidate_list_insert(&candidates, medoid, medoidDist);
+
+ // Pre-quantize query vector once for all quantized distance comparisons
+ u8 *queryQuantized = NULL;
+ if (elementType == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) {
+ queryQuantized = diskann_quantize_query(
+ (const f32 *)queryVector, dimensions, cfg->quantizer_type);
+ }
+
+ // 4. Greedy beam search loop (Algorithm 1 from LM-DiskANN paper)
+ while (1) {
+ int nextIdx = diskann_candidate_list_next_unvisited(&candidates);
+ if (nextIdx < 0) break;
+
+ struct Vec0DiskannCandidate *current = &candidates.items[nextIdx];
+ current->visited = 1;
+ i64 currentRowid = current->rowid;
+
+ // Read the node's neighbor data
+ u8 *validity = NULL, *neighborIds = NULL, *qvecs = NULL;
+ int validitySize, neighborIdsSize, qvecsSize;
+ rc = diskann_node_read(p, vec_col_idx, currentRowid,
+ &validity, &validitySize,
+ &neighborIds, &neighborIdsSize,
+ &qvecs, &qvecsSize);
+ if (rc != SQLITE_OK) {
+ continue; // Skip if node doesn't exist
+ }
+
+ // Insert all valid neighbors with approximate (quantized) distances
+ for (int i = 0; i < cfg->n_neighbors; i++) {
+ if (!diskann_validity_get(validity, i)) continue;
+
+ i64 neighborRowid = diskann_neighbor_id_get(neighborIds, i);
+
+ if (diskann_visited_set_contains(&visited, neighborRowid)) continue;
+
+ const u8 *neighborQvec = diskann_neighbor_qvec_get(
+ qvecs, i, cfg->quantizer_type, dimensions);
+
+ f32 approxDist;
+ if (queryQuantized) {
+ approxDist = diskann_distance_quantized_precomputed(
+ queryQuantized, neighborQvec, dimensions,
+ cfg->quantizer_type, col->distance_metric);
+ } else {
+ approxDist = diskann_distance_quantized(
+ queryVector, neighborQvec, dimensions,
+ cfg->quantizer_type, col->distance_metric);
+ }
+
+ diskann_candidate_list_insert(&candidates, neighborRowid, approxDist);
+ }
+
+ sqlite3_free(validity);
+ sqlite3_free(neighborIds);
+ sqlite3_free(qvecs);
+
+ // Add to visited set
+ diskann_visited_set_insert(&visited, currentRowid);
+
+ // Paper line 13: Re-rank p* using full-precision distance
+ // We already have exact distance for medoid; for others, update now
+ void *fullVec = NULL;
+ int fullVecSize;
+ rc = diskann_vector_read(p, vec_col_idx, currentRowid, &fullVec, &fullVecSize);
+ if (rc == SQLITE_OK) {
+ f32 exactDist = vec0_distance_full(queryVector, fullVec,
+ dimensions, elementType,
+ col->distance_metric);
+ sqlite3_free(fullVec);
+ // Update distance in candidate list and re-sort
+ diskann_candidate_list_insert(&candidates, currentRowid, exactDist);
+ }
+ }
+
+ // 5. Output results (candidates are already sorted by distance)
+ int resultCount = (candidates.count < k) ? candidates.count : k;
+ *outCount = resultCount;
+ for (int i = 0; i < resultCount; i++) {
+ outRowids[i] = candidates.items[i].rowid;
+ outDistances[i] = candidates.items[i].distance;
+ }
+
+ sqlite3_free(queryQuantized);
+ diskann_candidate_list_free(&candidates);
+ diskann_visited_set_free(&visited);
+ return SQLITE_OK;
+}
+
+// ============================================================
+// DiskANN RobustPrune (Algorithm 4 from LM-DiskANN paper)
+// ============================================================
+
+/**
+ * RobustPrune: Select up to max_neighbors neighbors for node p from a
+ * candidate set, using alpha-pruning for diversity.
+ *
+ * Following Algorithm 4 (LM-Prune):
+ * C = C union N_out(p) \ {p}
+ * N_out(p) = empty
+ * while C not empty:
+ * p* = argmin d(p, c) for c in C
+ * N_out(p).insert(p*)
+ * if |N_out(p)| >= R: break
+ * for each p' in C:
+ * if alpha * d(p*, p') <= d(p, p'): remove p' from C
+ */
+/**
+ * Pure function: given pre-sorted candidates and a distance matrix, select
+ * up to max_neighbors using alpha-pruning. inter_distances is a flattened
+ * num_candidates x num_candidates matrix where inter_distances[i*num_candidates+j]
+ * = d(candidate_i, candidate_j). p_distances[i] = d(p, candidate_i), already sorted.
+ * outSelected[i] = 1 if selected. Returns count of selected.
+ */
+int diskann_prune_select(
+ const f32 *inter_distances, const f32 *p_distances,
+ int num_candidates, f32 alpha, int max_neighbors,
+ int *outSelected, int *outCount) {
+
+ if (num_candidates == 0) {
+ *outCount = 0;
+ return SQLITE_OK;
+ }
+
+ u8 *active = sqlite3_malloc(num_candidates);
+ if (!active) return SQLITE_NOMEM;
+ memset(active, 1, num_candidates);
+ memset(outSelected, 0, num_candidates * sizeof(int));
+
+ int selectedCount = 0;
+
+ for (int round = 0; round < num_candidates && selectedCount < max_neighbors; round++) {
+ int bestIdx = -1;
+ for (int i = 0; i < num_candidates; i++) {
+ if (active[i]) { bestIdx = i; break; }
+ }
+ if (bestIdx < 0) break;
+
+ outSelected[bestIdx] = 1;
+ selectedCount++;
+ active[bestIdx] = 0;
+
+ for (int i = 0; i < num_candidates; i++) {
+ if (!active[i]) continue;
+ f32 dist_best_to_i = inter_distances[bestIdx * num_candidates + i];
+ if (alpha * dist_best_to_i <= p_distances[i]) {
+ active[i] = 0;
+ }
+ }
+ }
+
+ *outCount = selectedCount;
+ sqlite3_free(active);
+ return SQLITE_OK;
+}
+
+static int diskann_robust_prune(
+ vec0_vtab *p, int vec_col_idx,
+ i64 p_rowid, const void *p_vector,
+ i64 *candidates, f32 *candidate_distances, int num_candidates,
+ f32 alpha, int max_neighbors,
+ i64 *outNeighborRowids, int *outNeighborCount) {
+
+ struct VectorColumnDefinition *col = &p->vector_columns[vec_col_idx];
+ int rc;
+
+ // Remove p itself from candidates
+ for (int i = 0; i < num_candidates; i++) {
+ if (candidates[i] == p_rowid) {
+ candidates[i] = candidates[num_candidates - 1];
+ candidate_distances[i] = candidate_distances[num_candidates - 1];
+ num_candidates--;
+ break;
+ }
+ }
+
+ if (num_candidates == 0) {
+ *outNeighborCount = 0;
+ return SQLITE_OK;
+ }
+
+ // Sort candidates by distance to p (ascending) - insertion sort
+ for (int i = 1; i < num_candidates; i++) {
+ f32 tmpDist = candidate_distances[i];
+ i64 tmpRowid = candidates[i];
+ int j = i - 1;
+ while (j >= 0 && candidate_distances[j] > tmpDist) {
+ candidate_distances[j + 1] = candidate_distances[j];
+ candidates[j + 1] = candidates[j];
+ j--;
+ }
+ candidate_distances[j + 1] = tmpDist;
+ candidates[j + 1] = tmpRowid;
+ }
+
+ // Active flags
+ u8 *active = sqlite3_malloc(num_candidates);
+ if (!active) return SQLITE_NOMEM;
+ memset(active, 1, num_candidates);
+
+ // Cache full-precision vectors for inter-candidate distance
+ void **candidateVectors = sqlite3_malloc(num_candidates * sizeof(void *));
+ if (!candidateVectors) {
+ sqlite3_free(active);
+ return SQLITE_NOMEM;
+ }
+ memset(candidateVectors, 0, num_candidates * sizeof(void *));
+
+ int selectedCount = 0;
+
+ for (int round = 0; round < num_candidates && selectedCount < max_neighbors; round++) {
+ // Find closest active candidate
+ int bestIdx = -1;
+ for (int i = 0; i < num_candidates; i++) {
+ if (active[i]) { bestIdx = i; break; }
+ }
+ if (bestIdx < 0) break;
+
+ // Select this candidate
+ outNeighborRowids[selectedCount] = candidates[bestIdx];
+ selectedCount++;
+ active[bestIdx] = 0;
+
+ // Load selected candidate's vector
+ if (!candidateVectors[bestIdx]) {
+ int vecSize;
+ rc = diskann_vector_read(p, vec_col_idx, candidates[bestIdx],
+ &candidateVectors[bestIdx], &vecSize);
+ if (rc != SQLITE_OK) continue;
+ }
+
+ // Alpha-prune: remove candidates covered by the selected neighbor
+ for (int i = 0; i < num_candidates; i++) {
+ if (!active[i]) continue;
+
+ if (!candidateVectors[i]) {
+ int vecSize;
+ rc = diskann_vector_read(p, vec_col_idx, candidates[i],
+ &candidateVectors[i], &vecSize);
+ if (rc != SQLITE_OK) continue;
+ }
+
+ f32 dist_selected_to_i = vec0_distance_full(
+ candidateVectors[bestIdx], candidateVectors[i],
+ col->dimensions, col->element_type, col->distance_metric);
+
+ if (alpha * dist_selected_to_i <= candidate_distances[i]) {
+ active[i] = 0;
+ }
+ }
+ }
+
+ *outNeighborCount = selectedCount;
+
+ for (int i = 0; i < num_candidates; i++) {
+ sqlite3_free(candidateVectors[i]);
+ }
+ sqlite3_free(candidateVectors);
+ sqlite3_free(active);
+
+ return SQLITE_OK;
+}
+
+/**
+ * After RobustPrune selects neighbors, build the node blobs and write to DB.
+ * Quantizes each neighbor's vector and packs into the node format.
+ */
+static int diskann_write_pruned_neighbors(
+ vec0_vtab *p, int vec_col_idx, i64 nodeRowid,
+ const i64 *neighborRowids, int neighborCount) {
+
+ struct VectorColumnDefinition *col = &p->vector_columns[vec_col_idx];
+ struct Vec0DiskannConfig *cfg = &col->diskann;
+ int rc;
+
+ u8 *validity, *neighborIds, *qvecs;
+ int validitySize, neighborIdsSize, qvecsSize;
+ rc = diskann_node_init(cfg->n_neighbors, cfg->quantizer_type,
+ col->dimensions,
+ &validity, &validitySize,
+ &neighborIds, &neighborIdsSize,
+ &qvecs, &qvecsSize);
+ if (rc != SQLITE_OK) return rc;
+
+ size_t qvecSize = diskann_quantized_vector_byte_size(
+ cfg->quantizer_type, col->dimensions);
+ u8 *qvec = sqlite3_malloc(qvecSize);
+ if (!qvec) {
+ sqlite3_free(validity);
+ sqlite3_free(neighborIds);
+ sqlite3_free(qvecs);
+ return SQLITE_NOMEM;
+ }
+
+ for (int i = 0; i < neighborCount && i < cfg->n_neighbors; i++) {
+ void *neighborVec = NULL;
+ int neighborVecSize;
+ rc = diskann_vector_read(p, vec_col_idx, neighborRowids[i],
+ &neighborVec, &neighborVecSize);
+ if (rc != SQLITE_OK) continue;
+
+ if (col->element_type == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) {
+ diskann_quantize_vector((const f32 *)neighborVec, col->dimensions,
+ cfg->quantizer_type, qvec);
+ } else {
+ memcpy(qvec, neighborVec,
+ qvecSize < (size_t)neighborVecSize ? qvecSize : (size_t)neighborVecSize);
+ }
+
+ diskann_node_set_neighbor(validity, neighborIds, qvecs, i,
+ neighborRowids[i], qvec,
+ cfg->quantizer_type, col->dimensions);
+
+ sqlite3_free(neighborVec);
+ }
+ sqlite3_free(qvec);
+
+ rc = diskann_node_write(p, vec_col_idx, nodeRowid,
+ validity, validitySize,
+ neighborIds, neighborIdsSize,
+ qvecs, qvecsSize);
+
+ sqlite3_free(validity);
+ sqlite3_free(neighborIds);
+ sqlite3_free(qvecs);
+ return rc;
+}
+
+// ============================================================
+// DiskANN insert (Algorithm 2 from LM-DiskANN paper)
+// ============================================================
+
+/**
+ * Add a reverse edge: make target_rowid a neighbor of node_rowid.
+ * If node is full, run RobustPrune.
+ */
+static int diskann_add_reverse_edge(
+ vec0_vtab *p, int vec_col_idx,
+ i64 node_rowid, i64 target_rowid, const void *target_vector) {
+
+ struct VectorColumnDefinition *col = &p->vector_columns[vec_col_idx];
+ struct Vec0DiskannConfig *cfg = &col->diskann;
+ int rc;
+
+ u8 *validity = NULL, *neighborIds = NULL, *qvecs = NULL;
+ int validitySize, neighborIdsSize, qvecsSize;
+ rc = diskann_node_read(p, vec_col_idx, node_rowid,
+ &validity, &validitySize,
+ &neighborIds, &neighborIdsSize,
+ &qvecs, &qvecsSize);
+ if (rc != SQLITE_OK) return rc;
+
+ int currentCount = diskann_validity_count(validity, cfg->n_neighbors);
+
+ // Check if target is already a neighbor
+ for (int i = 0; i < cfg->n_neighbors; i++) {
+ if (diskann_validity_get(validity, i) &&
+ diskann_neighbor_id_get(neighborIds, i) == target_rowid) {
+ sqlite3_free(validity);
+ sqlite3_free(neighborIds);
+ sqlite3_free(qvecs);
+ return SQLITE_OK;
+ }
+ }
+
+ if (currentCount < cfg->n_neighbors) {
+ // Room available: find first empty slot
+ for (int i = 0; i < cfg->n_neighbors; i++) {
+ if (!diskann_validity_get(validity, i)) {
+ size_t qvecSize = diskann_quantized_vector_byte_size(
+ cfg->quantizer_type, col->dimensions);
+ u8 *qvec = sqlite3_malloc(qvecSize);
+ if (!qvec) {
+ sqlite3_free(validity);
+ sqlite3_free(neighborIds);
+ sqlite3_free(qvecs);
+ return SQLITE_NOMEM;
+ }
+
+ if (col->element_type == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) {
+ diskann_quantize_vector((const f32 *)target_vector, col->dimensions,
+ cfg->quantizer_type, qvec);
+ } else {
+ size_t vbs = vector_column_byte_size(*col);
+ memcpy(qvec, target_vector, qvecSize < vbs ? qvecSize : vbs);
+ }
+
+ diskann_node_set_neighbor(validity, neighborIds, qvecs, i,
+ target_rowid, qvec,
+ cfg->quantizer_type, col->dimensions);
+ sqlite3_free(qvec);
+ break;
+ }
+ }
+
+ rc = diskann_node_write(p, vec_col_idx, node_rowid,
+ validity, validitySize,
+ neighborIds, neighborIdsSize,
+ qvecs, qvecsSize);
+ } else {
+ // Full: lazy replacement — use quantized distances to find the worst
+ // existing neighbor and replace it if target is closer. This avoids
+ // reading all neighbors' float vectors (the expensive RobustPrune path).
+
+ // Quantize the node's vector and the target vector for comparison
+ void *nodeVector = NULL;
+ int nodeVecSize;
+ rc = diskann_vector_read(p, vec_col_idx, node_rowid,
+ &nodeVector, &nodeVecSize);
+ if (rc != SQLITE_OK) {
+ sqlite3_free(validity);
+ sqlite3_free(neighborIds);
+ sqlite3_free(qvecs);
+ return rc;
+ }
+
+ // Quantize target for node-level comparison
+ size_t qvecSize = diskann_quantized_vector_byte_size(
+ cfg->quantizer_type, col->dimensions);
+ u8 *targetQ = sqlite3_malloc(qvecSize);
+ u8 *nodeQ = sqlite3_malloc(qvecSize);
+ if (!targetQ || !nodeQ) {
+ sqlite3_free(targetQ);
+ sqlite3_free(nodeQ);
+ sqlite3_free(nodeVector);
+ sqlite3_free(validity);
+ sqlite3_free(neighborIds);
+ sqlite3_free(qvecs);
+ return SQLITE_NOMEM;
+ }
+
+ if (col->element_type == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) {
+ diskann_quantize_vector((const f32 *)target_vector, col->dimensions,
+ cfg->quantizer_type, targetQ);
+ diskann_quantize_vector((const f32 *)nodeVector, col->dimensions,
+ cfg->quantizer_type, nodeQ);
+ } else {
+ memcpy(targetQ, target_vector, qvecSize);
+ memcpy(nodeQ, nodeVector, qvecSize);
+ }
+
+ // Compute quantized distance from node to target
+ f32 targetDist = diskann_distance_quantized_precomputed(
+ nodeQ, targetQ, col->dimensions,
+ cfg->quantizer_type, col->distance_metric);
+
+ // Find the worst (farthest) existing neighbor using quantized distances
+ int worstIdx = -1;
+ f32 worstDist = -1.0f;
+ for (int i = 0; i < cfg->n_neighbors; i++) {
+ if (!diskann_validity_get(validity, i)) continue;
+ const u8 *nqvec = diskann_neighbor_qvec_get(
+ qvecs, i, cfg->quantizer_type, col->dimensions);
+ f32 d = diskann_distance_quantized_precomputed(
+ nodeQ, nqvec, col->dimensions,
+ cfg->quantizer_type, col->distance_metric);
+ if (d > worstDist) {
+ worstDist = d;
+ worstIdx = i;
+ }
+ }
+
+ // Replace worst neighbor if target is closer
+ if (worstIdx >= 0 && targetDist < worstDist) {
+ diskann_node_set_neighbor(validity, neighborIds, qvecs, worstIdx,
+ target_rowid, targetQ,
+ cfg->quantizer_type, col->dimensions);
+ rc = diskann_node_write(p, vec_col_idx, node_rowid,
+ validity, validitySize,
+ neighborIds, neighborIdsSize,
+ qvecs, qvecsSize);
+ } else {
+ rc = SQLITE_OK; // target is farther than all existing neighbors, skip
+ }
+
+ sqlite3_free(targetQ);
+ sqlite3_free(nodeQ);
+ sqlite3_free(nodeVector);
+ }
+
+ sqlite3_free(validity);
+ sqlite3_free(neighborIds);
+ sqlite3_free(qvecs);
+ return rc;
+}
+
+// ============================================================
+// DiskANN buffer helpers (for batched inserts)
+// ============================================================
+
+/**
+ * Insert a vector into the _diskann_buffer table.
+ */
+static int diskann_buffer_write(vec0_vtab *p, int vec_col_idx,
+ i64 rowid, const void *vector, int vectorSize) {
+ sqlite3_stmt *stmt = NULL;
+ char *zSql = sqlite3_mprintf(
+ "INSERT INTO " VEC0_SHADOW_DISKANN_BUFFER_N_NAME
+ " (rowid, vector) VALUES (?, ?)",
+ p->schemaName, p->tableName, vec_col_idx);
+ if (!zSql) return SQLITE_NOMEM;
+ int rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL);
+ sqlite3_free(zSql);
+ if (rc != SQLITE_OK) return rc;
+ sqlite3_bind_int64(stmt, 1, rowid);
+ sqlite3_bind_blob(stmt, 2, vector, vectorSize, SQLITE_STATIC);
+ rc = sqlite3_step(stmt);
+ sqlite3_finalize(stmt);
+ return (rc == SQLITE_DONE) ? SQLITE_OK : SQLITE_ERROR;
+}
+
+/**
+ * Delete a vector from the _diskann_buffer table.
+ */
+static int diskann_buffer_delete(vec0_vtab *p, int vec_col_idx, i64 rowid) {
+ sqlite3_stmt *stmt = NULL;
+ char *zSql = sqlite3_mprintf(
+ "DELETE FROM " VEC0_SHADOW_DISKANN_BUFFER_N_NAME " WHERE rowid = ?",
+ p->schemaName, p->tableName, vec_col_idx);
+ if (!zSql) return SQLITE_NOMEM;
+ int rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL);
+ sqlite3_free(zSql);
+ if (rc != SQLITE_OK) return rc;
+ sqlite3_bind_int64(stmt, 1, rowid);
+ rc = sqlite3_step(stmt);
+ sqlite3_finalize(stmt);
+ return (rc == SQLITE_DONE) ? SQLITE_OK : SQLITE_ERROR;
+}
+
+/**
+ * Check if a rowid exists in the _diskann_buffer table.
+ * Returns SQLITE_OK and sets *exists to 1 if found, 0 if not.
+ */
+static int diskann_buffer_exists(vec0_vtab *p, int vec_col_idx,
+ i64 rowid, int *exists) {
+ sqlite3_stmt *stmt = NULL;
+ char *zSql = sqlite3_mprintf(
+ "SELECT 1 FROM " VEC0_SHADOW_DISKANN_BUFFER_N_NAME " WHERE rowid = ?",
+ p->schemaName, p->tableName, vec_col_idx);
+ if (!zSql) return SQLITE_NOMEM;
+ int rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL);
+ sqlite3_free(zSql);
+ if (rc != SQLITE_OK) return rc;
+ sqlite3_bind_int64(stmt, 1, rowid);
+ rc = sqlite3_step(stmt);
+ *exists = (rc == SQLITE_ROW) ? 1 : 0;
+ sqlite3_finalize(stmt);
+ return SQLITE_OK;
+}
+
+/**
+ * Get the count of rows in the _diskann_buffer table.
+ */
+static int diskann_buffer_count(vec0_vtab *p, int vec_col_idx, i64 *count) {
+ sqlite3_stmt *stmt = NULL;
+ char *zSql = sqlite3_mprintf(
+ "SELECT count(*) FROM " VEC0_SHADOW_DISKANN_BUFFER_N_NAME,
+ p->schemaName, p->tableName, vec_col_idx);
+ if (!zSql) return SQLITE_NOMEM;
+ int rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL);
+ sqlite3_free(zSql);
+ if (rc != SQLITE_OK) return rc;
+ rc = sqlite3_step(stmt);
+ if (rc == SQLITE_ROW) {
+ *count = sqlite3_column_int64(stmt, 0);
+ sqlite3_finalize(stmt);
+ return SQLITE_OK;
+ }
+ sqlite3_finalize(stmt);
+ return SQLITE_ERROR;
+}
+
+// Forward declaration: diskann_insert_graph does the actual graph insertion
+static int diskann_insert_graph(vec0_vtab *p, int vec_col_idx,
+ i64 rowid, const void *vector);
+
+/**
+ * Flush all buffered vectors into the DiskANN graph.
+ * Iterates over _diskann_buffer rows and calls diskann_insert_graph for each.
+ */
+static int diskann_flush_buffer(vec0_vtab *p, int vec_col_idx) {
+ sqlite3_stmt *stmt = NULL;
+ char *zSql = sqlite3_mprintf(
+ "SELECT rowid, vector FROM " VEC0_SHADOW_DISKANN_BUFFER_N_NAME,
+ p->schemaName, p->tableName, vec_col_idx);
+ if (!zSql) return SQLITE_NOMEM;
+ int rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL);
+ sqlite3_free(zSql);
+ if (rc != SQLITE_OK) return rc;
+
+ while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) {
+ i64 rowid = sqlite3_column_int64(stmt, 0);
+ const void *vector = sqlite3_column_blob(stmt, 1);
+ // Note: vector is already written to _vectors table, so
+ // diskann_insert_graph will skip re-writing it (vector already exists).
+ // We call the graph-only insert path.
+ int insertRc = diskann_insert_graph(p, vec_col_idx, rowid, vector);
+ if (insertRc != SQLITE_OK) {
+ sqlite3_finalize(stmt);
+ return insertRc;
+ }
+ }
+ sqlite3_finalize(stmt);
+
+ // Clear the buffer
+ zSql = sqlite3_mprintf(
+ "DELETE FROM " VEC0_SHADOW_DISKANN_BUFFER_N_NAME,
+ p->schemaName, p->tableName, vec_col_idx);
+ if (!zSql) return SQLITE_NOMEM;
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL);
+ sqlite3_free(zSql);
+ if (rc != SQLITE_OK) return rc;
+ rc = sqlite3_step(stmt);
+ sqlite3_finalize(stmt);
+ return (rc == SQLITE_DONE) ? SQLITE_OK : SQLITE_ERROR;
+}
+
+/**
+ * Insert a new vector into the DiskANN graph (graph-only path).
+ * The vector must already be written to _vectors table.
+ * This is the core graph insertion logic (Algorithm 2: LM-Insert).
+ */
+static int diskann_insert_graph(vec0_vtab *p, int vec_col_idx,
+ i64 rowid, const void *vector) {
+ struct VectorColumnDefinition *col = &p->vector_columns[vec_col_idx];
+ struct Vec0DiskannConfig *cfg = &col->diskann;
+ int rc;
+
+ // Handle first insert (empty graph)
+ i64 medoid;
+ int isEmpty;
+ rc = diskann_medoid_get(p, vec_col_idx, &medoid, &isEmpty);
+ if (rc != SQLITE_OK) return rc;
+
+ if (isEmpty) {
+ u8 *validity, *neighborIds, *qvecs;
+ int validitySize, neighborIdsSize, qvecsSize;
+ rc = diskann_node_init(cfg->n_neighbors, cfg->quantizer_type,
+ col->dimensions,
+ &validity, &validitySize,
+ &neighborIds, &neighborIdsSize,
+ &qvecs, &qvecsSize);
+ if (rc != SQLITE_OK) return rc;
+
+ rc = diskann_node_write(p, vec_col_idx, rowid,
+ validity, validitySize,
+ neighborIds, neighborIdsSize,
+ qvecs, qvecsSize);
+ sqlite3_free(validity);
+ sqlite3_free(neighborIds);
+ sqlite3_free(qvecs);
+ if (rc != SQLITE_OK) return rc;
+
+ return diskann_medoid_set(p, vec_col_idx, rowid, 0);
+ }
+
+ // Search for nearest neighbors
+ int L = cfg->search_list_size_insert > 0 ? cfg->search_list_size_insert : cfg->search_list_size;
+ i64 *searchRowids = sqlite3_malloc(L * sizeof(i64));
+ f32 *searchDistances = sqlite3_malloc(L * sizeof(f32));
+ if (!searchRowids || !searchDistances) {
+ sqlite3_free(searchRowids);
+ sqlite3_free(searchDistances);
+ return SQLITE_NOMEM;
+ }
+
+ int searchCount;
+ rc = diskann_search(p, vec_col_idx, vector, col->dimensions,
+ col->element_type, L, L,
+ searchRowids, searchDistances, &searchCount);
+ if (rc != SQLITE_OK) {
+ sqlite3_free(searchRowids);
+ sqlite3_free(searchDistances);
+ return rc;
+ }
+
+ // RobustPrune to select neighbors for x
+ i64 *selectedNeighbors = sqlite3_malloc(cfg->n_neighbors * sizeof(i64));
+ int selectedCount = 0;
+ if (!selectedNeighbors) {
+ sqlite3_free(searchRowids);
+ sqlite3_free(searchDistances);
+ return SQLITE_NOMEM;
+ }
+
+ rc = diskann_robust_prune(p, vec_col_idx, rowid, vector,
+ searchRowids, searchDistances, searchCount,
+ cfg->alpha, cfg->n_neighbors,
+ selectedNeighbors, &selectedCount);
+ sqlite3_free(searchRowids);
+ sqlite3_free(searchDistances);
+ if (rc != SQLITE_OK) {
+ sqlite3_free(selectedNeighbors);
+ return rc;
+ }
+
+ // Write x's node with selected neighbors
+ rc = diskann_write_pruned_neighbors(p, vec_col_idx, rowid,
+ selectedNeighbors, selectedCount);
+ if (rc != SQLITE_OK) {
+ sqlite3_free(selectedNeighbors);
+ return rc;
+ }
+
+ // Add bidirectional edges
+ for (int i = 0; i < selectedCount; i++) {
+ diskann_add_reverse_edge(p, vec_col_idx,
+ selectedNeighbors[i], rowid, vector);
+ }
+
+ sqlite3_free(selectedNeighbors);
+ return SQLITE_OK;
+}
+
+/**
+ * Insert a new vector into the DiskANN index (Algorithm 2: LM-Insert).
+ * When buffer_threshold > 0, vectors are buffered and flushed in batch.
+ */
+static int diskann_insert(vec0_vtab *p, int vec_col_idx,
+ i64 rowid, const void *vector) {
+ struct VectorColumnDefinition *col = &p->vector_columns[vec_col_idx];
+ struct Vec0DiskannConfig *cfg = &col->diskann;
+ int rc;
+ size_t vectorSize = vector_column_byte_size(*col);
+
+ // 1. Write full-precision vector to _vectors table (always needed for queries)
+ rc = diskann_vector_write(p, vec_col_idx, rowid, vector, (int)vectorSize);
+ if (rc != SQLITE_OK) return rc;
+
+ // 2. If buffering is enabled, write to buffer instead of graph
+ if (cfg->buffer_threshold > 0) {
+ rc = diskann_buffer_write(p, vec_col_idx, rowid, vector, (int)vectorSize);
+ if (rc != SQLITE_OK) return rc;
+
+ i64 count;
+ rc = diskann_buffer_count(p, vec_col_idx, &count);
+ if (rc != SQLITE_OK) return rc;
+
+ if (count >= cfg->buffer_threshold) {
+ return diskann_flush_buffer(p, vec_col_idx);
+ }
+ return SQLITE_OK;
+ }
+
+ // 3. Legacy per-row insert directly into graph
+ return diskann_insert_graph(p, vec_col_idx, rowid, vector);
+}
+
+/**
+ * Returns 1 if ALL vector columns in this table are DiskANN-indexed.
+ */
+// ============================================================
+// DiskANN delete (Algorithm 3 from LM-DiskANN paper)
+// ============================================================
+
+static int diskann_node_delete(vec0_vtab *p, int vec_col_idx, i64 rowid) {
+ sqlite3_stmt *stmt = NULL;
+ char *zSql = sqlite3_mprintf(
+ "DELETE FROM " VEC0_SHADOW_DISKANN_NODES_N_NAME " WHERE rowid = ?",
+ p->schemaName, p->tableName, vec_col_idx);
+ if (!zSql) return SQLITE_NOMEM;
+ int rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL);
+ sqlite3_free(zSql);
+ if (rc != SQLITE_OK) return rc;
+ sqlite3_bind_int64(stmt, 1, rowid);
+ rc = sqlite3_step(stmt);
+ sqlite3_finalize(stmt);
+ return (rc == SQLITE_DONE) ? SQLITE_OK : SQLITE_ERROR;
+}
+
+static int diskann_vector_delete(vec0_vtab *p, int vec_col_idx, i64 rowid) {
+ sqlite3_stmt *stmt = NULL;
+ char *zSql = sqlite3_mprintf(
+ "DELETE FROM " VEC0_SHADOW_VECTORS_N_NAME " WHERE rowid = ?",
+ p->schemaName, p->tableName, vec_col_idx);
+ if (!zSql) return SQLITE_NOMEM;
+ int rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL);
+ sqlite3_free(zSql);
+ if (rc != SQLITE_OK) return rc;
+ sqlite3_bind_int64(stmt, 1, rowid);
+ rc = sqlite3_step(stmt);
+ sqlite3_finalize(stmt);
+ return (rc == SQLITE_DONE) ? SQLITE_OK : SQLITE_ERROR;
+}
+
+/**
+ * Repair graph after deleting a node. Following Algorithm 3 (LM-Delete):
+ * For each neighbor n of the deleted node, add deleted node's other neighbors
+ * to n's candidate set, then remove the deleted node from n's neighbor list.
+ * Uses simple slot replacement rather than full RobustPrune for performance.
+ */
+static int diskann_repair_reverse_edges(
+ vec0_vtab *p, int vec_col_idx, i64 deleted_rowid,
+ const i64 *deleted_neighbors, int deleted_neighbor_count) {
+
+ struct VectorColumnDefinition *col = &p->vector_columns[vec_col_idx];
+ struct Vec0DiskannConfig *cfg = &col->diskann;
+ int rc;
+
+ // For each neighbor of the deleted node, fix their neighbor list
+ for (int dn = 0; dn < deleted_neighbor_count; dn++) {
+ i64 nodeRowid = deleted_neighbors[dn];
+
+ u8 *validity = NULL, *neighborIds = NULL, *qvecs = NULL;
+ int vs, nis, qs;
+ rc = diskann_node_read(p, vec_col_idx, nodeRowid,
+ &validity, &vs, &neighborIds, &nis, &qvecs, &qs);
+ if (rc != SQLITE_OK) continue;
+
+ // Find and clear the deleted node's slot
+ int clearedSlot = -1;
+ for (int i = 0; i < cfg->n_neighbors; i++) {
+ if (diskann_validity_get(validity, i) &&
+ diskann_neighbor_id_get(neighborIds, i) == deleted_rowid) {
+ diskann_node_clear_neighbor(validity, neighborIds, qvecs, i,
+ cfg->quantizer_type, col->dimensions);
+ clearedSlot = i;
+ break;
+ }
+ }
+
+ if (clearedSlot >= 0) {
+ // Try to fill the cleared slot with one of the deleted node's other neighbors
+ for (int di = 0; di < deleted_neighbor_count; di++) {
+ i64 candidate = deleted_neighbors[di];
+ if (candidate == nodeRowid || candidate == deleted_rowid) continue;
+
+ // Check not already a neighbor
+ int alreadyNeighbor = 0;
+ for (int ni = 0; ni < cfg->n_neighbors; ni++) {
+ if (diskann_validity_get(validity, ni) &&
+ diskann_neighbor_id_get(neighborIds, ni) == candidate) {
+ alreadyNeighbor = 1;
+ break;
+ }
+ }
+ if (alreadyNeighbor) continue;
+
+ // Load, quantize, and set
+ void *candidateVec = NULL;
+ int cvs;
+ rc = diskann_vector_read(p, vec_col_idx, candidate, &candidateVec, &cvs);
+ if (rc != SQLITE_OK) continue;
+
+ size_t qvecSize = diskann_quantized_vector_byte_size(
+ cfg->quantizer_type, col->dimensions);
+ u8 *qvec = sqlite3_malloc(qvecSize);
+ if (qvec) {
+ if (col->element_type == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) {
+ diskann_quantize_vector((const f32 *)candidateVec, col->dimensions,
+ cfg->quantizer_type, qvec);
+ } else {
+ memcpy(qvec, candidateVec,
+ qvecSize < (size_t)cvs ? qvecSize : (size_t)cvs);
+ }
+ diskann_node_set_neighbor(validity, neighborIds, qvecs, clearedSlot,
+ candidate, qvec,
+ cfg->quantizer_type, col->dimensions);
+ sqlite3_free(qvec);
+ }
+ sqlite3_free(candidateVec);
+ break;
+ }
+
+ diskann_node_write(p, vec_col_idx, nodeRowid,
+ validity, vs, neighborIds, nis, qvecs, qs);
+ }
+
+ sqlite3_free(validity);
+ sqlite3_free(neighborIds);
+ sqlite3_free(qvecs);
+ }
+
+ return SQLITE_OK;
+}
+
+/**
+ * Delete a vector from the DiskANN graph (Algorithm 3: LM-Delete).
+ * If the vector is in the buffer (not yet flushed), just remove from buffer.
+ */
+static int diskann_delete(vec0_vtab *p, int vec_col_idx, i64 rowid) {
+ struct VectorColumnDefinition *col = &p->vector_columns[vec_col_idx];
+ struct Vec0DiskannConfig *cfg = &col->diskann;
+ int rc;
+
+ // Check if this rowid is in the buffer (not yet in graph)
+ if (cfg->buffer_threshold > 0) {
+ int inBuffer = 0;
+ rc = diskann_buffer_exists(p, vec_col_idx, rowid, &inBuffer);
+ if (rc != SQLITE_OK) return rc;
+ if (inBuffer) {
+ // Just remove from buffer and _vectors, no graph repair needed
+ rc = diskann_buffer_delete(p, vec_col_idx, rowid);
+ if (rc == SQLITE_OK) {
+ rc = diskann_vector_delete(p, vec_col_idx, rowid);
+ }
+ return rc;
+ }
+ }
+
+ // 1. Read the node to get its neighbor list
+ u8 *delValidity = NULL, *delNeighborIds = NULL, *delQvecs = NULL;
+ int dvs, dnis, dqs;
+ rc = diskann_node_read(p, vec_col_idx, rowid,
+ &delValidity, &dvs, &delNeighborIds, &dnis,
+ &delQvecs, &dqs);
+ if (rc != SQLITE_OK) {
+ return SQLITE_OK; // Node doesn't exist, nothing to do
+ }
+
+ i64 *deletedNeighbors = sqlite3_malloc(cfg->n_neighbors * sizeof(i64));
+ int deletedNeighborCount = 0;
+ if (!deletedNeighbors) {
+ sqlite3_free(delValidity);
+ sqlite3_free(delNeighborIds);
+ sqlite3_free(delQvecs);
+ return SQLITE_NOMEM;
+ }
+
+ for (int i = 0; i < cfg->n_neighbors; i++) {
+ if (diskann_validity_get(delValidity, i)) {
+ deletedNeighbors[deletedNeighborCount++] =
+ diskann_neighbor_id_get(delNeighborIds, i);
+ }
+ }
+
+ sqlite3_free(delValidity);
+ sqlite3_free(delNeighborIds);
+ sqlite3_free(delQvecs);
+
+ // 2. Repair reverse edges
+ rc = diskann_repair_reverse_edges(p, vec_col_idx, rowid,
+ deletedNeighbors, deletedNeighborCount);
+ sqlite3_free(deletedNeighbors);
+
+ // 3. Delete node and vector
+ if (rc == SQLITE_OK) {
+ rc = diskann_node_delete(p, vec_col_idx, rowid);
+ }
+ if (rc == SQLITE_OK) {
+ rc = diskann_vector_delete(p, vec_col_idx, rowid);
+ }
+
+ // 4. Handle medoid deletion
+ if (rc == SQLITE_OK) {
+ rc = diskann_medoid_handle_delete(p, vec_col_idx, rowid);
+ }
+
+ return rc;
+}
+
+static int vec0_all_columns_diskann(vec0_vtab *p) {
+ for (int i = 0; i < p->numVectorColumns; i++) {
+ if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_DISKANN) return 0;
+ }
+ return p->numVectorColumns > 0;
+}
+
+// ============================================================================
+// Command dispatch
+// ============================================================================
+
+static int diskann_handle_command(vec0_vtab *p, const char *command) {
+ int col_idx = -1;
+ for (int i = 0; i < p->numVectorColumns; i++) {
+ if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_DISKANN) { col_idx = i; break; }
+ }
+ if (col_idx < 0) return SQLITE_EMPTY;
+
+ struct Vec0DiskannConfig *cfg = &p->vector_columns[col_idx].diskann;
+
+ if (strncmp(command, "search_list_size_search=", 24) == 0) {
+ int val = atoi(command + 24);
+ if (val < 1) { vtab_set_error(&p->base, "search_list_size_search must be >= 1"); return SQLITE_ERROR; }
+ cfg->search_list_size_search = val;
+ return SQLITE_OK;
+ }
+ if (strncmp(command, "search_list_size_insert=", 24) == 0) {
+ int val = atoi(command + 24);
+ if (val < 1) { vtab_set_error(&p->base, "search_list_size_insert must be >= 1"); return SQLITE_ERROR; }
+ cfg->search_list_size_insert = val;
+ return SQLITE_OK;
+ }
+ if (strncmp(command, "search_list_size=", 17) == 0) {
+ int val = atoi(command + 17);
+ if (val < 1) { vtab_set_error(&p->base, "search_list_size must be >= 1"); return SQLITE_ERROR; }
+ cfg->search_list_size = val;
+ return SQLITE_OK;
+ }
+ return SQLITE_EMPTY;
+}
+
+#ifdef SQLITE_VEC_TEST
+// Expose internal DiskANN data structures and functions for unit testing.
+
+int _test_diskann_candidate_list_init(struct DiskannCandidateList *list, int capacity) {
+ return diskann_candidate_list_init(list, capacity);
+}
+void _test_diskann_candidate_list_free(struct DiskannCandidateList *list) {
+ diskann_candidate_list_free(list);
+}
+int _test_diskann_candidate_list_insert(struct DiskannCandidateList *list, long long rowid, float distance) {
+ return diskann_candidate_list_insert(list, (i64)rowid, (f32)distance);
+}
+int _test_diskann_candidate_list_next_unvisited(const struct DiskannCandidateList *list) {
+ return diskann_candidate_list_next_unvisited(list);
+}
+int _test_diskann_candidate_list_count(const struct DiskannCandidateList *list) {
+ return list->count;
+}
+long long _test_diskann_candidate_list_rowid(const struct DiskannCandidateList *list, int i) {
+ return (long long)list->items[i].rowid;
+}
+float _test_diskann_candidate_list_distance(const struct DiskannCandidateList *list, int i) {
+ return (float)list->items[i].distance;
+}
+void _test_diskann_candidate_list_set_visited(struct DiskannCandidateList *list, int i) {
+ list->items[i].visited = 1;
+}
+
+int _test_diskann_visited_set_init(struct DiskannVisitedSet *set, int capacity) {
+ return diskann_visited_set_init(set, capacity);
+}
+void _test_diskann_visited_set_free(struct DiskannVisitedSet *set) {
+ diskann_visited_set_free(set);
+}
+int _test_diskann_visited_set_contains(const struct DiskannVisitedSet *set, long long rowid) {
+ return diskann_visited_set_contains(set, (i64)rowid);
+}
+int _test_diskann_visited_set_insert(struct DiskannVisitedSet *set, long long rowid) {
+ return diskann_visited_set_insert(set, (i64)rowid);
+}
+#endif /* SQLITE_VEC_TEST */
+
diff --git a/sqlite-vec-rescore.c b/sqlite-vec-rescore.c
index a45f52f..ef4e692 100644
--- a/sqlite-vec-rescore.c
+++ b/sqlite-vec-rescore.c
@@ -156,21 +156,11 @@ static void rescore_quantize_float_to_bit(const float *src, uint8_t *dst,
static void rescore_quantize_float_to_int8(const float *src, int8_t *dst,
size_t dimensions) {
- float vmin = src[0], vmax = src[0];
- for (size_t i = 1; i < dimensions; i++) {
- if (src[i] < vmin) vmin = src[i];
- if (src[i] > vmax) vmax = src[i];
- }
- float range = vmax - vmin;
- if (range == 0.0f) {
- memset(dst, 0, dimensions);
- return;
- }
- float scale = 255.0f / range;
+ float step = 2.0f / 255.0f;
for (size_t i = 0; i < dimensions; i++) {
- float v = (src[i] - vmin) * scale - 128.0f;
- if (v < -128.0f) v = -128.0f;
- if (v > 127.0f) v = 127.0f;
+ float v = (src[i] - (-1.0f)) / step - 128.0f;
+ if (!(v <= 127.0f)) v = 127.0f;
+ if (!(v >= -128.0f)) v = -128.0f;
dst[i] = (int8_t)v;
}
}
diff --git a/sqlite-vec.c b/sqlite-vec.c
index 015792b..5ca7834 100644
--- a/sqlite-vec.c
+++ b/sqlite-vec.c
@@ -61,6 +61,10 @@ SQLITE_EXTENSION_INIT1
#define LONGDOUBLE_TYPE long double
#endif
+#ifndef SQLITE_VEC_ENABLE_DISKANN
+#define SQLITE_VEC_ENABLE_DISKANN 1
+#endif
+
#ifndef _WIN32
#ifndef __EMSCRIPTEN__
#ifndef __COSMOPOLITAN__
@@ -2544,6 +2548,7 @@ enum Vec0IndexType {
VEC0_INDEX_TYPE_RESCORE = 2,
#endif
VEC0_INDEX_TYPE_IVF = 3,
+ VEC0_INDEX_TYPE_DISKANN = 4,
};
#if SQLITE_VEC_ENABLE_RESCORE
@@ -2575,6 +2580,75 @@ struct Vec0IvfConfig {
struct Vec0IvfConfig { char _unused; };
#endif
+// ============================================================
+// DiskANN types and constants
+// ============================================================
+
+#define VEC0_DISKANN_DEFAULT_N_NEIGHBORS 72
+#define VEC0_DISKANN_MAX_N_NEIGHBORS 256
+#define VEC0_DISKANN_DEFAULT_SEARCH_LIST_SIZE 128
+#define VEC0_DISKANN_DEFAULT_ALPHA 1.2f
+
+/**
+ * Quantizer type used for compressing neighbor vectors in the DiskANN graph.
+ */
+enum Vec0DiskannQuantizerType {
+ VEC0_DISKANN_QUANTIZER_BINARY = 1, // 1 bit per dimension (1/32 compression)
+ VEC0_DISKANN_QUANTIZER_INT8 = 2, // 1 byte per dimension (1/4 compression)
+};
+
+/**
+ * Configuration for a DiskANN index on a single vector column.
+ * Parsed from `INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=72)`.
+ */
+struct Vec0DiskannConfig {
+ // Quantizer type for neighbor vectors
+ enum Vec0DiskannQuantizerType quantizer_type;
+
+ // Maximum number of neighbors per node (R in the paper). Must be divisible by 8.
+ int n_neighbors;
+
+ // Search list size (L in the paper) — unified default for both insert and query.
+ int search_list_size;
+
+ // Per-path overrides (0 = fall back to search_list_size).
+ int search_list_size_search;
+ int search_list_size_insert;
+
+ // Alpha parameter for RobustPrune (distance scaling factor, typically 1.0-1.5)
+ f32 alpha;
+
+ // Buffer threshold for batched inserts. When > 0, inserts go into a flat
+ // buffer table and are flushed into the graph when the buffer reaches this
+ // size. 0 = disabled (legacy per-row insert behavior).
+ int buffer_threshold;
+};
+
+/**
+ * Represents a single candidate during greedy beam search.
+ * Used in priority queues / sorted arrays during LM-Search.
+ */
+struct Vec0DiskannCandidate {
+ i64 rowid;
+ f32 distance;
+ int visited; // 1 if this candidate's neighbors have been explored
+};
+
+/**
+ * Returns the byte size of a quantized vector for the given quantizer type
+ * and number of dimensions.
+ */
+size_t diskann_quantized_vector_byte_size(
+ enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions) {
+ switch (quantizer_type) {
+ case VEC0_DISKANN_QUANTIZER_BINARY:
+ return dimensions / CHAR_BIT; // 1 bit per dimension
+ case VEC0_DISKANN_QUANTIZER_INT8:
+ return dimensions * sizeof(i8); // 1 byte per dimension
+ }
+ return 0;
+}
+
struct VectorColumnDefinition {
char *name;
int name_length;
@@ -2586,6 +2660,7 @@ struct VectorColumnDefinition {
struct Vec0RescoreConfig rescore;
#endif
struct Vec0IvfConfig ivf;
+ struct Vec0DiskannConfig diskann;
};
struct Vec0PartitionColumnDefinition {
@@ -2743,6 +2818,126 @@ static int vec0_parse_ivf_options(struct Vec0Scanner *scanner,
struct Vec0IvfConfig *config);
#endif
+/**
+ * Parse the options inside diskann(...) parentheses.
+ * Scanner should be positioned right before the '(' token.
+ *
+ * Recognized options:
+ * neighbor_quantizer = binary | int8 (required)
+ * n_neighbors = (optional, default 72)
+ * search_list_size = (optional, default 128)
+ */
+static int vec0_parse_diskann_options(struct Vec0Scanner *scanner,
+ struct Vec0DiskannConfig *config) {
+ int rc;
+ struct Vec0Token token;
+ int hasQuantizer = 0;
+
+ // Set defaults
+ config->n_neighbors = VEC0_DISKANN_DEFAULT_N_NEIGHBORS;
+ config->search_list_size = VEC0_DISKANN_DEFAULT_SEARCH_LIST_SIZE;
+ config->search_list_size_search = 0;
+ config->search_list_size_insert = 0;
+ config->alpha = VEC0_DISKANN_DEFAULT_ALPHA;
+ config->buffer_threshold = 0;
+ int hasSearchListSize = 0;
+ int hasSearchListSizeSplit = 0;
+
+ // Expect '('
+ rc = vec0_scanner_next(scanner, &token);
+ if (rc != VEC0_TOKEN_RESULT_SOME || token.token_type != TOKEN_TYPE_LPAREN) {
+ return SQLITE_ERROR;
+ }
+
+ while (1) {
+ // key
+ rc = vec0_scanner_next(scanner, &token);
+ if (rc == VEC0_TOKEN_RESULT_SOME && token.token_type == TOKEN_TYPE_RPAREN) {
+ break; // empty parens or trailing comma
+ }
+ if (rc != VEC0_TOKEN_RESULT_SOME || token.token_type != TOKEN_TYPE_IDENTIFIER) {
+ return SQLITE_ERROR;
+ }
+ char *optKey = token.start;
+ int optKeyLen = token.end - token.start;
+
+ // '='
+ rc = vec0_scanner_next(scanner, &token);
+ if (rc != VEC0_TOKEN_RESULT_SOME || token.token_type != TOKEN_TYPE_EQ) {
+ return SQLITE_ERROR;
+ }
+
+ // value (identifier or digit)
+ rc = vec0_scanner_next(scanner, &token);
+ if (rc != VEC0_TOKEN_RESULT_SOME) {
+ return SQLITE_ERROR;
+ }
+ char *optVal = token.start;
+ int optValLen = token.end - token.start;
+
+ if (sqlite3_strnicmp(optKey, "neighbor_quantizer", optKeyLen) == 0) {
+ if (sqlite3_strnicmp(optVal, "binary", optValLen) == 0) {
+ config->quantizer_type = VEC0_DISKANN_QUANTIZER_BINARY;
+ } else if (sqlite3_strnicmp(optVal, "int8", optValLen) == 0) {
+ config->quantizer_type = VEC0_DISKANN_QUANTIZER_INT8;
+ } else {
+ return SQLITE_ERROR; // unknown quantizer
+ }
+ hasQuantizer = 1;
+ } else if (sqlite3_strnicmp(optKey, "n_neighbors", optKeyLen) == 0) {
+ config->n_neighbors = atoi(optVal);
+ if (config->n_neighbors <= 0 || (config->n_neighbors % 8) != 0 ||
+ config->n_neighbors > VEC0_DISKANN_MAX_N_NEIGHBORS) {
+ return SQLITE_ERROR;
+ }
+ } else if (sqlite3_strnicmp(optKey, "search_list_size_search", optKeyLen) == 0 && optKeyLen == 23) {
+ config->search_list_size_search = atoi(optVal);
+ if (config->search_list_size_search <= 0) {
+ return SQLITE_ERROR;
+ }
+ hasSearchListSizeSplit = 1;
+ } else if (sqlite3_strnicmp(optKey, "search_list_size_insert", optKeyLen) == 0 && optKeyLen == 23) {
+ config->search_list_size_insert = atoi(optVal);
+ if (config->search_list_size_insert <= 0) {
+ return SQLITE_ERROR;
+ }
+ hasSearchListSizeSplit = 1;
+ } else if (sqlite3_strnicmp(optKey, "search_list_size", optKeyLen) == 0) {
+ config->search_list_size = atoi(optVal);
+ if (config->search_list_size <= 0) {
+ return SQLITE_ERROR;
+ }
+ hasSearchListSize = 1;
+ } else if (sqlite3_strnicmp(optKey, "buffer_threshold", optKeyLen) == 0) {
+ config->buffer_threshold = atoi(optVal);
+ if (config->buffer_threshold < 0) {
+ return SQLITE_ERROR;
+ }
+ } else {
+ return SQLITE_ERROR; // unknown option
+ }
+
+ // Expect ',' or ')'
+ rc = vec0_scanner_next(scanner, &token);
+ if (rc == VEC0_TOKEN_RESULT_SOME && token.token_type == TOKEN_TYPE_RPAREN) {
+ break;
+ }
+ if (rc != VEC0_TOKEN_RESULT_SOME || token.token_type != TOKEN_TYPE_COMMA) {
+ return SQLITE_ERROR;
+ }
+ }
+
+ if (!hasQuantizer) {
+ return SQLITE_ERROR; // neighbor_quantizer is required
+ }
+
+ if (hasSearchListSize && hasSearchListSizeSplit) {
+ return SQLITE_ERROR; // cannot mix search_list_size with search_list_size_search/insert
+ }
+
+ return SQLITE_OK;
+}
+
int vec0_parse_vector_column(const char *source, int source_length,
struct VectorColumnDefinition *outColumn) {
// parses a vector column definition like so:
@@ -2763,8 +2958,9 @@ int vec0_parse_vector_column(const char *source, int source_length,
#endif
struct Vec0IvfConfig ivfConfig;
memset(&ivfConfig, 0, sizeof(ivfConfig));
+ struct Vec0DiskannConfig diskannConfig;
+ memset(&diskannConfig, 0, sizeof(diskannConfig));
int dimensions;
-
vec0_scanner_init(&scanner, source, source_length);
// starts with an identifier
@@ -2931,6 +3127,16 @@ int vec0_parse_vector_column(const char *source, int source_length,
}
#else
return SQLITE_ERROR; // IVF not compiled in
+#endif
+ } else if (sqlite3_strnicmp(token.start, "diskann", indexNameLen) == 0) {
+#if SQLITE_VEC_ENABLE_DISKANN
+ indexType = VEC0_INDEX_TYPE_DISKANN;
+ rc = vec0_parse_diskann_options(&scanner, &diskannConfig);
+ if (rc != SQLITE_OK) {
+ return rc;
+ }
+#else
+ return SQLITE_ERROR;
#endif
} else {
// unknown index type
@@ -2956,6 +3162,7 @@ int vec0_parse_vector_column(const char *source, int source_length,
outColumn->rescore = rescoreConfig;
#endif
outColumn->ivf = ivfConfig;
+ outColumn->diskann = diskannConfig;
return SQLITE_OK;
}
@@ -3154,6 +3361,7 @@ static sqlite3_module vec_eachModule = {
#pragma endregion
+
#pragma region vec0 virtual table
#define VEC0_COLUMN_ID 0
@@ -3214,6 +3422,9 @@ static sqlite3_module vec_eachModule = {
#define VEC0_SHADOW_AUXILIARY_NAME "\"%w\".\"%w_auxiliary\""
#define VEC0_SHADOW_METADATA_N_NAME "\"%w\".\"%w_metadatachunks%02d\""
+#define VEC0_SHADOW_VECTORS_N_NAME "\"%w\".\"%w_vectors%02d\""
+#define VEC0_SHADOW_DISKANN_NODES_N_NAME "\"%w\".\"%w_diskann_nodes%02d\""
+#define VEC0_SHADOW_DISKANN_BUFFER_N_NAME "\"%w\".\"%w_diskann_buffer%02d\""
#define VEC0_SHADOW_METADATA_TEXT_DATA_NAME "\"%w\".\"%w_metadatatext%02d\""
#define VEC_INTERAL_ERROR "Internal sqlite-vec error: "
@@ -3388,6 +3599,24 @@ struct vec0_vtab {
* Must be cleaned up with sqlite3_finalize().
*/
sqlite3_stmt *stmtRowidsGetChunkPosition;
+
+ // === DiskANN additions ===
+#if SQLITE_VEC_ENABLE_DISKANN
+ // Shadow table names for DiskANN, per vector column
+ // e.g., "{schema}"."{table}_vectors{00..15}"
+ char *shadowVectorsNames[VEC0_MAX_VECTOR_COLUMNS];
+
+ // e.g., "{schema}"."{table}_diskann_nodes{00..15}"
+ char *shadowDiskannNodesNames[VEC0_MAX_VECTOR_COLUMNS];
+
+ // Prepared statements for DiskANN operations (per vector column)
+ // These will be lazily prepared on first use.
+ sqlite3_stmt *stmtDiskannNodeRead[VEC0_MAX_VECTOR_COLUMNS];
+ sqlite3_stmt *stmtDiskannNodeWrite[VEC0_MAX_VECTOR_COLUMNS];
+ sqlite3_stmt *stmtDiskannNodeInsert[VEC0_MAX_VECTOR_COLUMNS];
+ sqlite3_stmt *stmtVectorsRead[VEC0_MAX_VECTOR_COLUMNS];
+ sqlite3_stmt *stmtVectorsInsert[VEC0_MAX_VECTOR_COLUMNS];
+#endif
};
#if SQLITE_VEC_ENABLE_RESCORE
@@ -3427,6 +3656,13 @@ void vec0_free_resources(vec0_vtab *p) {
sqlite3_finalize(p->stmtIvfRowidMapLookup[i]); p->stmtIvfRowidMapLookup[i] = NULL;
sqlite3_finalize(p->stmtIvfRowidMapDelete[i]); p->stmtIvfRowidMapDelete[i] = NULL;
sqlite3_finalize(p->stmtIvfCentroidsAll[i]); p->stmtIvfCentroidsAll[i] = NULL;
+#if SQLITE_VEC_ENABLE_DISKANN
+ sqlite3_finalize(p->stmtDiskannNodeRead[i]); p->stmtDiskannNodeRead[i] = NULL;
+ sqlite3_finalize(p->stmtDiskannNodeWrite[i]); p->stmtDiskannNodeWrite[i] = NULL;
+ sqlite3_finalize(p->stmtDiskannNodeInsert[i]); p->stmtDiskannNodeInsert[i] = NULL;
+ sqlite3_finalize(p->stmtVectorsRead[i]); p->stmtVectorsRead[i] = NULL;
+ sqlite3_finalize(p->stmtVectorsInsert[i]); p->stmtVectorsInsert[i] = NULL;
+#endif
}
#endif
}
@@ -3464,6 +3700,13 @@ void vec0_free(vec0_vtab *p) {
p->shadowRescoreVectorsNames[i] = NULL;
#endif
+#if SQLITE_VEC_ENABLE_DISKANN
+ sqlite3_free(p->shadowVectorsNames[i]);
+ p->shadowVectorsNames[i] = NULL;
+ sqlite3_free(p->shadowDiskannNodesNames[i]);
+ p->shadowDiskannNodesNames[i] = NULL;
+#endif
+
sqlite3_free(p->vector_columns[i].name);
p->vector_columns[i].name = NULL;
}
@@ -3484,6 +3727,12 @@ void vec0_free(vec0_vtab *p) {
}
}
+#if SQLITE_VEC_ENABLE_DISKANN
+#include "sqlite-vec-diskann.c"
+#else
+static int vec0_all_columns_diskann(vec0_vtab *p) { (void)p; return 0; }
+#endif
+
int vec0_num_defined_user_columns(vec0_vtab *p) {
return p->numVectorColumns + p->numPartitionColumns + p->numAuxiliaryColumns + p->numMetadataColumns;
}
@@ -3753,6 +4002,25 @@ int vec0_get_vector_data(vec0_vtab *pVtab, i64 rowid, int vector_column_idx,
void **outVector, int *outVectorSize) {
vec0_vtab *p = pVtab;
int rc, brc;
+
+#if SQLITE_VEC_ENABLE_DISKANN
+ // DiskANN fast path: read from _vectors table
+ if (p->vector_columns[vector_column_idx].index_type == VEC0_INDEX_TYPE_DISKANN) {
+ void *vec = NULL;
+ int vecSize;
+ rc = diskann_vector_read(p, vector_column_idx, rowid, &vec, &vecSize);
+ if (rc != SQLITE_OK) {
+ vtab_set_error(&pVtab->base,
+ "Could not fetch vector data for %lld from DiskANN vectors table",
+ rowid);
+ return SQLITE_ERROR;
+ }
+ *outVector = vec;
+ if (outVectorSize) *outVectorSize = vecSize;
+ return SQLITE_OK;
+ }
+#endif
+
i64 chunk_id;
i64 chunk_offset;
@@ -4653,6 +4921,26 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
(i64)vecColumn.dimensions, SQLITE_VEC_VEC0_MAX_DIMENSIONS);
goto error;
}
+
+ // DiskANN validation
+ if (vecColumn.index_type == VEC0_INDEX_TYPE_DISKANN) {
+ if (vecColumn.element_type == SQLITE_VEC_ELEMENT_TYPE_BIT) {
+ sqlite3_free(vecColumn.name);
+ *pzErr = sqlite3_mprintf(
+ VEC_CONSTRUCTOR_ERROR
+ "DiskANN index is not supported on bit vector columns");
+ goto error;
+ }
+ if (vecColumn.diskann.quantizer_type == VEC0_DISKANN_QUANTIZER_BINARY &&
+ (vecColumn.dimensions % CHAR_BIT) != 0) {
+ sqlite3_free(vecColumn.name);
+ *pzErr = sqlite3_mprintf(
+ VEC_CONSTRUCTOR_ERROR
+ "DiskANN with binary quantizer requires dimensions divisible by 8");
+ goto error;
+ }
+ }
+
pNew->user_column_kinds[user_column_idx] = SQLITE_VEC0_USER_COLUMN_KIND_VECTOR;
pNew->user_column_idxs[user_column_idx] = numVectorColumns;
memcpy(&pNew->vector_columns[numVectorColumns], &vecColumn, sizeof(vecColumn));
@@ -4881,6 +5169,31 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
}
}
+ // DiskANN columns cannot coexist with aux/metadata/partition columns
+ for (int i = 0; i < numVectorColumns; i++) {
+ if (pNew->vector_columns[i].index_type == VEC0_INDEX_TYPE_DISKANN) {
+ if (numAuxiliaryColumns > 0) {
+ *pzErr = sqlite3_mprintf(
+ VEC_CONSTRUCTOR_ERROR
+ "Auxiliary columns are not supported with DiskANN-indexed vector columns");
+ goto error;
+ }
+ if (numMetadataColumns > 0) {
+ *pzErr = sqlite3_mprintf(
+ VEC_CONSTRUCTOR_ERROR
+ "Metadata columns are not supported with DiskANN-indexed vector columns");
+ goto error;
+ }
+ if (numPartitionColumns > 0) {
+ *pzErr = sqlite3_mprintf(
+ VEC_CONSTRUCTOR_ERROR
+ "Partition key columns are not supported with DiskANN-indexed vector columns");
+ goto error;
+ }
+ break;
+ }
+ }
+
sqlite3_str *createStr = sqlite3_str_new(NULL);
sqlite3_str_appendall(createStr, "CREATE TABLE x(");
if (pkColumnName) {
@@ -4984,6 +5297,20 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
goto error;
}
}
+#endif
+#if SQLITE_VEC_ENABLE_DISKANN
+ if (pNew->vector_columns[i].index_type == VEC0_INDEX_TYPE_DISKANN) {
+ pNew->shadowVectorsNames[i] =
+ sqlite3_mprintf("%s_vectors%02d", tableName, i);
+ if (!pNew->shadowVectorsNames[i]) {
+ goto error;
+ }
+ pNew->shadowDiskannNodesNames[i] =
+ sqlite3_mprintf("%s_diskann_nodes%02d", tableName, i);
+ if (!pNew->shadowDiskannNodesNames[i]) {
+ goto error;
+ }
+ }
#endif
}
#if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE
@@ -5060,7 +5387,32 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
}
sqlite3_finalize(stmt);
-
+#if SQLITE_VEC_ENABLE_DISKANN
+ // Seed medoid entries for DiskANN-indexed columns
+ for (int i = 0; i < pNew->numVectorColumns; i++) {
+ if (pNew->vector_columns[i].index_type != VEC0_INDEX_TYPE_DISKANN) {
+ continue;
+ }
+ char *key = sqlite3_mprintf("diskann_medoid_%02d", i);
+ char *zInsert = sqlite3_mprintf(
+ "INSERT INTO " VEC0_SHADOW_INFO_NAME "(key, value) VALUES (?1, ?2)",
+ pNew->schemaName, pNew->tableName);
+ rc = sqlite3_prepare_v2(db, zInsert, -1, &stmt, NULL);
+ sqlite3_free(zInsert);
+ if (rc != SQLITE_OK) {
+ sqlite3_free(key);
+ sqlite3_finalize(stmt);
+ goto error;
+ }
+ sqlite3_bind_text(stmt, 1, key, -1, sqlite3_free);
+ sqlite3_bind_null(stmt, 2); // NULL means empty graph
+ if (sqlite3_step(stmt) != SQLITE_DONE) {
+ sqlite3_finalize(stmt);
+ goto error;
+ }
+ sqlite3_finalize(stmt);
+ }
+#endif
// create the _chunks shadow table
char *zCreateShadowChunks = NULL;
@@ -5118,7 +5470,7 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
for (int i = 0; i < pNew->numVectorColumns; i++) {
#if SQLITE_VEC_ENABLE_RESCORE
- // Rescore and IVF columns don't use _vector_chunks
+ // Non-FLAT columns don't use _vector_chunks
if (pNew->vector_columns[i].index_type != VEC0_INDEX_TYPE_FLAT)
continue;
#endif
@@ -5159,6 +5511,84 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
}
#endif
+#if SQLITE_VEC_ENABLE_DISKANN
+ // Create DiskANN shadow tables for indexed vector columns
+ for (int i = 0; i < pNew->numVectorColumns; i++) {
+ if (pNew->vector_columns[i].index_type != VEC0_INDEX_TYPE_DISKANN) {
+ continue;
+ }
+
+ // Create _vectors{NN} table
+ {
+ char *zSql = sqlite3_mprintf(
+ "CREATE TABLE " VEC0_SHADOW_VECTORS_N_NAME
+ " (rowid INTEGER PRIMARY KEY, vector BLOB NOT NULL);",
+ pNew->schemaName, pNew->tableName, i);
+ if (!zSql) {
+ goto error;
+ }
+ rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, 0);
+ sqlite3_free(zSql);
+ if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) {
+ sqlite3_finalize(stmt);
+ *pzErr = sqlite3_mprintf(
+ "Could not create '_vectors%02d' shadow table: %s", i,
+ sqlite3_errmsg(db));
+ goto error;
+ }
+ sqlite3_finalize(stmt);
+ }
+
+ // Create _diskann_nodes{NN} table
+ {
+ char *zSql = sqlite3_mprintf(
+ "CREATE TABLE " VEC0_SHADOW_DISKANN_NODES_N_NAME " ("
+ "rowid INTEGER PRIMARY KEY, "
+ "neighbors_validity BLOB NOT NULL, "
+ "neighbor_ids BLOB NOT NULL, "
+ "neighbor_quantized_vectors BLOB NOT NULL"
+ ");",
+ pNew->schemaName, pNew->tableName, i);
+ if (!zSql) {
+ goto error;
+ }
+ rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, 0);
+ sqlite3_free(zSql);
+ if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) {
+ sqlite3_finalize(stmt);
+ *pzErr = sqlite3_mprintf(
+ "Could not create '_diskann_nodes%02d' shadow table: %s", i,
+ sqlite3_errmsg(db));
+ goto error;
+ }
+ sqlite3_finalize(stmt);
+ }
+
+ // Create _diskann_buffer{NN} table (for batched inserts)
+ {
+ char *zSql = sqlite3_mprintf(
+ "CREATE TABLE " VEC0_SHADOW_DISKANN_BUFFER_N_NAME " ("
+ "rowid INTEGER PRIMARY KEY, "
+ "vector BLOB NOT NULL"
+ ");",
+ pNew->schemaName, pNew->tableName, i);
+ if (!zSql) {
+ goto error;
+ }
+ rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, 0);
+ sqlite3_free(zSql);
+ if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) {
+ sqlite3_finalize(stmt);
+ *pzErr = sqlite3_mprintf(
+ "Could not create '_diskann_buffer%02d' shadow table: %s", i,
+ sqlite3_errmsg(db));
+ goto error;
+ }
+ sqlite3_finalize(stmt);
+ }
+ }
+#endif
+
// See SHADOW_TABLE_ROWID_QUIRK in vec0_new_chunk() — same "rowid PRIMARY KEY"
// without INTEGER type issue applies here.
for (int i = 0; i < pNew->numMetadataColumns; i++) {
@@ -5293,6 +5723,45 @@ static int vec0Destroy(sqlite3_vtab *pVtab) {
sqlite3_finalize(stmt);
for (int i = 0; i < p->numVectorColumns; i++) {
+#if SQLITE_VEC_ENABLE_DISKANN
+ if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_DISKANN) {
+ // Drop DiskANN shadow tables
+ zSql = sqlite3_mprintf("DROP TABLE IF EXISTS " VEC0_SHADOW_VECTORS_N_NAME,
+ p->schemaName, p->tableName, i);
+ if (zSql) {
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0);
+ sqlite3_free((void *)zSql);
+ if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) {
+ rc = SQLITE_ERROR;
+ goto done;
+ }
+ sqlite3_finalize(stmt);
+ }
+ zSql = sqlite3_mprintf("DROP TABLE IF EXISTS " VEC0_SHADOW_DISKANN_NODES_N_NAME,
+ p->schemaName, p->tableName, i);
+ if (zSql) {
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0);
+ sqlite3_free((void *)zSql);
+ if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) {
+ rc = SQLITE_ERROR;
+ goto done;
+ }
+ sqlite3_finalize(stmt);
+ }
+ zSql = sqlite3_mprintf("DROP TABLE IF EXISTS " VEC0_SHADOW_DISKANN_BUFFER_N_NAME,
+ p->schemaName, p->tableName, i);
+ if (zSql) {
+ rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0);
+ sqlite3_free((void *)zSql);
+ if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) {
+ rc = SQLITE_ERROR;
+ goto done;
+ }
+ sqlite3_finalize(stmt);
+ }
+ continue;
+ }
+#endif
#if SQLITE_VEC_ENABLE_RESCORE
if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_FLAT)
continue;
@@ -7088,6 +7557,171 @@ cleanup:
#include "sqlite-vec-rescore.c"
#endif
+#if SQLITE_VEC_ENABLE_DISKANN
+/**
+ * Handle a KNN query using the DiskANN graph search.
+ */
+static int vec0Filter_knn_diskann(
+ vec0_cursor *pCur, vec0_vtab *p, int idxNum,
+ const char *idxStr, int argc, sqlite3_value **argv) {
+
+ int rc;
+ int vectorColumnIdx = idxNum;
+ struct VectorColumnDefinition *vector_column = &p->vector_columns[vectorColumnIdx];
+ struct vec0_query_knn_data *knn_data;
+
+ knn_data = sqlite3_malloc(sizeof(*knn_data));
+ if (!knn_data) return SQLITE_NOMEM;
+ memset(knn_data, 0, sizeof(*knn_data));
+
+ // Parse query_idx and k_idx from idxStr
+ int query_idx = -1;
+ int k_idx = -1;
+ for (int i = 0; i < argc; i++) {
+ if (idxStr[1 + (i * 4)] == VEC0_IDXSTR_KIND_KNN_MATCH) {
+ query_idx = i;
+ }
+ if (idxStr[1 + (i * 4)] == VEC0_IDXSTR_KIND_KNN_K) {
+ k_idx = i;
+ }
+ }
+ assert(query_idx >= 0);
+ assert(k_idx >= 0);
+
+ // Extract query vector
+ void *queryVector;
+ size_t dimensions;
+ enum VectorElementType elementType;
+ vector_cleanup queryVectorCleanup = vector_cleanup_noop;
+ char *pzError;
+
+ rc = vector_from_value(argv[query_idx], &queryVector, &dimensions,
+ &elementType, &queryVectorCleanup, &pzError);
+ if (rc != SQLITE_OK) {
+ vtab_set_error(&p->base, "Invalid query vector: %z", pzError);
+ sqlite3_free(knn_data);
+ return SQLITE_ERROR;
+ }
+
+ if (elementType != vector_column->element_type ||
+ dimensions != vector_column->dimensions) {
+ vtab_set_error(&p->base, "Query vector type/dimension mismatch");
+ queryVectorCleanup(queryVector);
+ sqlite3_free(knn_data);
+ return SQLITE_ERROR;
+ }
+
+ i64 k = sqlite3_value_int64(argv[k_idx]);
+ if (k <= 0) {
+ knn_data->k = 0;
+ knn_data->k_used = 0;
+ pCur->knn_data = knn_data;
+ pCur->query_plan = VEC0_QUERY_PLAN_KNN;
+ queryVectorCleanup(queryVector);
+ return SQLITE_OK;
+ }
+
+ // Run DiskANN search
+ i64 *resultRowids = sqlite3_malloc(k * sizeof(i64));
+ f32 *resultDistances = sqlite3_malloc(k * sizeof(f32));
+ if (!resultRowids || !resultDistances) {
+ sqlite3_free(resultRowids);
+ sqlite3_free(resultDistances);
+ queryVectorCleanup(queryVector);
+ sqlite3_free(knn_data);
+ return SQLITE_NOMEM;
+ }
+
+ int resultCount;
+ rc = diskann_search(p, vectorColumnIdx, queryVector, dimensions,
+ elementType, (int)k, 0,
+ resultRowids, resultDistances, &resultCount);
+
+ if (rc != SQLITE_OK) {
+ queryVectorCleanup(queryVector);
+ sqlite3_free(resultRowids);
+ sqlite3_free(resultDistances);
+ sqlite3_free(knn_data);
+ return rc;
+ }
+
+ // Scan _diskann_buffer for any buffered (unflushed) vectors and merge
+ // with graph results. This ensures no recall loss for buffered vectors.
+ {
+ sqlite3_stmt *bufStmt = NULL;
+ char *zSql = sqlite3_mprintf(
+ "SELECT rowid, vector FROM " VEC0_SHADOW_DISKANN_BUFFER_N_NAME,
+ p->schemaName, p->tableName, vectorColumnIdx);
+ if (!zSql) {
+ queryVectorCleanup(queryVector);
+ sqlite3_free(resultRowids);
+ sqlite3_free(resultDistances);
+ sqlite3_free(knn_data);
+ return SQLITE_NOMEM;
+ }
+ int bufRc = sqlite3_prepare_v2(p->db, zSql, -1, &bufStmt, NULL);
+ sqlite3_free(zSql);
+ if (bufRc == SQLITE_OK) {
+ while (sqlite3_step(bufStmt) == SQLITE_ROW) {
+ i64 bufRowid = sqlite3_column_int64(bufStmt, 0);
+ const void *bufVec = sqlite3_column_blob(bufStmt, 1);
+ f32 dist = vec0_distance_full(
+ queryVector, bufVec, dimensions, elementType,
+ vector_column->distance_metric);
+
+ // Check if this buffer vector should replace the worst graph result
+ if (resultCount < (int)k) {
+ // Still have room, just add it
+ resultRowids[resultCount] = bufRowid;
+ resultDistances[resultCount] = dist;
+ resultCount++;
+ } else {
+ // Find worst (largest distance) in results
+ int worstIdx = 0;
+ for (int wi = 1; wi < resultCount; wi++) {
+ if (resultDistances[wi] > resultDistances[worstIdx]) {
+ worstIdx = wi;
+ }
+ }
+ if (dist < resultDistances[worstIdx]) {
+ resultRowids[worstIdx] = bufRowid;
+ resultDistances[worstIdx] = dist;
+ }
+ }
+ }
+ sqlite3_finalize(bufStmt);
+ }
+ }
+
+ queryVectorCleanup(queryVector);
+
+ // Sort results by distance (ascending)
+ for (int si = 0; si < resultCount - 1; si++) {
+ for (int sj = si + 1; sj < resultCount; sj++) {
+ if (resultDistances[sj] < resultDistances[si]) {
+ f32 tmpD = resultDistances[si];
+ resultDistances[si] = resultDistances[sj];
+ resultDistances[sj] = tmpD;
+ i64 tmpR = resultRowids[si];
+ resultRowids[si] = resultRowids[sj];
+ resultRowids[sj] = tmpR;
+ }
+ }
+ }
+
+ knn_data->k = resultCount;
+ knn_data->k_used = resultCount;
+ knn_data->rowids = resultRowids;
+ knn_data->distances = resultDistances;
+ knn_data->current_idx = 0;
+
+ pCur->knn_data = knn_data;
+ pCur->query_plan = VEC0_QUERY_PLAN_KNN;
+
+ return SQLITE_OK;
+}
+#endif /* SQLITE_VEC_ENABLE_DISKANN */
+
int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
const char *idxStr, int argc, sqlite3_value **argv) {
assert(argc == (strlen(idxStr)-1) / 4);
@@ -7098,6 +7732,13 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
struct VectorColumnDefinition *vector_column =
&p->vector_columns[vectorColumnIdx];
+#if SQLITE_VEC_ENABLE_DISKANN
+ // DiskANN dispatch
+ if (vector_column->index_type == VEC0_INDEX_TYPE_DISKANN) {
+ return vec0Filter_knn_diskann(pCur, p, idxNum, idxStr, argc, argv);
+ }
+#endif
+
struct Array *arrayRowidsIn = NULL;
sqlite3_stmt *stmtChunks = NULL;
void *queryVector;
@@ -8567,24 +9208,37 @@ int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv,
goto cleanup;
}
- // Step #2: Find the next "available" position in the _chunks table for this
- // row.
- rc = vec0Update_InsertNextAvailableStep(p, partitionKeyValues,
- &chunk_rowid, &chunk_offset,
- &blobChunksValidity,
- &bufferChunksValidity);
- if (rc != SQLITE_OK) {
- goto cleanup;
+ if (!vec0_all_columns_diskann(p)) {
+ // Step #2: Find the next "available" position in the _chunks table for this
+ // row.
+ rc = vec0Update_InsertNextAvailableStep(p, partitionKeyValues,
+ &chunk_rowid, &chunk_offset,
+ &blobChunksValidity,
+ &bufferChunksValidity);
+ if (rc != SQLITE_OK) {
+ goto cleanup;
+ }
+
+ // Step #3: With the next available chunk position, write out all the vectors
+ // to their specified location.
+ rc = vec0Update_InsertWriteFinalStep(p, chunk_rowid, chunk_offset, rowid,
+ vectorDatas, blobChunksValidity,
+ bufferChunksValidity);
+ if (rc != SQLITE_OK) {
+ goto cleanup;
+ }
}
- // Step #3: With the next available chunk position, write out all the vectors
- // to their specified location.
- rc = vec0Update_InsertWriteFinalStep(p, chunk_rowid, chunk_offset, rowid,
- vectorDatas, blobChunksValidity,
- bufferChunksValidity);
- if (rc != SQLITE_OK) {
- goto cleanup;
+#if SQLITE_VEC_ENABLE_DISKANN
+ // Step #4: Insert into DiskANN graph for indexed vector columns
+ for (int i = 0; i < p->numVectorColumns; i++) {
+ if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_DISKANN) continue;
+ rc = diskann_insert(p, i, rowid, vectorDatas[i]);
+ if (rc != SQLITE_OK) {
+ goto cleanup;
+ }
}
+#endif
#if SQLITE_VEC_ENABLE_RESCORE
rc = rescore_on_insert(p, chunk_rowid, chunk_offset, rowid, vectorDatas);
@@ -9126,29 +9780,43 @@ int vec0Update_Delete(sqlite3_vtab *pVTab, sqlite3_value *idValue) {
// 4. Zero out vector data in all vector column chunks
// 5. Delete value in _rowids table
- // 1. get chunk_id and chunk_offset from _rowids
- rc = vec0_get_chunk_position(p, rowid, NULL, &chunk_id, &chunk_offset);
- if (rc != SQLITE_OK) {
- return rc;
+#if SQLITE_VEC_ENABLE_DISKANN
+ // DiskANN graph deletion for indexed columns
+ for (int i = 0; i < p->numVectorColumns; i++) {
+ if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_DISKANN) continue;
+ rc = diskann_delete(p, i, rowid);
+ if (rc != SQLITE_OK) {
+ return rc;
+ }
+ }
+#endif
+
+ if (!vec0_all_columns_diskann(p)) {
+ // 1. get chunk_id and chunk_offset from _rowids
+ rc = vec0_get_chunk_position(p, rowid, NULL, &chunk_id, &chunk_offset);
+ if (rc != SQLITE_OK) {
+ return rc;
+ }
+
+ // 2. clear validity bit
+ rc = vec0Update_Delete_ClearValidity(p, chunk_id, chunk_offset);
+ if (rc != SQLITE_OK) {
+ return rc;
+ }
+
+ // 3. zero out rowid in chunks.rowids
+ rc = vec0Update_Delete_ClearRowid(p, chunk_id, chunk_offset);
+ if (rc != SQLITE_OK) {
+ return rc;
+ }
+
+ // 4. zero out any data in vector chunks tables
+ rc = vec0Update_Delete_ClearVectors(p, chunk_id, chunk_offset);
+ if (rc != SQLITE_OK) {
+ return rc;
+ }
}
- // 2. clear validity bit
- rc = vec0Update_Delete_ClearValidity(p, chunk_id, chunk_offset);
- if (rc != SQLITE_OK) {
- return rc;
- }
-
- // 3. zero out rowid in chunks.rowids
- rc = vec0Update_Delete_ClearRowid(p, chunk_id, chunk_offset);
- if (rc != SQLITE_OK) {
- return rc;
- }
-
- // 4. zero out any data in vector chunks tables
- rc = vec0Update_Delete_ClearVectors(p, chunk_id, chunk_offset);
- if (rc != SQLITE_OK) {
- return rc;
- }
#if SQLITE_VEC_ENABLE_RESCORE
// 4b. zero out quantized data in rescore chunk tables, delete from rescore vectors
@@ -9172,20 +9840,22 @@ int vec0Update_Delete(sqlite3_vtab *pVTab, sqlite3_value *idValue) {
}
}
- // 7. delete metadata
- for(int i = 0; i < p->numMetadataColumns; i++) {
- rc = vec0Update_Delete_ClearMetadata(p, i, rowid, chunk_id, chunk_offset);
- if (rc != SQLITE_OK) {
- return rc;
+ // 7. delete metadata and reclaim chunk (only when using chunk-based storage)
+ if (!vec0_all_columns_diskann(p)) {
+ for(int i = 0; i < p->numMetadataColumns; i++) {
+ rc = vec0Update_Delete_ClearMetadata(p, i, rowid, chunk_id, chunk_offset);
+ if (rc != SQLITE_OK) {
+ return rc;
+ }
}
- }
- // 8. reclaim chunk if fully empty
- {
- int chunkDeleted;
- rc = vec0Update_Delete_DeleteChunkIfEmpty(p, chunk_id, &chunkDeleted);
- if (rc != SQLITE_OK) {
- return rc;
+ // 8. reclaim chunk if fully empty
+ {
+ int chunkDeleted;
+ rc = vec0Update_Delete_DeleteChunkIfEmpty(p, chunk_id, &chunkDeleted);
+ if (rc != SQLITE_OK) {
+ return rc;
+ }
}
}
@@ -9481,8 +10151,12 @@ static int vec0Update(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv,
const char *cmd = (const char *)sqlite3_value_text(idVal);
vec0_vtab *p = (vec0_vtab *)pVTab;
int cmdRc = ivf_handle_command(p, cmd, argc, argv);
+#if SQLITE_VEC_ENABLE_DISKANN
+ if (cmdRc == SQLITE_EMPTY)
+ cmdRc = diskann_handle_command(p, cmd);
+#endif
if (cmdRc != SQLITE_EMPTY) return cmdRc; // handled (or error)
- // SQLITE_EMPTY means not an IVF command — fall through to normal insert
+ // SQLITE_EMPTY means not a recognized command — fall through to normal insert
}
#endif
return vec0Update_Insert(pVTab, argc, argv, pRowid);
@@ -9638,9 +10312,16 @@ static sqlite3_module vec0Module = {
#define SQLITE_VEC_DEBUG_BUILD_IVF ""
#endif
+#if SQLITE_VEC_ENABLE_DISKANN
+#define SQLITE_VEC_DEBUG_BUILD_DISKANN "diskann"
+#else
+#define SQLITE_VEC_DEBUG_BUILD_DISKANN ""
+#endif
+
#define SQLITE_VEC_DEBUG_BUILD \
SQLITE_VEC_DEBUG_BUILD_AVX " " SQLITE_VEC_DEBUG_BUILD_NEON " " \
- SQLITE_VEC_DEBUG_BUILD_RESCORE " " SQLITE_VEC_DEBUG_BUILD_IVF
+ SQLITE_VEC_DEBUG_BUILD_RESCORE " " SQLITE_VEC_DEBUG_BUILD_IVF " " \
+ SQLITE_VEC_DEBUG_BUILD_DISKANN
#define SQLITE_VEC_DEBUG_STRING \
"Version: " SQLITE_VEC_VERSION "\n" \
diff --git a/tests/fuzz/Makefile b/tests/fuzz/Makefile
index a3405a4..202dc2b 100644
--- a/tests/fuzz/Makefile
+++ b/tests/fuzz/Makefile
@@ -26,7 +26,7 @@ FUZZ_LDFLAGS ?= $(shell \
echo "-Wl,-ld_classic"; \
fi)
-FUZZ_CFLAGS = $(FUZZ_SANITIZERS) -I ../../ -I ../../vendor -DSQLITE_CORE -g $(FUZZ_LDFLAGS)
+FUZZ_CFLAGS = $(FUZZ_SANITIZERS) -I ../../ -I ../../vendor -DSQLITE_CORE -DSQLITE_VEC_ENABLE_DISKANN=1 -g $(FUZZ_LDFLAGS)
FUZZ_SRCS = ../../vendor/sqlite3.c ../../sqlite-vec.c
TARGET_DIR = ./targets
@@ -115,6 +115,34 @@ $(TARGET_DIR)/ivf_cell_overflow: ivf-cell-overflow.c $(FUZZ_SRCS) | $(TARGET_DIR
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
$(TARGET_DIR)/ivf_rescore: ivf-rescore.c $(FUZZ_SRCS) | $(TARGET_DIR)
+$(TARGET_DIR)/diskann_operations: diskann-operations.c $(FUZZ_SRCS) | $(TARGET_DIR)
+ $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
+
+$(TARGET_DIR)/diskann_create: diskann-create.c $(FUZZ_SRCS) | $(TARGET_DIR)
+ $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
+
+$(TARGET_DIR)/diskann_graph_corrupt: diskann-graph-corrupt.c $(FUZZ_SRCS) | $(TARGET_DIR)
+ $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
+
+$(TARGET_DIR)/diskann_deep_search: diskann-deep-search.c $(FUZZ_SRCS) | $(TARGET_DIR)
+ $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
+
+$(TARGET_DIR)/diskann_blob_truncate: diskann-blob-truncate.c $(FUZZ_SRCS) | $(TARGET_DIR)
+ $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
+
+$(TARGET_DIR)/diskann_delete_stress: diskann-delete-stress.c $(FUZZ_SRCS) | $(TARGET_DIR)
+ $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
+
+$(TARGET_DIR)/diskann_buffer_flush: diskann-buffer-flush.c $(FUZZ_SRCS) | $(TARGET_DIR)
+ $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
+
+$(TARGET_DIR)/diskann_int8_quant: diskann-int8-quant.c $(FUZZ_SRCS) | $(TARGET_DIR)
+ $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
+
+$(TARGET_DIR)/diskann_prune_direct: diskann-prune-direct.c $(FUZZ_SRCS) | $(TARGET_DIR)
+ $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
+
+$(TARGET_DIR)/diskann_command_inject: diskann-command-inject.c $(FUZZ_SRCS) | $(TARGET_DIR)
$(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@
FUZZ_TARGETS = vec0_create exec json numpy \
@@ -127,6 +155,11 @@ FUZZ_TARGETS = vec0_create exec json numpy \
ivf_create ivf_operations \
ivf_quantize ivf_kmeans ivf_shadow_corrupt \
ivf_knn_deep ivf_cell_overflow ivf_rescore
+ diskann_operations diskann_create diskann_graph_corrupt \
+ diskann_deep_search diskann_blob_truncate \
+ diskann_delete_stress diskann_buffer_flush \
+ diskann_int8_quant diskann_prune_direct \
+ diskann_command_inject
all: $(addprefix $(TARGET_DIR)/,$(FUZZ_TARGETS))
diff --git a/tests/fuzz/diskann-blob-truncate.c b/tests/fuzz/diskann-blob-truncate.c
new file mode 100644
index 0000000..903a0d7
--- /dev/null
+++ b/tests/fuzz/diskann-blob-truncate.c
@@ -0,0 +1,250 @@
+/**
+ * Fuzz target for DiskANN shadow table blob size mismatches.
+ *
+ * The critical vulnerability: diskann_node_read() copies whatever blob size
+ * SQLite returns, but diskann_search/insert/delete index into those blobs
+ * using cfg->n_neighbors * sizeof(i64) etc. If the blob is truncated,
+ * extended, or has wrong size, this causes out-of-bounds reads/writes.
+ *
+ * This fuzzer:
+ * 1. Creates a valid DiskANN graph with several nodes
+ * 2. Uses fuzz data to directly write malformed blobs to shadow tables:
+ * - Truncated neighbor_ids (fewer bytes than n_neighbors * 8)
+ * - Truncated validity bitmaps
+ * - Oversized blobs with garbage trailing data
+ * - Zero-length blobs
+ * - Blobs with valid headers but corrupted neighbor rowids
+ * 3. Runs INSERT, DELETE, and KNN operations that traverse the corrupted graph
+ *
+ * Key code paths targeted:
+ * - diskann_node_read with mismatched blob sizes
+ * - diskann_validity_get / diskann_neighbor_id_get on truncated blobs
+ * - diskann_add_reverse_edge reading corrupted neighbor data
+ * - diskann_repair_reverse_edges traversing corrupted neighbor lists
+ * - diskann_search iterating neighbors from corrupted blobs
+ */
+#include
+#include
+#include
+#include
+#include
+#include "sqlite-vec.h"
+#include "sqlite3.h"
+#include
+
+static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) {
+ if (*size == 0) return def;
+ uint8_t b = **data;
+ (*data)++;
+ (*size)--;
+ return b;
+}
+
+int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
+ if (size < 32) return 0;
+
+ int rc;
+ sqlite3 *db;
+
+ rc = sqlite3_open(":memory:", &db);
+ assert(rc == SQLITE_OK);
+ rc = sqlite3_vec_init(db, NULL, NULL);
+ assert(rc == SQLITE_OK);
+
+ /* Use binary quantizer, float[16], n_neighbors=8 for predictable blob sizes:
+ * validity: 8/8 = 1 byte
+ * neighbor_ids: 8 * 8 = 64 bytes
+ * qvecs: 8 * (16/8) = 16 bytes (binary: 2 bytes per qvec)
+ */
+ rc = sqlite3_exec(db,
+ "CREATE VIRTUAL TABLE v USING vec0("
+ "emb float[16] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8))",
+ NULL, NULL, NULL);
+ if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
+
+ /* Insert 12 vectors to create a valid graph structure */
+ {
+ sqlite3_stmt *stmt;
+ sqlite3_prepare_v2(db,
+ "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmt, NULL);
+ for (int i = 1; i <= 12; i++) {
+ float vec[16];
+ for (int j = 0; j < 16; j++) {
+ vec[j] = (float)i * 0.1f + (float)j * 0.01f;
+ }
+ sqlite3_reset(stmt);
+ sqlite3_bind_int64(stmt, 1, i);
+ sqlite3_bind_blob(stmt, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
+ sqlite3_step(stmt);
+ }
+ sqlite3_finalize(stmt);
+ }
+
+ /* Now corrupt shadow table blobs using fuzz data */
+ const char *columns[] = {
+ "neighbors_validity",
+ "neighbor_ids",
+ "neighbor_quantized_vectors"
+ };
+
+ /* Expected sizes for n_neighbors=8, dims=16, binary quantizer */
+ int expected_sizes[] = {1, 64, 16};
+
+ while (size >= 4) {
+ int target_row = (fuzz_byte(&data, &size, 0) % 12) + 1;
+ int col_idx = fuzz_byte(&data, &size, 0) % 3;
+ uint8_t corrupt_mode = fuzz_byte(&data, &size, 0) % 6;
+ uint8_t extra = fuzz_byte(&data, &size, 0);
+
+ char sqlbuf[256];
+ snprintf(sqlbuf, sizeof(sqlbuf),
+ "UPDATE v_diskann_nodes00 SET %s = ? WHERE rowid = ?",
+ columns[col_idx]);
+
+ sqlite3_stmt *writeStmt;
+ rc = sqlite3_prepare_v2(db, sqlbuf, -1, &writeStmt, NULL);
+ if (rc != SQLITE_OK) continue;
+
+ int expected = expected_sizes[col_idx];
+ unsigned char *blob = NULL;
+ int blob_size = 0;
+
+ switch (corrupt_mode) {
+ case 0: {
+ /* Truncated blob: 0 to expected-1 bytes */
+ blob_size = extra % expected;
+ if (blob_size == 0) blob_size = 0; /* zero-length is interesting */
+ blob = sqlite3_malloc(blob_size > 0 ? blob_size : 1);
+ if (!blob) { sqlite3_finalize(writeStmt); continue; }
+ for (int i = 0; i < blob_size; i++) {
+ blob[i] = fuzz_byte(&data, &size, 0);
+ }
+ break;
+ }
+ case 1: {
+ /* Oversized blob: expected + extra bytes */
+ blob_size = expected + (extra % 64);
+ blob = sqlite3_malloc(blob_size);
+ if (!blob) { sqlite3_finalize(writeStmt); continue; }
+ for (int i = 0; i < blob_size; i++) {
+ blob[i] = fuzz_byte(&data, &size, 0xFF);
+ }
+ break;
+ }
+ case 2: {
+ /* Zero-length blob */
+ blob_size = 0;
+ blob = NULL;
+ sqlite3_bind_zeroblob(writeStmt, 1, 0);
+ sqlite3_bind_int64(writeStmt, 2, target_row);
+ sqlite3_step(writeStmt);
+ sqlite3_finalize(writeStmt);
+ continue;
+ }
+ case 3: {
+ /* Correct size but all-ones validity (all slots "valid") with
+ * garbage neighbor IDs -- forces reading non-existent nodes */
+ blob_size = expected;
+ blob = sqlite3_malloc(blob_size);
+ if (!blob) { sqlite3_finalize(writeStmt); continue; }
+ memset(blob, 0xFF, blob_size);
+ break;
+ }
+ case 4: {
+ /* neighbor_ids with very large rowid values (near INT64_MAX) */
+ blob_size = expected;
+ blob = sqlite3_malloc(blob_size);
+ if (!blob) { sqlite3_finalize(writeStmt); continue; }
+ memset(blob, 0x7F, blob_size); /* fills with large positive values */
+ break;
+ }
+ case 5: {
+ /* neighbor_ids with negative rowid values (rowid=0 is sentinel) */
+ blob_size = expected;
+ blob = sqlite3_malloc(blob_size);
+ if (!blob) { sqlite3_finalize(writeStmt); continue; }
+ memset(blob, 0x80, blob_size); /* fills with large negative values */
+ /* Flip some bytes from fuzz data */
+ for (int i = 0; i < blob_size && size > 0; i++) {
+ blob[i] ^= fuzz_byte(&data, &size, 0);
+ }
+ break;
+ }
+ }
+
+ if (blob) {
+ sqlite3_bind_blob(writeStmt, 1, blob, blob_size, SQLITE_TRANSIENT);
+ } else {
+ sqlite3_bind_blob(writeStmt, 1, "", 0, SQLITE_STATIC);
+ }
+ sqlite3_bind_int64(writeStmt, 2, target_row);
+ sqlite3_step(writeStmt);
+ sqlite3_finalize(writeStmt);
+ sqlite3_free(blob);
+ }
+
+ /* Exercise the corrupted graph with various operations */
+
+ /* KNN query */
+ {
+ float qvec[16];
+ for (int j = 0; j < 16; j++) qvec[j] = (float)j * 0.1f;
+ sqlite3_stmt *knnStmt;
+ rc = sqlite3_prepare_v2(db,
+ "SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = 5",
+ -1, &knnStmt, NULL);
+ if (rc == SQLITE_OK) {
+ sqlite3_bind_blob(knnStmt, 1, qvec, sizeof(qvec), SQLITE_STATIC);
+ while (sqlite3_step(knnStmt) == SQLITE_ROW) {}
+ sqlite3_finalize(knnStmt);
+ }
+ }
+
+ /* Insert into corrupted graph (triggers add_reverse_edge on corrupted nodes) */
+ {
+ float vec[16];
+ for (int j = 0; j < 16; j++) vec[j] = 0.5f;
+ sqlite3_stmt *stmt;
+ sqlite3_prepare_v2(db,
+ "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmt, NULL);
+ if (stmt) {
+ sqlite3_bind_int64(stmt, 1, 100);
+ sqlite3_bind_blob(stmt, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
+ sqlite3_step(stmt);
+ sqlite3_finalize(stmt);
+ }
+ }
+
+ /* Delete from corrupted graph (triggers repair_reverse_edges) */
+ {
+ sqlite3_stmt *stmt;
+ sqlite3_prepare_v2(db,
+ "DELETE FROM v WHERE rowid = ?", -1, &stmt, NULL);
+ if (stmt) {
+ sqlite3_bind_int64(stmt, 1, 5);
+ sqlite3_step(stmt);
+ sqlite3_finalize(stmt);
+ }
+ }
+
+ /* Another KNN to traverse the post-mutation graph */
+ {
+ float qvec[16];
+ for (int j = 0; j < 16; j++) qvec[j] = -0.5f + (float)j * 0.07f;
+ sqlite3_stmt *knnStmt;
+ rc = sqlite3_prepare_v2(db,
+ "SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = 12",
+ -1, &knnStmt, NULL);
+ if (rc == SQLITE_OK) {
+ sqlite3_bind_blob(knnStmt, 1, qvec, sizeof(qvec), SQLITE_STATIC);
+ while (sqlite3_step(knnStmt) == SQLITE_ROW) {}
+ sqlite3_finalize(knnStmt);
+ }
+ }
+
+ /* Full scan */
+ sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
+
+ sqlite3_close(db);
+ return 0;
+}
diff --git a/tests/fuzz/diskann-buffer-flush.c b/tests/fuzz/diskann-buffer-flush.c
new file mode 100644
index 0000000..f10e100
--- /dev/null
+++ b/tests/fuzz/diskann-buffer-flush.c
@@ -0,0 +1,164 @@
+/**
+ * Fuzz target for DiskANN buffered insert and flush paths.
+ *
+ * When buffer_threshold > 0, inserts go into a flat buffer table and
+ * are flushed into the graph in batch. This fuzzer exercises:
+ *
+ * - diskann_buffer_write / diskann_buffer_delete / diskann_buffer_exists
+ * - diskann_flush_buffer (batch graph insertion)
+ * - diskann_insert with buffer_threshold (batching logic)
+ * - Buffer-graph merge in vec0Filter_knn_diskann (unflushed vectors
+ * must be scanned during KNN and merged with graph results)
+ * - Delete of a buffered (not yet flushed) vector
+ * - Delete of a graph vector while buffer has pending inserts
+ * - Interaction: insert to buffer, query (triggers buffer scan), flush,
+ * query again (now from graph)
+ *
+ * The buffer merge path in vec0Filter_knn_diskann is particularly
+ * interesting because it does a brute-force scan of buffer vectors and
+ * merges with the top-k from graph search.
+ */
+#include
+#include
+#include
+#include
+#include
+#include "sqlite-vec.h"
+#include "sqlite3.h"
+#include
+
+static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) {
+ if (*size == 0) return def;
+ uint8_t b = **data;
+ (*data)++;
+ (*size)--;
+ return b;
+}
+
+int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
+ if (size < 16) return 0;
+
+ int rc;
+ sqlite3 *db;
+ rc = sqlite3_open(":memory:", &db);
+ assert(rc == SQLITE_OK);
+ rc = sqlite3_vec_init(db, NULL, NULL);
+ assert(rc == SQLITE_OK);
+
+ /* buffer_threshold: small (3-8) to trigger frequent flushes */
+ int buf_threshold = 3 + (fuzz_byte(&data, &size, 0) % 6);
+ int dims = 8;
+
+ char sql[512];
+ snprintf(sql, sizeof(sql),
+ "CREATE VIRTUAL TABLE v USING vec0("
+ "emb float[%d] INDEXED BY diskann("
+ "neighbor_quantizer=binary, n_neighbors=8, "
+ "search_list_size=16, buffer_threshold=%d"
+ "))", dims, buf_threshold);
+
+ rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
+ if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
+
+ sqlite3_stmt *stmtInsert = NULL, *stmtDelete = NULL, *stmtKnn = NULL;
+ sqlite3_prepare_v2(db,
+ "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
+ sqlite3_prepare_v2(db,
+ "DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
+ sqlite3_prepare_v2(db,
+ "SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?",
+ -1, &stmtKnn, NULL);
+
+ if (!stmtInsert || !stmtDelete || !stmtKnn) goto cleanup;
+
+ float vec[8];
+ int next_rowid = 1;
+
+ while (size >= 2) {
+ uint8_t op = fuzz_byte(&data, &size, 0) % 6;
+ uint8_t param = fuzz_byte(&data, &size, 0);
+
+ switch (op) {
+ case 0: { /* Insert: accumulates in buffer until threshold */
+ int64_t rowid = next_rowid++;
+ if (next_rowid > 64) next_rowid = 1; /* wrap around for reuse */
+ for (int j = 0; j < dims; j++) {
+ vec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f;
+ }
+ sqlite3_reset(stmtInsert);
+ sqlite3_bind_int64(stmtInsert, 1, rowid);
+ sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
+ sqlite3_step(stmtInsert);
+ break;
+ }
+ case 1: { /* KNN query while buffer may have unflushed vectors */
+ for (int j = 0; j < dims; j++) {
+ vec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f;
+ }
+ int k = (param % 10) + 1;
+ sqlite3_reset(stmtKnn);
+ sqlite3_bind_blob(stmtKnn, 1, vec, sizeof(vec), SQLITE_TRANSIENT);
+ sqlite3_bind_int(stmtKnn, 2, k);
+ while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
+ break;
+ }
+ case 2: { /* Delete a potentially-buffered vector */
+ int64_t rowid = (int64_t)(param % 64) + 1;
+ sqlite3_reset(stmtDelete);
+ sqlite3_bind_int64(stmtDelete, 1, rowid);
+ sqlite3_step(stmtDelete);
+ break;
+ }
+ case 3: { /* Insert several at once to trigger flush mid-batch */
+ for (int i = 0; i < buf_threshold + 1 && size >= 2; i++) {
+ int64_t rowid = (int64_t)(fuzz_byte(&data, &size, 0) % 64) + 1;
+ for (int j = 0; j < dims; j++) {
+ vec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f;
+ }
+ sqlite3_reset(stmtInsert);
+ sqlite3_bind_int64(stmtInsert, 1, rowid);
+ sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
+ sqlite3_step(stmtInsert);
+ }
+ break;
+ }
+ case 4: { /* Insert then immediately delete (still in buffer) */
+ int64_t rowid = (int64_t)(param % 64) + 1;
+ for (int j = 0; j < dims; j++) vec[j] = 0.1f * param;
+ sqlite3_reset(stmtInsert);
+ sqlite3_bind_int64(stmtInsert, 1, rowid);
+ sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
+ sqlite3_step(stmtInsert);
+
+ sqlite3_reset(stmtDelete);
+ sqlite3_bind_int64(stmtDelete, 1, rowid);
+ sqlite3_step(stmtDelete);
+ break;
+ }
+ case 5: { /* Query with k=0 and k=1 (boundary) */
+ for (int j = 0; j < dims; j++) vec[j] = 0.0f;
+ sqlite3_reset(stmtKnn);
+ sqlite3_bind_blob(stmtKnn, 1, vec, sizeof(vec), SQLITE_TRANSIENT);
+ sqlite3_bind_int(stmtKnn, 2, param % 2); /* k=0 or k=1 */
+ while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
+ break;
+ }
+ }
+ }
+
+ /* Final query to exercise post-operation state */
+ {
+ float qvec[8] = {1.0f, -1.0f, 0.5f, -0.5f, 0.0f, 0.0f, 0.0f, 0.0f};
+ sqlite3_reset(stmtKnn);
+ sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_TRANSIENT);
+ sqlite3_bind_int(stmtKnn, 2, 20);
+ while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
+ }
+
+cleanup:
+ sqlite3_finalize(stmtInsert);
+ sqlite3_finalize(stmtDelete);
+ sqlite3_finalize(stmtKnn);
+ sqlite3_close(db);
+ return 0;
+}
diff --git a/tests/fuzz/diskann-command-inject.c b/tests/fuzz/diskann-command-inject.c
new file mode 100644
index 0000000..ef62884
--- /dev/null
+++ b/tests/fuzz/diskann-command-inject.c
@@ -0,0 +1,158 @@
+/**
+ * Fuzz target for DiskANN runtime command dispatch (diskann_handle_command).
+ *
+ * The command handler parses strings like "search_list_size_search=42" and
+ * modifies live DiskANN config. This fuzzer exercises:
+ *
+ * - atoi on fuzz-controlled strings (integer overflow, negative, non-numeric)
+ * - strncmp boundary with fuzz data (near-matches to valid commands)
+ * - Changing search_list_size mid-operation (affects subsequent queries)
+ * - Setting search_list_size to 1 (minimum - single-candidate beam search)
+ * - Setting search_list_size very large (memory pressure)
+ * - Interleaving command changes with inserts and queries
+ *
+ * Also tests the UPDATE v SET command = ? path through the vtable.
+ */
+#include
+#include
+#include
+#include
+#include
+#include "sqlite-vec.h"
+#include "sqlite3.h"
+#include
+
+static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) {
+ if (*size == 0) return def;
+ uint8_t b = **data;
+ (*data)++;
+ (*size)--;
+ return b;
+}
+
+int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
+ if (size < 20) return 0;
+
+ int rc;
+ sqlite3 *db;
+ rc = sqlite3_open(":memory:", &db);
+ assert(rc == SQLITE_OK);
+ rc = sqlite3_vec_init(db, NULL, NULL);
+ assert(rc == SQLITE_OK);
+
+ rc = sqlite3_exec(db,
+ "CREATE VIRTUAL TABLE v USING vec0("
+ "emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8))",
+ NULL, NULL, NULL);
+ if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
+
+ /* Insert some vectors first */
+ {
+ sqlite3_stmt *stmt;
+ sqlite3_prepare_v2(db,
+ "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmt, NULL);
+ for (int i = 1; i <= 8; i++) {
+ float vec[8];
+ for (int j = 0; j < 8; j++) vec[j] = (float)i * 0.1f + (float)j * 0.01f;
+ sqlite3_reset(stmt);
+ sqlite3_bind_int64(stmt, 1, i);
+ sqlite3_bind_blob(stmt, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
+ sqlite3_step(stmt);
+ }
+ sqlite3_finalize(stmt);
+ }
+
+ sqlite3_stmt *stmtCmd = NULL;
+ sqlite3_stmt *stmtInsert = NULL;
+ sqlite3_stmt *stmtKnn = NULL;
+
+ /* Commands are dispatched via INSERT INTO t(rowid) VALUES ('cmd_string') */
+ sqlite3_prepare_v2(db,
+ "INSERT INTO v(rowid) VALUES (?)", -1, &stmtCmd, NULL);
+ sqlite3_prepare_v2(db,
+ "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
+ sqlite3_prepare_v2(db,
+ "SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?",
+ -1, &stmtKnn, NULL);
+
+ if (!stmtCmd || !stmtInsert || !stmtKnn) goto cleanup;
+
+ /* Fuzz-driven command + operation interleaving */
+ while (size >= 2) {
+ uint8_t op = fuzz_byte(&data, &size, 0) % 5;
+
+ switch (op) {
+ case 0: { /* Send fuzz command string */
+ int cmd_len = fuzz_byte(&data, &size, 0) % 64;
+ char cmd[65];
+ for (int i = 0; i < cmd_len && size > 0; i++) {
+ cmd[i] = (char)fuzz_byte(&data, &size, 0);
+ }
+ cmd[cmd_len] = '\0';
+ sqlite3_reset(stmtCmd);
+ sqlite3_bind_text(stmtCmd, 1, cmd, -1, SQLITE_TRANSIENT);
+ sqlite3_step(stmtCmd); /* May fail -- that's expected */
+ break;
+ }
+ case 1: { /* Send valid-looking command with fuzz value */
+ const char *prefixes[] = {
+ "search_list_size=",
+ "search_list_size_search=",
+ "search_list_size_insert=",
+ };
+ int prefix_idx = fuzz_byte(&data, &size, 0) % 3;
+ int val = (int)(int8_t)fuzz_byte(&data, &size, 0);
+
+ char cmd[128];
+ snprintf(cmd, sizeof(cmd), "%s%d", prefixes[prefix_idx], val);
+ sqlite3_reset(stmtCmd);
+ sqlite3_bind_text(stmtCmd, 1, cmd, -1, SQLITE_TRANSIENT);
+ sqlite3_step(stmtCmd);
+ break;
+ }
+ case 2: { /* KNN query (uses whatever search_list_size is set) */
+ float qvec[8] = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
+ qvec[0] = (float)((int8_t)fuzz_byte(&data, &size, 127)) / 10.0f;
+ int k = fuzz_byte(&data, &size, 3) % 10 + 1;
+ sqlite3_reset(stmtKnn);
+ sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_TRANSIENT);
+ sqlite3_bind_int(stmtKnn, 2, k);
+ while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
+ break;
+ }
+ case 3: { /* Insert (uses whatever search_list_size_insert is set) */
+ int64_t rowid = (int64_t)(fuzz_byte(&data, &size, 0) % 32) + 1;
+ float vec[8];
+ for (int j = 0; j < 8; j++) {
+ vec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f;
+ }
+ sqlite3_reset(stmtInsert);
+ sqlite3_bind_int64(stmtInsert, 1, rowid);
+ sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
+ sqlite3_step(stmtInsert);
+ break;
+ }
+ case 4: { /* Set search_list_size to extreme values */
+ const char *extreme_cmds[] = {
+ "search_list_size=1",
+ "search_list_size=2",
+ "search_list_size=1000",
+ "search_list_size_search=1",
+ "search_list_size_insert=1",
+ };
+ int idx = fuzz_byte(&data, &size, 0) % 5;
+ sqlite3_reset(stmtCmd);
+ sqlite3_bind_text(stmtCmd, 1, extreme_cmds[idx], -1, SQLITE_STATIC);
+ sqlite3_step(stmtCmd);
+ break;
+ }
+ }
+ }
+
+cleanup:
+ sqlite3_finalize(stmtCmd);
+ sqlite3_finalize(stmtInsert);
+ sqlite3_finalize(stmtKnn);
+ sqlite3_close(db);
+ return 0;
+}
diff --git a/tests/fuzz/diskann-create.c b/tests/fuzz/diskann-create.c
new file mode 100644
index 0000000..1b40a84
--- /dev/null
+++ b/tests/fuzz/diskann-create.c
@@ -0,0 +1,44 @@
+/**
+ * Fuzz target for DiskANN CREATE TABLE config parsing.
+ * Feeds fuzz data as the INDEXED BY diskann(...) option string.
+ */
+#include
+#include
+#include
+#include
+#include
+#include "sqlite-vec.h"
+#include "sqlite3.h"
+#include
+
+int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
+ if (size > 4096) return 0; /* Limit input size */
+
+ int rc;
+ sqlite3 *db;
+ sqlite3_stmt *stmt;
+
+ rc = sqlite3_open(":memory:", &db);
+ assert(rc == SQLITE_OK);
+ rc = sqlite3_vec_init(db, NULL, NULL);
+ assert(rc == SQLITE_OK);
+
+ sqlite3_str *s = sqlite3_str_new(NULL);
+ assert(s);
+ sqlite3_str_appendall(s,
+ "CREATE VIRTUAL TABLE v USING vec0("
+ "emb float[64] INDEXED BY diskann(");
+ sqlite3_str_appendf(s, "%.*s", (int)size, data);
+ sqlite3_str_appendall(s, "))");
+ const char *zSql = sqlite3_str_finish(s);
+ assert(zSql);
+
+ rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, NULL);
+ sqlite3_free((char *)zSql);
+ if (rc == SQLITE_OK) {
+ sqlite3_step(stmt);
+ }
+ sqlite3_finalize(stmt);
+ sqlite3_close(db);
+ return 0;
+}
diff --git a/tests/fuzz/diskann-deep-search.c b/tests/fuzz/diskann-deep-search.c
new file mode 100644
index 0000000..35d548c
--- /dev/null
+++ b/tests/fuzz/diskann-deep-search.c
@@ -0,0 +1,187 @@
+/**
+ * Fuzz target for DiskANN greedy beam search deep paths.
+ *
+ * Builds a graph with enough nodes to force multi-hop traversal, then
+ * uses fuzz data to control: query vector values, k, search_list_size
+ * overrides, and interleaved insert/delete/query sequences that stress
+ * the candidate list growth, visited set hash collisions, and the
+ * re-ranking logic.
+ *
+ * Key code paths targeted:
+ * - diskann_candidate_list_insert (sorted insert, dedup, eviction)
+ * - diskann_visited_set (hash collisions, capacity)
+ * - diskann_search (full beam search loop, re-ranking with exact dist)
+ * - diskann_distance_quantized_precomputed (both binary and int8)
+ * - Buffer merge in vec0Filter_knn_diskann
+ */
+#include
+#include
+#include
+#include
+#include
+#include
+#include "sqlite-vec.h"
+#include "sqlite3.h"
+#include
+
+/* Consume one byte from fuzz input, or return default. */
+static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) {
+ if (*size == 0) return def;
+ uint8_t b = **data;
+ (*data)++;
+ (*size)--;
+ return b;
+}
+
+static uint16_t fuzz_u16(const uint8_t **data, size_t *size) {
+ uint8_t lo = fuzz_byte(data, size, 0);
+ uint8_t hi = fuzz_byte(data, size, 0);
+ return (uint16_t)hi << 8 | lo;
+}
+
+static float fuzz_float(const uint8_t **data, size_t *size) {
+ return (float)((int8_t)fuzz_byte(data, size, 0)) / 10.0f;
+}
+
+int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
+ if (size < 32) return 0;
+
+ /* Use first bytes to pick quantizer type and dimensions */
+ uint8_t quantizer_choice = fuzz_byte(&data, &size, 0) % 2;
+ const char *quantizer = quantizer_choice ? "int8" : "binary";
+
+ /* Dimensions must be divisible by 8. Pick from {8, 16, 32} */
+ int dim_choices[] = {8, 16, 32};
+ int dims = dim_choices[fuzz_byte(&data, &size, 0) % 3];
+
+ /* n_neighbors: 8 or 16 -- small to force full-neighbor scenarios quickly */
+ int n_neighbors = (fuzz_byte(&data, &size, 0) % 2) ? 16 : 8;
+
+ /* search_list_size: small so beam search terminates quickly but still exercises loops */
+ int search_list_size = 8 + (fuzz_byte(&data, &size, 0) % 24);
+
+ /* alpha: vary to test RobustPrune pruning logic */
+ float alpha_choices[] = {1.0f, 1.2f, 1.5f, 2.0f};
+ float alpha = alpha_choices[fuzz_byte(&data, &size, 0) % 4];
+
+ int rc;
+ sqlite3 *db;
+ rc = sqlite3_open(":memory:", &db);
+ assert(rc == SQLITE_OK);
+ rc = sqlite3_vec_init(db, NULL, NULL);
+ assert(rc == SQLITE_OK);
+
+ char sql[512];
+ snprintf(sql, sizeof(sql),
+ "CREATE VIRTUAL TABLE v USING vec0("
+ "emb float[%d] INDEXED BY diskann("
+ "neighbor_quantizer=%s, n_neighbors=%d, "
+ "search_list_size=%d"
+ "))", dims, quantizer, n_neighbors, search_list_size);
+
+ rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
+ if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
+
+ sqlite3_stmt *stmtInsert = NULL, *stmtDelete = NULL, *stmtKnn = NULL;
+ sqlite3_prepare_v2(db,
+ "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
+ sqlite3_prepare_v2(db,
+ "DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
+
+ char knn_sql[256];
+ snprintf(knn_sql, sizeof(knn_sql),
+ "SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?");
+ sqlite3_prepare_v2(db, knn_sql, -1, &stmtKnn, NULL);
+
+ if (!stmtInsert || !stmtDelete || !stmtKnn) goto cleanup;
+
+ /* Phase 1: Seed the graph with enough nodes to create multi-hop structure.
+ * Insert 2*n_neighbors nodes so the graph is dense enough for search
+ * to actually traverse multiple hops. */
+ int seed_count = n_neighbors * 2;
+ if (seed_count > 64) seed_count = 64; /* Bound for performance */
+ {
+ float *vec = malloc(dims * sizeof(float));
+ if (!vec) goto cleanup;
+ for (int i = 1; i <= seed_count; i++) {
+ for (int j = 0; j < dims; j++) {
+ vec[j] = fuzz_float(&data, &size);
+ }
+ sqlite3_reset(stmtInsert);
+ sqlite3_bind_int64(stmtInsert, 1, i);
+ sqlite3_bind_blob(stmtInsert, 2, vec, dims * sizeof(float), SQLITE_TRANSIENT);
+ sqlite3_step(stmtInsert);
+ }
+ free(vec);
+ }
+
+ /* Phase 2: Fuzz-driven operations on the seeded graph */
+ float *vec = malloc(dims * sizeof(float));
+ if (!vec) goto cleanup;
+
+ while (size >= 2) {
+ uint8_t op = fuzz_byte(&data, &size, 0) % 5;
+ uint8_t param = fuzz_byte(&data, &size, 0);
+
+ switch (op) {
+ case 0: { /* INSERT with fuzz-controlled vector and rowid */
+ int64_t rowid = (int64_t)(param % 128) + 1;
+ for (int j = 0; j < dims; j++) {
+ vec[j] = fuzz_float(&data, &size);
+ }
+ sqlite3_reset(stmtInsert);
+ sqlite3_bind_int64(stmtInsert, 1, rowid);
+ sqlite3_bind_blob(stmtInsert, 2, vec, dims * sizeof(float), SQLITE_TRANSIENT);
+ sqlite3_step(stmtInsert);
+ break;
+ }
+ case 1: { /* DELETE */
+ int64_t rowid = (int64_t)(param % 128) + 1;
+ sqlite3_reset(stmtDelete);
+ sqlite3_bind_int64(stmtDelete, 1, rowid);
+ sqlite3_step(stmtDelete);
+ break;
+ }
+ case 2: { /* KNN with fuzz query vector and variable k */
+ for (int j = 0; j < dims; j++) {
+ vec[j] = fuzz_float(&data, &size);
+ }
+ int k = (param % 20) + 1;
+ sqlite3_reset(stmtKnn);
+ sqlite3_bind_blob(stmtKnn, 1, vec, dims * sizeof(float), SQLITE_TRANSIENT);
+ sqlite3_bind_int(stmtKnn, 2, k);
+ while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
+ break;
+ }
+ case 3: { /* KNN with k > number of nodes (boundary) */
+ for (int j = 0; j < dims; j++) {
+ vec[j] = fuzz_float(&data, &size);
+ }
+ sqlite3_reset(stmtKnn);
+ sqlite3_bind_blob(stmtKnn, 1, vec, dims * sizeof(float), SQLITE_TRANSIENT);
+ sqlite3_bind_int(stmtKnn, 2, 1000); /* k >> graph size */
+ while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
+ break;
+ }
+ case 4: { /* INSERT duplicate rowid (triggers OR REPLACE path) */
+ int64_t rowid = (int64_t)(param % 32) + 1;
+ for (int j = 0; j < dims; j++) {
+ vec[j] = (float)(param + j) / 50.0f;
+ }
+ sqlite3_reset(stmtInsert);
+ sqlite3_bind_int64(stmtInsert, 1, rowid);
+ sqlite3_bind_blob(stmtInsert, 2, vec, dims * sizeof(float), SQLITE_TRANSIENT);
+ sqlite3_step(stmtInsert);
+ break;
+ }
+ }
+ }
+ free(vec);
+
+cleanup:
+ sqlite3_finalize(stmtInsert);
+ sqlite3_finalize(stmtDelete);
+ sqlite3_finalize(stmtKnn);
+ sqlite3_close(db);
+ return 0;
+}
diff --git a/tests/fuzz/diskann-delete-stress.c b/tests/fuzz/diskann-delete-stress.c
new file mode 100644
index 0000000..d10a7ff
--- /dev/null
+++ b/tests/fuzz/diskann-delete-stress.c
@@ -0,0 +1,175 @@
+/**
+ * Fuzz target for DiskANN delete path and graph connectivity maintenance.
+ *
+ * The delete path is the most complex graph mutation:
+ * 1. Read deleted node's neighbor list
+ * 2. For each neighbor, remove deleted node from their list
+ * 3. Try to fill the gap with one of deleted node's other neighbors
+ * 4. Handle medoid deletion (pick new medoid)
+ *
+ * Edge cases this targets:
+ * - Delete the medoid (entry point) -- forces medoid reassignment
+ * - Delete all nodes except one -- graph degenerates
+ * - Delete nodes in a chain -- cascading dangling edges
+ * - Re-insert at deleted rowids -- stale graph edges to old data
+ * - Delete nonexistent rowids -- should be no-op
+ * - Insert-delete-insert same rowid rapidly
+ * - Delete when graph has exactly n_neighbors entries (full nodes)
+ *
+ * Key code paths:
+ * - diskann_delete -> diskann_repair_reverse_edges
+ * - diskann_medoid_handle_delete
+ * - diskann_node_clear_neighbor
+ * - Interaction between delete and concurrent search
+ */
+#include
+#include
+#include
+#include
+#include
+#include "sqlite-vec.h"
+#include "sqlite3.h"
+#include
+
+static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) {
+ if (*size == 0) return def;
+ uint8_t b = **data;
+ (*data)++;
+ (*size)--;
+ return b;
+}
+
+int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
+ if (size < 20) return 0;
+
+ int rc;
+ sqlite3 *db;
+ rc = sqlite3_open(":memory:", &db);
+ assert(rc == SQLITE_OK);
+ rc = sqlite3_vec_init(db, NULL, NULL);
+ assert(rc == SQLITE_OK);
+
+ /* int8 quantizer to exercise that distance code path */
+ uint8_t quant = fuzz_byte(&data, &size, 0) % 2;
+ const char *qname = quant ? "int8" : "binary";
+
+ char sql[256];
+ snprintf(sql, sizeof(sql),
+ "CREATE VIRTUAL TABLE v USING vec0("
+ "emb float[8] INDEXED BY diskann(neighbor_quantizer=%s, n_neighbors=8))",
+ qname);
+ rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
+ if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
+
+ sqlite3_stmt *stmtInsert = NULL, *stmtDelete = NULL, *stmtKnn = NULL;
+ sqlite3_prepare_v2(db,
+ "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
+ sqlite3_prepare_v2(db,
+ "DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
+ sqlite3_prepare_v2(db,
+ "SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?",
+ -1, &stmtKnn, NULL);
+
+ if (!stmtInsert || !stmtDelete || !stmtKnn) goto cleanup;
+
+ /* Phase 1: Build a graph of exactly n_neighbors+2 = 10 nodes.
+ * This makes every node nearly full, maximizing the chance that
+ * inserts trigger the "full node" path in add_reverse_edge. */
+ for (int i = 1; i <= 10; i++) {
+ float vec[8];
+ for (int j = 0; j < 8; j++) {
+ vec[j] = (float)((int8_t)fuzz_byte(&data, &size, (uint8_t)(i*13+j*7))) / 20.0f;
+ }
+ sqlite3_reset(stmtInsert);
+ sqlite3_bind_int64(stmtInsert, 1, i);
+ sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
+ sqlite3_step(stmtInsert);
+ }
+
+ /* Phase 2: Fuzz-driven delete-heavy workload */
+ while (size >= 2) {
+ uint8_t op = fuzz_byte(&data, &size, 0);
+ uint8_t param = fuzz_byte(&data, &size, 0);
+
+ switch (op % 6) {
+ case 0: /* Delete existing node */
+ case 1: { /* (weighted toward deletes) */
+ int64_t rowid = (int64_t)(param % 16) + 1;
+ sqlite3_reset(stmtDelete);
+ sqlite3_bind_int64(stmtDelete, 1, rowid);
+ sqlite3_step(stmtDelete);
+ break;
+ }
+ case 2: { /* Delete then immediately re-insert same rowid */
+ int64_t rowid = (int64_t)(param % 10) + 1;
+ sqlite3_reset(stmtDelete);
+ sqlite3_bind_int64(stmtDelete, 1, rowid);
+ sqlite3_step(stmtDelete);
+
+ float vec[8];
+ for (int j = 0; j < 8; j++) {
+ vec[j] = (float)((int8_t)fuzz_byte(&data, &size, (uint8_t)(rowid+j))) / 15.0f;
+ }
+ sqlite3_reset(stmtInsert);
+ sqlite3_bind_int64(stmtInsert, 1, rowid);
+ sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
+ sqlite3_step(stmtInsert);
+ break;
+ }
+ case 3: { /* KNN query on potentially sparse/empty graph */
+ float qvec[8];
+ for (int j = 0; j < 8; j++) {
+ qvec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f;
+ }
+ int k = (param % 15) + 1;
+ sqlite3_reset(stmtKnn);
+ sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_TRANSIENT);
+ sqlite3_bind_int(stmtKnn, 2, k);
+ while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
+ break;
+ }
+ case 4: { /* Insert new node */
+ int64_t rowid = (int64_t)(param % 32) + 1;
+ float vec[8];
+ for (int j = 0; j < 8; j++) {
+ vec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f;
+ }
+ sqlite3_reset(stmtInsert);
+ sqlite3_bind_int64(stmtInsert, 1, rowid);
+ sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
+ sqlite3_step(stmtInsert);
+ break;
+ }
+ case 5: { /* Delete ALL remaining nodes, then insert fresh */
+ for (int i = 1; i <= 32; i++) {
+ sqlite3_reset(stmtDelete);
+ sqlite3_bind_int64(stmtDelete, 1, i);
+ sqlite3_step(stmtDelete);
+ }
+ /* Now insert one node into empty graph */
+ float vec[8] = {1.0f, 0, 0, 0, 0, 0, 0, 0};
+ sqlite3_reset(stmtInsert);
+ sqlite3_bind_int64(stmtInsert, 1, 1);
+ sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
+ sqlite3_step(stmtInsert);
+ break;
+ }
+ }
+ }
+
+ /* Final KNN on whatever state the graph is in */
+ {
+ float qvec[8] = {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f};
+ sqlite3_reset(stmtKnn);
+ sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_TRANSIENT);
+ sqlite3_bind_int(stmtKnn, 2, 10);
+ while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
+ }
+
+cleanup:
+ sqlite3_finalize(stmtInsert);
+ sqlite3_finalize(stmtDelete);
+ sqlite3_finalize(stmtKnn);
+ sqlite3_close(db);
+ return 0;
+}
diff --git a/tests/fuzz/diskann-graph-corrupt.c b/tests/fuzz/diskann-graph-corrupt.c
new file mode 100644
index 0000000..a8dbc19
--- /dev/null
+++ b/tests/fuzz/diskann-graph-corrupt.c
@@ -0,0 +1,123 @@
+/**
+ * Fuzz target for DiskANN shadow table corruption resilience.
+ * Creates and populates a DiskANN table, then corrupts shadow table blobs
+ * using fuzz data and runs queries.
+ */
+#include
+#include
+#include
+#include
+#include
+#include "sqlite-vec.h"
+#include "sqlite3.h"
+#include
+
+int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
+ if (size < 16) return 0;
+
+ int rc;
+ sqlite3 *db;
+
+ rc = sqlite3_open(":memory:", &db);
+ assert(rc == SQLITE_OK);
+ rc = sqlite3_vec_init(db, NULL, NULL);
+ assert(rc == SQLITE_OK);
+
+ rc = sqlite3_exec(db,
+ "CREATE VIRTUAL TABLE v USING vec0("
+ "emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8))",
+ NULL, NULL, NULL);
+ if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
+
+ /* Insert a few vectors to create graph structure */
+ {
+ sqlite3_stmt *stmt;
+ sqlite3_prepare_v2(db,
+ "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmt, NULL);
+ for (int i = 1; i <= 10; i++) {
+ float vec[8];
+ for (int j = 0; j < 8; j++) {
+ vec[j] = (float)i * 0.1f + (float)j * 0.01f;
+ }
+ sqlite3_reset(stmt);
+ sqlite3_bind_int64(stmt, 1, i);
+ sqlite3_bind_blob(stmt, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
+ sqlite3_step(stmt);
+ }
+ sqlite3_finalize(stmt);
+ }
+
+ /* Corrupt shadow table data using fuzz bytes */
+ size_t offset = 0;
+
+ /* Determine which row and column to corrupt */
+ int target_row = (data[offset++] % 10) + 1;
+ int corrupt_type = data[offset++] % 3; /* 0=validity, 1=neighbor_ids, 2=qvecs */
+
+ const char *column_name;
+ switch (corrupt_type) {
+ case 0: column_name = "neighbors_validity"; break;
+ case 1: column_name = "neighbor_ids"; break;
+ default: column_name = "neighbor_quantized_vectors"; break;
+ }
+
+ /* Read the blob, corrupt it, write it back */
+ {
+ sqlite3_stmt *readStmt;
+ char sqlbuf[256];
+ snprintf(sqlbuf, sizeof(sqlbuf),
+ "SELECT %s FROM v_diskann_nodes00 WHERE rowid = ?", column_name);
+ rc = sqlite3_prepare_v2(db, sqlbuf, -1, &readStmt, NULL);
+ if (rc == SQLITE_OK) {
+ sqlite3_bind_int64(readStmt, 1, target_row);
+ if (sqlite3_step(readStmt) == SQLITE_ROW) {
+ const void *blob = sqlite3_column_blob(readStmt, 0);
+ int blobSize = sqlite3_column_bytes(readStmt, 0);
+ if (blob && blobSize > 0) {
+ unsigned char *corrupt = sqlite3_malloc(blobSize);
+ if (corrupt) {
+ memcpy(corrupt, blob, blobSize);
+ /* Apply fuzz bytes as XOR corruption */
+ size_t remaining = size - offset;
+ for (size_t i = 0; i < remaining && i < (size_t)blobSize; i++) {
+ corrupt[i % blobSize] ^= data[offset + i];
+ }
+ /* Write back */
+ sqlite3_stmt *writeStmt;
+ snprintf(sqlbuf, sizeof(sqlbuf),
+ "UPDATE v_diskann_nodes00 SET %s = ? WHERE rowid = ?", column_name);
+ rc = sqlite3_prepare_v2(db, sqlbuf, -1, &writeStmt, NULL);
+ if (rc == SQLITE_OK) {
+ sqlite3_bind_blob(writeStmt, 1, corrupt, blobSize, SQLITE_TRANSIENT);
+ sqlite3_bind_int64(writeStmt, 2, target_row);
+ sqlite3_step(writeStmt);
+ sqlite3_finalize(writeStmt);
+ }
+ sqlite3_free(corrupt);
+ }
+ }
+ }
+ sqlite3_finalize(readStmt);
+ }
+ }
+
+ /* Run queries on corrupted graph -- should not crash */
+ {
+ float qvec[8] = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
+ sqlite3_stmt *knnStmt;
+ rc = sqlite3_prepare_v2(db,
+ "SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = 5",
+ -1, &knnStmt, NULL);
+ if (rc == SQLITE_OK) {
+ sqlite3_bind_blob(knnStmt, 1, qvec, sizeof(qvec), SQLITE_STATIC);
+ while (sqlite3_step(knnStmt) == SQLITE_ROW) {}
+ sqlite3_finalize(knnStmt);
+ }
+ }
+
+ /* Full scan */
+ sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
+
+ sqlite3_close(db);
+ return 0;
+}
diff --git a/tests/fuzz/diskann-int8-quant.c b/tests/fuzz/diskann-int8-quant.c
new file mode 100644
index 0000000..f1bd31d
--- /dev/null
+++ b/tests/fuzz/diskann-int8-quant.c
@@ -0,0 +1,164 @@
+/**
+ * Fuzz target for DiskANN int8 quantizer edge cases.
+ *
+ * The binary quantizer is simple (sign bit), but the int8 quantizer has
+ * interesting arithmetic:
+ * i8_val = (i8)(((src - (-1.0f)) / step) - 128.0f)
+ * where step = 2.0f / 255.0f
+ *
+ * Edge cases in this formula:
+ * - src values outside [-1, 1] cause clamping issues (no explicit clamp!)
+ * - src = NaN, +Inf, -Inf (from corrupted vectors or div-by-zero)
+ * - src very close to boundaries (-1.0, 1.0) -- rounding
+ * - The cast to i8 can overflow for extreme src values
+ *
+ * Also exercises int8 distance functions:
+ * - distance_l2_sqr_int8: accumulates squared differences, possible overflow
+ * - distance_cosine_int8: dot product with normalization
+ * - distance_l1_int8: absolute differences
+ *
+ * This fuzzer also tests the cosine distance metric path which the
+ * other fuzzers (using L2 default) don't cover.
+ */
+#include
+#include
+#include
+#include
+#include
+#include
+#include "sqlite-vec.h"
+#include "sqlite3.h"
+#include
+
+static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) {
+ if (*size == 0) return def;
+ uint8_t b = **data;
+ (*data)++;
+ (*size)--;
+ return b;
+}
+
+static float fuzz_extreme_float(const uint8_t **data, size_t *size) {
+ uint8_t mode = fuzz_byte(data, size, 0) % 8;
+ uint8_t raw = fuzz_byte(data, size, 0);
+ switch (mode) {
+ case 0: return (float)((int8_t)raw) / 10.0f; /* Normal range */
+ case 1: return (float)((int8_t)raw) * 100.0f; /* Large values */
+ case 2: return (float)((int8_t)raw) / 1000.0f; /* Tiny values near 0 */
+ case 3: return -1.0f; /* Exact boundary */
+ case 4: return 1.0f; /* Exact boundary */
+ case 5: return 0.0f; /* Zero */
+ case 6: return (float)raw / 255.0f; /* [0, 1] range */
+ case 7: return -(float)raw / 255.0f; /* [-1, 0] range */
+ }
+ return 0.0f;
+}
+
+int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
+ if (size < 40) return 0;
+
+ int rc;
+ sqlite3 *db;
+ rc = sqlite3_open(":memory:", &db);
+ assert(rc == SQLITE_OK);
+ rc = sqlite3_vec_init(db, NULL, NULL);
+ assert(rc == SQLITE_OK);
+
+ /* Test both distance metrics with int8 quantizer */
+ uint8_t metric_choice = fuzz_byte(&data, &size, 0) % 2;
+ const char *metric = metric_choice ? "cosine" : "L2";
+
+ int dims = 8 + (fuzz_byte(&data, &size, 0) % 3) * 8; /* 8, 16, or 24 */
+
+ char sql[512];
+ snprintf(sql, sizeof(sql),
+ "CREATE VIRTUAL TABLE v USING vec0("
+ "emb float[%d] distance_metric=%s "
+ "INDEXED BY diskann(neighbor_quantizer=int8, n_neighbors=8, search_list_size=16))",
+ dims, metric);
+
+ rc = sqlite3_exec(db, sql, NULL, NULL, NULL);
+ if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
+
+ sqlite3_stmt *stmtInsert = NULL, *stmtKnn = NULL, *stmtDelete = NULL;
+ sqlite3_prepare_v2(db,
+ "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
+ sqlite3_prepare_v2(db,
+ "SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?",
+ -1, &stmtKnn, NULL);
+ sqlite3_prepare_v2(db,
+ "DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
+
+ if (!stmtInsert || !stmtKnn || !stmtDelete) goto cleanup;
+
+ /* Insert vectors with extreme float values to stress quantization */
+ float *vec = malloc(dims * sizeof(float));
+ if (!vec) goto cleanup;
+
+ for (int i = 1; i <= 16; i++) {
+ for (int j = 0; j < dims; j++) {
+ vec[j] = fuzz_extreme_float(&data, &size);
+ }
+ sqlite3_reset(stmtInsert);
+ sqlite3_bind_int64(stmtInsert, 1, i);
+ sqlite3_bind_blob(stmtInsert, 2, vec, dims * sizeof(float), SQLITE_TRANSIENT);
+ sqlite3_step(stmtInsert);
+ }
+
+ /* Fuzz-driven operations */
+ while (size >= 2) {
+ uint8_t op = fuzz_byte(&data, &size, 0) % 4;
+ uint8_t param = fuzz_byte(&data, &size, 0);
+
+ switch (op) {
+ case 0: { /* KNN with extreme query values */
+ for (int j = 0; j < dims; j++) {
+ vec[j] = fuzz_extreme_float(&data, &size);
+ }
+ int k = (param % 10) + 1;
+ sqlite3_reset(stmtKnn);
+ sqlite3_bind_blob(stmtKnn, 1, vec, dims * sizeof(float), SQLITE_TRANSIENT);
+ sqlite3_bind_int(stmtKnn, 2, k);
+ while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
+ break;
+ }
+ case 1: { /* Insert with extreme values */
+ int64_t rowid = (int64_t)(param % 32) + 1;
+ for (int j = 0; j < dims; j++) {
+ vec[j] = fuzz_extreme_float(&data, &size);
+ }
+ sqlite3_reset(stmtInsert);
+ sqlite3_bind_int64(stmtInsert, 1, rowid);
+ sqlite3_bind_blob(stmtInsert, 2, vec, dims * sizeof(float), SQLITE_TRANSIENT);
+ sqlite3_step(stmtInsert);
+ break;
+ }
+ case 2: { /* Delete */
+ int64_t rowid = (int64_t)(param % 32) + 1;
+ sqlite3_reset(stmtDelete);
+ sqlite3_bind_int64(stmtDelete, 1, rowid);
+ sqlite3_step(stmtDelete);
+ break;
+ }
+ case 3: { /* KNN with all-zero or all-same-value query */
+ float val = (param % 3 == 0) ? 0.0f :
+ (param % 3 == 1) ? 1.0f : -1.0f;
+ for (int j = 0; j < dims; j++) vec[j] = val;
+ sqlite3_reset(stmtKnn);
+ sqlite3_bind_blob(stmtKnn, 1, vec, dims * sizeof(float), SQLITE_TRANSIENT);
+ sqlite3_bind_int(stmtKnn, 2, 5);
+ while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
+ break;
+ }
+ }
+ }
+
+ free(vec);
+
+cleanup:
+ sqlite3_finalize(stmtInsert);
+ sqlite3_finalize(stmtKnn);
+ sqlite3_finalize(stmtDelete);
+ sqlite3_close(db);
+ return 0;
+}
diff --git a/tests/fuzz/diskann-operations.c b/tests/fuzz/diskann-operations.c
new file mode 100644
index 0000000..b36620b
--- /dev/null
+++ b/tests/fuzz/diskann-operations.c
@@ -0,0 +1,100 @@
+/**
+ * Fuzz target for DiskANN insert/delete/query operation sequences.
+ * Uses fuzz bytes to drive random operations on a DiskANN-indexed table.
+ */
+#include
+#include
+#include
+#include
+#include
+#include "sqlite-vec.h"
+#include "sqlite3.h"
+#include
+
+int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
+ if (size < 6) return 0;
+
+ int rc;
+ sqlite3 *db;
+ sqlite3_stmt *stmtInsert = NULL;
+ sqlite3_stmt *stmtDelete = NULL;
+ sqlite3_stmt *stmtKnn = NULL;
+ sqlite3_stmt *stmtScan = NULL;
+
+ rc = sqlite3_open(":memory:", &db);
+ assert(rc == SQLITE_OK);
+ rc = sqlite3_vec_init(db, NULL, NULL);
+ assert(rc == SQLITE_OK);
+
+ rc = sqlite3_exec(db,
+ "CREATE VIRTUAL TABLE v USING vec0("
+ "emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8))",
+ NULL, NULL, NULL);
+ if (rc != SQLITE_OK) { sqlite3_close(db); return 0; }
+
+ sqlite3_prepare_v2(db,
+ "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL);
+ sqlite3_prepare_v2(db,
+ "DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL);
+ sqlite3_prepare_v2(db,
+ "SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = 3",
+ -1, &stmtKnn, NULL);
+ sqlite3_prepare_v2(db,
+ "SELECT rowid FROM v", -1, &stmtScan, NULL);
+
+ if (!stmtInsert || !stmtDelete || !stmtKnn || !stmtScan) goto cleanup;
+
+ size_t i = 0;
+ while (i + 2 <= size) {
+ uint8_t op = data[i++] % 4;
+ uint8_t rowid_byte = data[i++];
+ int64_t rowid = (int64_t)(rowid_byte % 32) + 1;
+
+ switch (op) {
+ case 0: {
+ /* INSERT: consume 32 bytes for 8 floats, or use what's left */
+ float vec[8] = {0};
+ for (int j = 0; j < 8 && i < size; j++, i++) {
+ vec[j] = (float)((int8_t)data[i]) / 10.0f;
+ }
+ sqlite3_reset(stmtInsert);
+ sqlite3_bind_int64(stmtInsert, 1, rowid);
+ sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT);
+ sqlite3_step(stmtInsert);
+ break;
+ }
+ case 1: {
+ /* DELETE */
+ sqlite3_reset(stmtDelete);
+ sqlite3_bind_int64(stmtDelete, 1, rowid);
+ sqlite3_step(stmtDelete);
+ break;
+ }
+ case 2: {
+ /* KNN query */
+ float qvec[8] = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
+ sqlite3_reset(stmtKnn);
+ sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_STATIC);
+ while (sqlite3_step(stmtKnn) == SQLITE_ROW) {}
+ break;
+ }
+ case 3: {
+ /* Full scan */
+ sqlite3_reset(stmtScan);
+ while (sqlite3_step(stmtScan) == SQLITE_ROW) {}
+ break;
+ }
+ }
+ }
+
+ /* Final operations -- must not crash regardless of prior state */
+ sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL);
+
+cleanup:
+ sqlite3_finalize(stmtInsert);
+ sqlite3_finalize(stmtDelete);
+ sqlite3_finalize(stmtKnn);
+ sqlite3_finalize(stmtScan);
+ sqlite3_close(db);
+ return 0;
+}
diff --git a/tests/fuzz/diskann-prune-direct.c b/tests/fuzz/diskann-prune-direct.c
new file mode 100644
index 0000000..7a440ad
--- /dev/null
+++ b/tests/fuzz/diskann-prune-direct.c
@@ -0,0 +1,131 @@
+/**
+ * Fuzz target for DiskANN RobustPrune algorithm (diskann_prune_select).
+ *
+ * diskann_prune_select is exposed for testing and takes:
+ * - inter_distances: flattened NxN matrix of inter-candidate distances
+ * - p_distances: N distances from node p to each candidate
+ * - num_candidates, alpha, max_neighbors
+ *
+ * This is a pure function that doesn't need a database, so we can
+ * call it directly with fuzz-controlled inputs. This gives the fuzzer
+ * maximum speed (no SQLite overhead) to explore:
+ *
+ * - alpha boundary: alpha=0 (prunes nothing), alpha=very large (prunes all)
+ * - max_neighbors = 0, 1, num_candidates, > num_candidates
+ * - num_candidates = 0, 1, large
+ * - Distance matrices with: all zeros, all same, negative values, NaN, Inf
+ * - Non-symmetric distance matrices (should still work)
+ * - Memory: large num_candidates to stress malloc
+ *
+ * Key code paths:
+ * - diskann_prune_select alpha-pruning loop
+ * - Boundary: selectedCount reaches max_neighbors exactly
+ * - All candidates pruned before max_neighbors reached
+ */
+#include
+#include
+#include
+#include
+#include
+#include
+#include "sqlite-vec.h"
+#include "sqlite3.h"
+#include
+
+/* Declare the test-exposed function.
+ * diskann_prune_select is not static -- it's a public symbol. */
+extern int diskann_prune_select(
+ const float *inter_distances, const float *p_distances,
+ int num_candidates, float alpha, int max_neighbors,
+ int *outSelected, int *outCount);
+
+static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) {
+ if (*size == 0) return def;
+ uint8_t b = **data;
+ (*data)++;
+ (*size)--;
+ return b;
+}
+
+int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
+ if (size < 8) return 0;
+
+ /* Consume parameters from fuzz data */
+ int num_candidates = fuzz_byte(&data, &size, 0) % 33; /* 0..32 */
+ int max_neighbors = fuzz_byte(&data, &size, 0) % 17; /* 0..16 */
+
+ /* Alpha: pick from interesting values */
+ uint8_t alpha_idx = fuzz_byte(&data, &size, 0) % 8;
+ float alpha_values[] = {0.0f, 0.5f, 1.0f, 1.2f, 1.5f, 2.0f, 10.0f, 100.0f};
+ float alpha = alpha_values[alpha_idx];
+
+ if (num_candidates == 0) {
+ /* Test empty case */
+ int outCount = -1;
+ int rc = diskann_prune_select(NULL, NULL, 0, alpha, max_neighbors,
+ NULL, &outCount);
+ assert(rc == 0 /* SQLITE_OK */);
+ assert(outCount == 0);
+ return 0;
+ }
+
+ /* Allocate arrays */
+ int n = num_candidates;
+ float *inter_distances = malloc(n * n * sizeof(float));
+ float *p_distances = malloc(n * sizeof(float));
+ int *outSelected = malloc(n * sizeof(int));
+ if (!inter_distances || !p_distances || !outSelected) {
+ free(inter_distances);
+ free(p_distances);
+ free(outSelected);
+ return 0;
+ }
+
+ /* Fill p_distances from fuzz data (sorted ascending for correct input) */
+ for (int i = 0; i < n; i++) {
+ uint8_t raw = fuzz_byte(&data, &size, (uint8_t)(i * 10));
+ p_distances[i] = (float)raw / 10.0f;
+ }
+ /* Sort p_distances ascending (prune_select expects sorted input) */
+ for (int i = 1; i < n; i++) {
+ float tmp = p_distances[i];
+ int j = i - 1;
+ while (j >= 0 && p_distances[j] > tmp) {
+ p_distances[j + 1] = p_distances[j];
+ j--;
+ }
+ p_distances[j + 1] = tmp;
+ }
+
+ /* Fill inter-distance matrix from fuzz data */
+ for (int i = 0; i < n * n; i++) {
+ uint8_t raw = fuzz_byte(&data, &size, (uint8_t)(i % 256));
+ inter_distances[i] = (float)raw / 10.0f;
+ }
+ /* Make diagonal zero */
+ for (int i = 0; i < n; i++) {
+ inter_distances[i * n + i] = 0.0f;
+ }
+
+ int outCount = -1;
+ int rc = diskann_prune_select(inter_distances, p_distances,
+ n, alpha, max_neighbors,
+ outSelected, &outCount);
+ /* Basic sanity: should not crash, count should be valid */
+ assert(rc == 0);
+ assert(outCount >= 0);
+ assert(outCount <= max_neighbors || max_neighbors == 0);
+ assert(outCount <= n);
+
+ /* Verify outSelected flags are consistent with outCount */
+ int flagCount = 0;
+ for (int i = 0; i < n; i++) {
+ if (outSelected[i]) flagCount++;
+ }
+ assert(flagCount == outCount);
+
+ free(inter_distances);
+ free(p_distances);
+ free(outSelected);
+ return 0;
+}
diff --git a/tests/fuzz/diskann.dict b/tests/fuzz/diskann.dict
new file mode 100644
index 0000000..31d289d
--- /dev/null
+++ b/tests/fuzz/diskann.dict
@@ -0,0 +1,10 @@
+"neighbor_quantizer"
+"binary"
+"int8"
+"n_neighbors"
+"search_list_size"
+"search_list_size_search"
+"search_list_size_insert"
+"alpha"
+"="
+","
diff --git a/tests/sqlite-vec-internal.h b/tests/sqlite-vec-internal.h
index 67f1370..313add4 100644
--- a/tests/sqlite-vec-internal.h
+++ b/tests/sqlite-vec-internal.h
@@ -73,6 +73,7 @@ enum Vec0IndexType {
VEC0_INDEX_TYPE_RESCORE = 2,
#endif
VEC0_INDEX_TYPE_IVF = 3,
+ VEC0_INDEX_TYPE_DISKANN = 4,
};
enum Vec0RescoreQuantizerType {
@@ -114,6 +115,20 @@ struct Vec0RescoreConfig {
};
#endif
+enum Vec0DiskannQuantizerType {
+ VEC0_DISKANN_QUANTIZER_BINARY = 1,
+ VEC0_DISKANN_QUANTIZER_INT8 = 2,
+};
+
+struct Vec0DiskannConfig {
+ enum Vec0DiskannQuantizerType quantizer_type;
+ int n_neighbors;
+ int search_list_size;
+ int search_list_size_search;
+ int search_list_size_insert;
+ float alpha;
+ int buffer_threshold;
+};
struct VectorColumnDefinition {
char *name;
@@ -126,6 +141,7 @@ struct VectorColumnDefinition {
struct Vec0RescoreConfig rescore;
#endif
struct Vec0IvfConfig ivf;
+ struct Vec0DiskannConfig diskann;
};
int vec0_parse_vector_column(const char *source, int source_length,
@@ -136,6 +152,48 @@ int vec0_parse_partition_key_definition(const char *source, int source_length,
int *out_column_name_length,
int *out_column_type);
+size_t diskann_quantized_vector_byte_size(
+ enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions);
+
+int diskann_validity_byte_size(int n_neighbors);
+size_t diskann_neighbor_ids_byte_size(int n_neighbors);
+size_t diskann_neighbor_qvecs_byte_size(
+ int n_neighbors, enum Vec0DiskannQuantizerType quantizer_type,
+ size_t dimensions);
+int diskann_node_init(
+ int n_neighbors, enum Vec0DiskannQuantizerType quantizer_type,
+ size_t dimensions,
+ unsigned char **outValidity, int *outValiditySize,
+ unsigned char **outNeighborIds, int *outNeighborIdsSize,
+ unsigned char **outNeighborQvecs, int *outNeighborQvecsSize);
+int diskann_validity_get(const unsigned char *validity, int i);
+void diskann_validity_set(unsigned char *validity, int i, int value);
+int diskann_validity_count(const unsigned char *validity, int n_neighbors);
+long long diskann_neighbor_id_get(const unsigned char *neighbor_ids, int i);
+void diskann_neighbor_id_set(unsigned char *neighbor_ids, int i, long long rowid);
+const unsigned char *diskann_neighbor_qvec_get(
+ const unsigned char *qvecs, int i,
+ enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions);
+void diskann_neighbor_qvec_set(
+ unsigned char *qvecs, int i, const unsigned char *src_qvec,
+ enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions);
+void diskann_node_set_neighbor(
+ unsigned char *validity, unsigned char *neighbor_ids, unsigned char *qvecs, int i,
+ long long neighbor_rowid, const unsigned char *neighbor_qvec,
+ enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions);
+void diskann_node_clear_neighbor(
+ unsigned char *validity, unsigned char *neighbor_ids, unsigned char *qvecs, int i,
+ enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions);
+int diskann_quantize_vector(
+ const float *src, size_t dimensions,
+ enum Vec0DiskannQuantizerType quantizer_type,
+ unsigned char *out);
+
+int diskann_prune_select(
+ const float *inter_distances, const float *p_distances,
+ int num_candidates, float alpha, int max_neighbors,
+ int *outSelected, int *outCount);
+
#ifdef SQLITE_VEC_TEST
float _test_distance_l2_sqr_float(const float *a, const float *b, size_t dims);
float _test_distance_cosine_float(const float *a, const float *b, size_t dims);
@@ -151,6 +209,33 @@ size_t _test_rescore_quantized_byte_size_int8(size_t dimensions);
void ivf_quantize_int8(const float *src, int8_t *dst, int D);
void ivf_quantize_binary(const float *src, uint8_t *dst, int D);
#endif
+// DiskANN candidate list (opaque struct, use accessors)
+struct DiskannCandidateList {
+ void *items; // opaque
+ int count;
+ int capacity;
+};
+
+int _test_diskann_candidate_list_init(struct DiskannCandidateList *list, int capacity);
+void _test_diskann_candidate_list_free(struct DiskannCandidateList *list);
+int _test_diskann_candidate_list_insert(struct DiskannCandidateList *list, long long rowid, float distance);
+int _test_diskann_candidate_list_next_unvisited(const struct DiskannCandidateList *list);
+int _test_diskann_candidate_list_count(const struct DiskannCandidateList *list);
+long long _test_diskann_candidate_list_rowid(const struct DiskannCandidateList *list, int i);
+float _test_diskann_candidate_list_distance(const struct DiskannCandidateList *list, int i);
+void _test_diskann_candidate_list_set_visited(struct DiskannCandidateList *list, int i);
+
+// DiskANN visited set (opaque struct, use accessors)
+struct DiskannVisitedSet {
+ void *slots; // opaque
+ int capacity;
+ int count;
+};
+
+int _test_diskann_visited_set_init(struct DiskannVisitedSet *set, int capacity);
+void _test_diskann_visited_set_free(struct DiskannVisitedSet *set);
+int _test_diskann_visited_set_contains(const struct DiskannVisitedSet *set, long long rowid);
+int _test_diskann_visited_set_insert(struct DiskannVisitedSet *set, long long rowid);
#endif
#endif /* SQLITE_VEC_INTERNAL_H */
diff --git a/tests/test-diskann.py b/tests/test-diskann.py
new file mode 100644
index 0000000..4c049ce
--- /dev/null
+++ b/tests/test-diskann.py
@@ -0,0 +1,1160 @@
+import sqlite3
+import struct
+import pytest
+from helpers import _f32, exec
+
+
+def test_diskann_create_basic(db):
+ """Basic DiskANN table creation with binary quantizer should succeed."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[128] INDEXED BY diskann(neighbor_quantizer=binary)
+ )
+ """)
+ # Table should exist
+ tables = [
+ row[0]
+ for row in db.execute(
+ "select name from sqlite_master where name like 't%' order by 1"
+ ).fetchall()
+ ]
+ assert "t" in tables
+
+
+def test_diskann_create_int8_quantizer(db):
+ """DiskANN with int8 quantizer should succeed."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[64] INDEXED BY diskann(neighbor_quantizer=int8)
+ )
+ """)
+ tables = [
+ row[0]
+ for row in db.execute(
+ "select name from sqlite_master where name like 't%' order by 1"
+ ).fetchall()
+ ]
+ assert "t" in tables
+
+
+def test_diskann_create_with_options(db):
+ """DiskANN with custom n_neighbors and search_list_size."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[128] INDEXED BY diskann(
+ neighbor_quantizer=binary,
+ n_neighbors=48,
+ search_list_size=256
+ )
+ )
+ """)
+ tables = [
+ row[0]
+ for row in db.execute(
+ "select name from sqlite_master where name like 't%' order by 1"
+ ).fetchall()
+ ]
+ assert "t" in tables
+
+
+def test_diskann_create_with_distance_metric(db):
+ """DiskANN combined with distance_metric should work."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[128] distance_metric=cosine INDEXED BY diskann(neighbor_quantizer=binary)
+ )
+ """)
+ tables = [
+ row[0]
+ for row in db.execute(
+ "select name from sqlite_master where name like 't%' order by 1"
+ ).fetchall()
+ ]
+ assert "t" in tables
+
+
+def test_diskann_create_error_missing_quantizer(db):
+ """Error when neighbor_quantizer is not specified."""
+ result = exec(db, """
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[128] INDEXED BY diskann(n_neighbors=72)
+ )
+ """)
+ assert "error" in result
+
+
+def test_diskann_create_error_empty_parens(db):
+ """Error on empty parens."""
+ result = exec(db, """
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[128] INDEXED BY diskann()
+ )
+ """)
+ assert "error" in result
+
+
+def test_diskann_create_error_unknown_quantizer(db):
+ """Error on unknown quantizer type."""
+ result = exec(db, """
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[128] INDEXED BY diskann(neighbor_quantizer=unknown)
+ )
+ """)
+ assert "error" in result
+
+
+def test_diskann_create_error_bit_column(db):
+ """Error: DiskANN not supported on bit vector columns."""
+ result = exec(db, """
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb bit[128] INDEXED BY diskann(neighbor_quantizer=binary)
+ )
+ """)
+ assert "error" in result
+ assert "bit" in result["message"].lower() or "DiskANN" in result["message"]
+
+
+def test_diskann_create_error_binary_quantizer_odd_dims(db):
+ """Error: binary quantizer requires dimensions divisible by 8."""
+ result = exec(db, """
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[13] INDEXED BY diskann(neighbor_quantizer=binary)
+ )
+ """)
+ assert "error" in result
+ assert "divisible" in result["message"].lower()
+
+
+def test_diskann_create_error_bad_n_neighbors(db):
+ """Error: n_neighbors must be divisible by 8."""
+ result = exec(db, """
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=13)
+ )
+ """)
+ assert "error" in result
+
+
+def test_diskann_shadow_tables_created(db):
+ """DiskANN table should create _vectors00 and _diskann_nodes00 shadow tables."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[64] INDEXED BY diskann(neighbor_quantizer=binary)
+ )
+ """)
+ tables = sorted([
+ row[0]
+ for row in db.execute(
+ "select name from sqlite_master where type='table' and name like 't_%' order by 1"
+ ).fetchall()
+ ])
+ assert "t_vectors00" in tables
+ assert "t_diskann_nodes00" in tables
+ # DiskANN columns should NOT have _vector_chunks00
+ assert "t_vector_chunks00" not in tables
+
+
+def test_diskann_medoid_in_info(db):
+ """_info table should contain diskann_medoid_00 key with NULL value."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[64] INDEXED BY diskann(neighbor_quantizer=binary)
+ )
+ """)
+ row = db.execute(
+ "SELECT key, value FROM t_info WHERE key = 'diskann_medoid_00'"
+ ).fetchone()
+ assert row is not None
+ assert row[0] == "diskann_medoid_00"
+ assert row[1] is None
+
+
+def test_non_diskann_no_extra_tables(db):
+ """Non-DiskANN table must NOT create _vectors or _diskann_nodes tables."""
+ db.execute("CREATE VIRTUAL TABLE t USING vec0(emb float[64])")
+ tables = [
+ row[0]
+ for row in db.execute(
+ "select name from sqlite_master where type='table' and name like 't_%' order by 1"
+ ).fetchall()
+ ]
+ assert "t_vectors00" not in tables
+ assert "t_diskann_nodes00" not in tables
+ assert "t_vector_chunks00" in tables
+
+
+def test_diskann_medoid_initial_null(db):
+ """Medoid should be NULL initially (empty graph)."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[64] INDEXED BY diskann(neighbor_quantizer=binary)
+ )
+ """)
+ row = db.execute(
+ "SELECT value FROM t_info WHERE key = 'diskann_medoid_00'"
+ ).fetchone()
+ assert row[0] is None
+
+
+def test_diskann_medoid_set_via_info(db):
+ """Setting medoid via _info table should be retrievable."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[64] INDEXED BY diskann(neighbor_quantizer=binary)
+ )
+ """)
+ # Manually set medoid to simulate first insert
+ db.execute("UPDATE t_info SET value = 42 WHERE key = 'diskann_medoid_00'")
+ row = db.execute(
+ "SELECT value FROM t_info WHERE key = 'diskann_medoid_00'"
+ ).fetchone()
+ assert row[0] == 42
+
+ # Reset to NULL (empty graph)
+ db.execute("UPDATE t_info SET value = NULL WHERE key = 'diskann_medoid_00'")
+ row = db.execute(
+ "SELECT value FROM t_info WHERE key = 'diskann_medoid_00'"
+ ).fetchone()
+ assert row[0] is None
+
+
+def test_diskann_single_insert(db):
+ """Insert 1 vector. Verify _vectors00, _diskann_nodes00, and medoid."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[8] INDEXED BY diskann(neighbor_quantizer=binary)
+ )
+ """)
+ db.execute(
+ "INSERT INTO t(rowid, emb) VALUES (1, ?)",
+ [_f32([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
+ )
+ # Verify _vectors00 has 1 row
+ count = db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0]
+ assert count == 1
+
+ # Verify _diskann_nodes00 has 1 row
+ count = db.execute("SELECT count(*) FROM t_diskann_nodes00").fetchone()[0]
+ assert count == 1
+
+ # Verify medoid is set
+ medoid = db.execute(
+ "SELECT value FROM t_info WHERE key = 'diskann_medoid_00'"
+ ).fetchone()[0]
+ assert medoid == 1
+
+
+def test_diskann_multiple_inserts(db):
+ """Insert multiple vectors. Verify counts and that nodes have neighbors."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8)
+ )
+ """)
+ import random
+ random.seed(42)
+ for i in range(1, 21):
+ vec = [random.gauss(0, 1) for _ in range(8)]
+ db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32(vec)])
+
+ # Verify counts
+ assert db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] == 20
+ assert db.execute("SELECT count(*) FROM t_diskann_nodes00").fetchone()[0] == 20
+
+ # Every node after the first should have at least 1 neighbor
+ rows = db.execute(
+ "SELECT rowid, neighbors_validity FROM t_diskann_nodes00"
+ ).fetchall()
+ nodes_with_neighbors = 0
+ for row in rows:
+ validity = row[1]
+ has_neighbor = any(b != 0 for b in validity)
+ if has_neighbor:
+ nodes_with_neighbors += 1
+ # At minimum, nodes 2-20 should have neighbors (node 1 gets neighbors via reverse edges)
+ assert nodes_with_neighbors >= 19
+
+
+def test_diskann_bidirectional_edges(db):
+ """Insert A then B. B should be in A's neighbors and A in B's."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8)
+ )
+ """)
+ db.execute(
+ "INSERT INTO t(rowid, emb) VALUES (1, ?)",
+ [_f32([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
+ )
+ db.execute(
+ "INSERT INTO t(rowid, emb) VALUES (2, ?)",
+ [_f32([0.9, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])],
+ )
+
+ # Check B(2) is in A(1)'s neighbor list
+ row_a = db.execute(
+ "SELECT neighbor_ids FROM t_diskann_nodes00 WHERE rowid = 1"
+ ).fetchone()
+ neighbor_ids_a = struct.unpack(f"{len(row_a[0])//8}q", row_a[0])
+ assert 2 in neighbor_ids_a
+
+ # Check A(1) is in B(2)'s neighbor list
+ row_b = db.execute(
+ "SELECT neighbor_ids FROM t_diskann_nodes00 WHERE rowid = 2"
+ ).fetchone()
+ neighbor_ids_b = struct.unpack(f"{len(row_b[0])//8}q", row_b[0])
+ assert 1 in neighbor_ids_b
+
+
+def test_diskann_delete_single(db):
+ """Insert 3, delete 1. Verify counts."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8)
+ )
+ """)
+ for i in range(1, 4):
+ db.execute(
+ "INSERT INTO t(rowid, emb) VALUES (?, ?)",
+ [i, _f32([float(i)] * 8)],
+ )
+ db.execute("DELETE FROM t WHERE rowid = 2")
+
+ assert db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] == 2
+ assert db.execute("SELECT count(*) FROM t_diskann_nodes00").fetchone()[0] == 2
+
+
+def test_diskann_delete_no_stale_references(db):
+ """After delete, no node should reference the deleted rowid."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8)
+ )
+ """)
+ import random
+ random.seed(123)
+ for i in range(1, 11):
+ vec = [random.gauss(0, 1) for _ in range(8)]
+ db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32(vec)])
+
+ db.execute("DELETE FROM t WHERE rowid = 5")
+
+ # Scan all remaining nodes and verify rowid 5 is not in any neighbor list
+ rows = db.execute(
+ "SELECT rowid, neighbors_validity, neighbor_ids FROM t_diskann_nodes00"
+ ).fetchall()
+ for row in rows:
+ validity = row[1]
+ neighbor_ids_blob = row[2]
+ n_neighbors = len(validity) * 8
+ ids = struct.unpack(f"{n_neighbors}q", neighbor_ids_blob)
+ for i in range(n_neighbors):
+ byte_idx = i // 8
+ bit_idx = i % 8
+ if validity[byte_idx] & (1 << bit_idx):
+ assert ids[i] != 5, f"Node {row[0]} still references deleted rowid 5"
+
+
+def test_diskann_delete_medoid(db):
+ """Delete the medoid. Verify a new non-NULL medoid is selected."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8)
+ )
+ """)
+ for i in range(1, 4):
+ db.execute(
+ "INSERT INTO t(rowid, emb) VALUES (?, ?)",
+ [i, _f32([float(i)] * 8)],
+ )
+
+ medoid_before = db.execute(
+ "SELECT value FROM t_info WHERE key = 'diskann_medoid_00'"
+ ).fetchone()[0]
+ assert medoid_before == 1
+
+ db.execute("DELETE FROM t WHERE rowid = 1")
+
+ medoid_after = db.execute(
+ "SELECT value FROM t_info WHERE key = 'diskann_medoid_00'"
+ ).fetchone()[0]
+ assert medoid_after is not None
+ assert medoid_after != 1
+
+
+def test_diskann_delete_all(db):
+ """Delete all vectors. Medoid should be NULL."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8)
+ )
+ """)
+ for i in range(1, 4):
+ db.execute(
+ "INSERT INTO t(rowid, emb) VALUES (?, ?)",
+ [i, _f32([float(i)] * 8)],
+ )
+ for i in range(1, 4):
+ db.execute("DELETE FROM t WHERE rowid = ?", [i])
+
+ assert db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] == 0
+ assert db.execute("SELECT count(*) FROM t_diskann_nodes00").fetchone()[0] == 0
+
+ medoid = db.execute(
+ "SELECT value FROM t_info WHERE key = 'diskann_medoid_00'"
+ ).fetchone()[0]
+ assert medoid is None
+
+
+def test_diskann_insert_delete_insert_cycle(db):
+ """Insert, delete, insert again. No crashes."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8)
+ )
+ """)
+ db.execute("INSERT INTO t(rowid, emb) VALUES (1, ?)", [_f32([1.0] * 8)])
+ db.execute("INSERT INTO t(rowid, emb) VALUES (2, ?)", [_f32([2.0] * 8)])
+ db.execute("DELETE FROM t WHERE rowid = 1")
+ db.execute("INSERT INTO t(rowid, emb) VALUES (3, ?)", [_f32([3.0] * 8)])
+
+ assert db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] == 2
+ assert db.execute("SELECT count(*) FROM t_diskann_nodes00").fetchone()[0] == 2
+
+
+def test_diskann_knn_basic(db):
+ """Basic KNN query should return results."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8)
+ )
+ """)
+ db.execute("INSERT INTO t(rowid, emb) VALUES (1, ?)", [_f32([1, 0, 0, 0, 0, 0, 0, 0])])
+ db.execute("INSERT INTO t(rowid, emb) VALUES (2, ?)", [_f32([0, 1, 0, 0, 0, 0, 0, 0])])
+ db.execute("INSERT INTO t(rowid, emb) VALUES (3, ?)", [_f32([0.9, 0.1, 0, 0, 0, 0, 0, 0])])
+
+ rows = db.execute(
+ "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=2",
+ [_f32([1, 0, 0, 0, 0, 0, 0, 0])],
+ ).fetchall()
+ assert len(rows) == 2
+ # Closest should be rowid 1 (exact match)
+ assert rows[0][0] == 1
+ assert rows[0][1] < 0.01 # ~0 distance
+
+
+def test_diskann_knn_distances_sorted(db):
+ """Returned distances should be in ascending order."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=16)
+ )
+ """)
+ import random
+ random.seed(42)
+ for i in range(1, 51):
+ vec = [random.gauss(0, 1) for _ in range(8)]
+ db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32(vec)])
+
+ rows = db.execute(
+ "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=10",
+ [_f32([0.0] * 8)],
+ ).fetchall()
+ assert len(rows) == 10
+ distances = [r[1] for r in rows]
+ for i in range(len(distances) - 1):
+ assert distances[i] <= distances[i + 1], f"Distances not sorted at index {i}"
+
+
+def test_diskann_knn_empty_table(db):
+ """KNN on empty table should return 0 results."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8)
+ )
+ """)
+ rows = db.execute(
+ "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=5",
+ [_f32([1, 0, 0, 0, 0, 0, 0, 0])],
+ ).fetchall()
+ assert len(rows) == 0
+
+
+def test_diskann_knn_after_delete(db):
+ """KNN after delete should not return deleted rows."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8)
+ )
+ """)
+ db.execute("INSERT INTO t(rowid, emb) VALUES (1, ?)", [_f32([1, 0, 0, 0, 0, 0, 0, 0])])
+ db.execute("INSERT INTO t(rowid, emb) VALUES (2, ?)", [_f32([0, 1, 0, 0, 0, 0, 0, 0])])
+ db.execute("INSERT INTO t(rowid, emb) VALUES (3, ?)", [_f32([0.5, 0.5, 0, 0, 0, 0, 0, 0])])
+ db.execute("DELETE FROM t WHERE rowid = 1")
+
+ rows = db.execute(
+ "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=3",
+ [_f32([1, 0, 0, 0, 0, 0, 0, 0])],
+ ).fetchall()
+ rowids = [r[0] for r in rows]
+ assert 1 not in rowids
+ assert len(rows) == 2
+
+
+def test_diskann_no_index_still_works(db):
+ """Tables without INDEXED BY should still work identically."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[4]
+ )
+ """)
+ db.execute("INSERT INTO t(rowid, emb) VALUES (1, ?)", [_f32([1, 2, 3, 4])])
+ rows = db.execute(
+ "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=1",
+ [_f32([1, 2, 3, 4])],
+ ).fetchall()
+ assert len(rows) == 1
+ assert rows[0][0] == 1
+
+
+def test_diskann_drop_table(db):
+ """DROP TABLE should clean up all shadow tables."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[128] INDEXED BY diskann(neighbor_quantizer=binary)
+ )
+ """)
+ db.execute("DROP TABLE t")
+ tables = [
+ row[0]
+ for row in db.execute(
+ "select name from sqlite_master where name like 't%'"
+ ).fetchall()
+ ]
+ assert len(tables) == 0
+
+
+def test_diskann_create_split_search_list_size(db):
+ """DiskANN with separate search_list_size_search and search_list_size_insert."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[128] INDEXED BY diskann(
+ neighbor_quantizer=binary,
+ search_list_size_search=256,
+ search_list_size_insert=64
+ )
+ )
+ """)
+ tables = [
+ row[0]
+ for row in db.execute(
+ "select name from sqlite_master where name like 't%' order by 1"
+ ).fetchall()
+ ]
+ assert "t" in tables
+
+
+def test_diskann_create_error_mixed_search_list_size(db):
+ """Error when mixing search_list_size with search_list_size_search."""
+ result = exec(db, """
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[128] INDEXED BY diskann(
+ neighbor_quantizer=binary,
+ search_list_size=128,
+ search_list_size_search=256
+ )
+ )
+ """)
+ assert "error" in result
+
+
+def test_diskann_command_search_list_size(db):
+ """Runtime search_list_size override via command insert."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[64] INDEXED BY diskann(neighbor_quantizer=binary)
+ )
+ """)
+ import struct, random
+ random.seed(42)
+ for i in range(20):
+ vec = struct.pack("64f", *[random.random() for _ in range(64)])
+ db.execute("INSERT INTO t(emb) VALUES (?)", [vec])
+
+ # Query with default search_list_size
+ query = struct.pack("64f", *[random.random() for _ in range(64)])
+ results_before = db.execute(
+ "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k = 5", [query]
+ ).fetchall()
+ assert len(results_before) == 5
+
+ # Override search_list_size_search at runtime
+ db.execute("INSERT INTO t(rowid) VALUES ('search_list_size_search=256')")
+
+ # Query should still work
+ results_after = db.execute(
+ "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k = 5", [query]
+ ).fetchall()
+ assert len(results_after) == 5
+
+ # Override search_list_size_insert at runtime
+ db.execute("INSERT INTO t(rowid) VALUES ('search_list_size_insert=32')")
+
+ # Inserts should still work
+ vec = struct.pack("64f", *[random.random() for _ in range(64)])
+ db.execute("INSERT INTO t(emb) VALUES (?)", [vec])
+
+ # Override unified search_list_size
+ db.execute("INSERT INTO t(rowid) VALUES ('search_list_size=64')")
+
+ results_final = db.execute(
+ "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k = 5", [query]
+ ).fetchall()
+ assert len(results_final) == 5
+
+
+def test_diskann_command_search_list_size_error(db):
+ """Error on invalid search_list_size command value."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[64] INDEXED BY diskann(neighbor_quantizer=binary)
+ )
+ """)
+ result = exec(db, "INSERT INTO t(rowid) VALUES ('search_list_size=0')")
+ assert "error" in result
+ result = exec(db, "INSERT INTO t(rowid) VALUES ('search_list_size=-1')")
+ assert "error" in result
+
+
+# ======================================================================
+# Error cases: DiskANN + auxiliary/metadata/partition columns
+# ======================================================================
+
+def test_diskann_create_error_with_auxiliary_column(db):
+ """DiskANN tables should not support auxiliary columns."""
+ result = exec(db, """
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[64] INDEXED BY diskann(neighbor_quantizer=binary),
+ +extra text
+ )
+ """)
+ assert "error" in result
+ assert "auxiliary" in result["message"].lower() or "Auxiliary" in result["message"]
+
+
+def test_diskann_create_error_with_metadata_column(db):
+ """DiskANN tables should not support metadata columns."""
+ result = exec(db, """
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[64] INDEXED BY diskann(neighbor_quantizer=binary),
+ metadata_col integer metadata
+ )
+ """)
+ assert "error" in result
+ assert "metadata" in result["message"].lower() or "Metadata" in result["message"]
+
+
+def test_diskann_create_error_with_partition_key(db):
+ """DiskANN tables should not support partition key columns."""
+ result = exec(db, """
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[64] INDEXED BY diskann(neighbor_quantizer=binary),
+ user_id text partition key
+ )
+ """)
+ assert "error" in result
+ assert "partition" in result["message"].lower() or "Partition" in result["message"]
+
+
+# ======================================================================
+# Insert edge cases
+# ======================================================================
+
+def test_diskann_insert_no_rowid(db):
+ """INSERT without explicit rowid (auto-generated) should work."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[8] INDEXED BY diskann(neighbor_quantizer=binary)
+ )
+ """)
+ db.execute("INSERT INTO t(emb) VALUES (?)", [_f32([1.0] * 8)])
+ db.execute("INSERT INTO t(emb) VALUES (?)", [_f32([2.0] * 8)])
+ assert db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] == 2
+ assert db.execute("SELECT count(*) FROM t_diskann_nodes00").fetchone()[0] == 2
+
+
+def test_diskann_insert_large_batch(db):
+ """INSERT 500+ vectors, verify all are queryable via KNN."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[16] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=16)
+ )
+ """)
+ import random
+ random.seed(99)
+ N = 500
+ for i in range(1, N + 1):
+ vec = [random.gauss(0, 1) for _ in range(16)]
+ db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32(vec)])
+
+ assert db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] == N
+ assert db.execute("SELECT count(*) FROM t_diskann_nodes00").fetchone()[0] == N
+
+ # KNN should return results
+ query = [random.gauss(0, 1) for _ in range(16)]
+ rows = db.execute(
+ "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=10",
+ [_f32(query)],
+ ).fetchall()
+ assert len(rows) == 10
+ # Distances should be sorted
+ distances = [r[1] for r in rows]
+ for i in range(len(distances) - 1):
+ assert distances[i] <= distances[i + 1]
+
+
+def test_diskann_insert_zero_vector(db):
+ """Insert an all-zero vector (edge case for binary quantizer)."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[8] INDEXED BY diskann(neighbor_quantizer=binary)
+ )
+ """)
+ db.execute("INSERT INTO t(rowid, emb) VALUES (1, ?)", [_f32([0.0] * 8)])
+ db.execute("INSERT INTO t(rowid, emb) VALUES (2, ?)", [_f32([1.0] * 8)])
+ count = db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0]
+ assert count == 2
+
+ # Query with zero vector should find rowid 1 as closest
+ rows = db.execute(
+ "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=2",
+ [_f32([0.0] * 8)],
+ ).fetchall()
+ assert len(rows) == 2
+ assert rows[0][0] == 1
+
+
+def test_diskann_insert_large_values(db):
+ """Insert vectors with very large float values."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[8] INDEXED BY diskann(neighbor_quantizer=binary)
+ )
+ """)
+ import sys
+ large = sys.float_info.max / 1e300 # Large but not overflowing
+ db.execute("INSERT INTO t(rowid, emb) VALUES (1, ?)", [_f32([large] * 8)])
+ db.execute("INSERT INTO t(rowid, emb) VALUES (2, ?)", [_f32([-large] * 8)])
+ db.execute("INSERT INTO t(rowid, emb) VALUES (3, ?)", [_f32([0.0] * 8)])
+ assert db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] == 3
+
+
+def test_diskann_insert_int8_quantizer_knn(db):
+ """Full insert + query cycle with int8 quantizer."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[16] INDEXED BY diskann(neighbor_quantizer=int8, n_neighbors=8)
+ )
+ """)
+ import random
+ random.seed(77)
+ for i in range(1, 31):
+ vec = [random.gauss(0, 1) for _ in range(16)]
+ db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32(vec)])
+
+ assert db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] == 30
+
+ # KNN should work
+ query = [random.gauss(0, 1) for _ in range(16)]
+ rows = db.execute(
+ "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=5",
+ [_f32(query)],
+ ).fetchall()
+ assert len(rows) == 5
+ distances = [r[1] for r in rows]
+ for i in range(len(distances) - 1):
+ assert distances[i] <= distances[i + 1]
+
+
+# ======================================================================
+# Delete edge cases
+# ======================================================================
+
+def test_diskann_delete_nonexistent(db):
+ """DELETE of a nonexistent rowid should either be a no-op or return an error, not crash."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8)
+ )
+ """)
+ db.execute("INSERT INTO t(rowid, emb) VALUES (1, ?)", [_f32([1.0] * 8)])
+ # Deleting a nonexistent rowid may error but should not crash
+ result = exec(db, "DELETE FROM t WHERE rowid = 999")
+ # Whether it succeeds or errors, the existing row should still be there
+ assert db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] == 1
+
+
+def test_diskann_delete_then_reinsert_same_rowid(db):
+ """Delete rowid 5, then reinsert rowid 5 with a new vector."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8)
+ )
+ """)
+ for i in range(1, 6):
+ db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32([float(i)] * 8)])
+
+ db.execute("DELETE FROM t WHERE rowid = 5")
+ assert db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] == 4
+
+ # Reinsert with new vector
+ db.execute("INSERT INTO t(rowid, emb) VALUES (5, ?)", [_f32([99.0] * 8)])
+ assert db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] == 5
+ assert db.execute("SELECT count(*) FROM t_diskann_nodes00").fetchone()[0] == 5
+
+
+def test_diskann_delete_all_then_insert(db):
+ """Delete everything, then insert new vectors. Graph should rebuild."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8)
+ )
+ """)
+ for i in range(1, 6):
+ db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32([float(i)] * 8)])
+
+ # Delete all
+ for i in range(1, 6):
+ db.execute("DELETE FROM t WHERE rowid = ?", [i])
+ assert db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] == 0
+
+ medoid = db.execute("SELECT value FROM t_info WHERE key = 'diskann_medoid_00'").fetchone()[0]
+ assert medoid is None
+
+ # Insert new vectors
+ for i in range(10, 15):
+ db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32([float(i)] * 8)])
+
+ assert db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] == 5
+ assert db.execute("SELECT count(*) FROM t_diskann_nodes00").fetchone()[0] == 5
+
+ medoid = db.execute("SELECT value FROM t_info WHERE key = 'diskann_medoid_00'").fetchone()[0]
+ assert medoid is not None
+
+ # KNN should work
+ rows = db.execute(
+ "SELECT rowid FROM t WHERE emb MATCH ? AND k=3",
+ [_f32([12.0] * 8)],
+ ).fetchall()
+ assert len(rows) == 3
+
+
+def test_diskann_delete_preserves_graph_connectivity(db):
+ """After deleting a node, remaining nodes should still be reachable via KNN."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8)
+ )
+ """)
+ import random
+ random.seed(456)
+ for i in range(1, 21):
+ vec = [random.gauss(0, 1) for _ in range(8)]
+ db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32(vec)])
+
+ # Delete 5 nodes
+ for i in [3, 7, 11, 15, 19]:
+ db.execute("DELETE FROM t WHERE rowid = ?", [i])
+
+ remaining = db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0]
+ assert remaining == 15
+
+ # Every remaining node should be reachable via KNN (appears somewhere in top-k)
+ all_rowids = [r[0] for r in db.execute("SELECT rowid FROM t_vectors00").fetchall()]
+ reachable = set()
+ for rid in all_rowids:
+ vec_blob = db.execute("SELECT vector FROM t_vectors00 WHERE rowid = ?", [rid]).fetchone()[0]
+ rows = db.execute(
+ "SELECT rowid FROM t WHERE emb MATCH ? AND k=5",
+ [vec_blob],
+ ).fetchall()
+ assert len(rows) >= 1 # At least some results
+ for r in rows:
+ reachable.add(r[0])
+ # Most nodes should be reachable through the graph
+ assert len(reachable) >= len(all_rowids) * 0.8, \
+ f"Only {len(reachable)}/{len(all_rowids)} nodes reachable"
+
+
+# ======================================================================
+# Update scenarios
+# ======================================================================
+
+def test_diskann_update_vector(db):
+ """UPDATE a vector on DiskANN table may not be supported; verify it either works or errors cleanly."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8)
+ )
+ """)
+ db.execute("INSERT INTO t(rowid, emb) VALUES (1, ?)", [_f32([1, 0, 0, 0, 0, 0, 0, 0])])
+ db.execute("INSERT INTO t(rowid, emb) VALUES (2, ?)", [_f32([0, 1, 0, 0, 0, 0, 0, 0])])
+ db.execute("INSERT INTO t(rowid, emb) VALUES (3, ?)", [_f32([0, 0, 1, 0, 0, 0, 0, 0])])
+
+ # UPDATE may not be fully supported for DiskANN yet; verify no crash
+ result = exec(db, "UPDATE t SET emb = ? WHERE rowid = 1", [_f32([0, 0.9, 0.1, 0, 0, 0, 0, 0])])
+ if "error" not in result:
+ # If UPDATE succeeded, verify KNN reflects the new value
+ rows = db.execute(
+ "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=3",
+ [_f32([0, 1, 0, 0, 0, 0, 0, 0])],
+ ).fetchall()
+ assert len(rows) == 3
+ # rowid 2 should still be closest (exact match)
+ assert rows[0][0] == 2
+
+
+# ======================================================================
+# KNN correctness after mutations
+# ======================================================================
+
+def test_diskann_knn_recall_after_inserts(db):
+ """Insert N vectors, verify top-1 recall is 100% for exact matches."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=16)
+ )
+ """)
+ import random
+ random.seed(200)
+ vectors = {}
+ for i in range(1, 51):
+ vec = [random.gauss(0, 1) for _ in range(8)]
+ vectors[i] = vec
+ db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32(vec)])
+
+ # Top-1 for each vector should return itself
+ correct = 0
+ for rid, vec in vectors.items():
+ rows = db.execute(
+ "SELECT rowid FROM t WHERE emb MATCH ? AND k=1",
+ [_f32(vec)],
+ ).fetchall()
+ if rows and rows[0][0] == rid:
+ correct += 1
+
+ # With binary quantizer, approximate search may not always return exact match
+ # but should have high recall (at least 80%)
+ assert correct >= len(vectors) * 0.8, f"Top-1 recall too low: {correct}/{len(vectors)}"
+
+
+def test_diskann_knn_k_larger_than_table(db):
+ """k=100 on table with 5 rows should return 5."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8)
+ )
+ """)
+ for i in range(1, 6):
+ db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32([float(i)] * 8)])
+
+ rows = db.execute(
+ "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=100",
+ [_f32([3.0] * 8)],
+ ).fetchall()
+ assert len(rows) == 5
+
+
+def test_diskann_knn_cosine_metric(db):
+ """KNN with cosine distance metric."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[8] distance_metric=cosine INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8)
+ )
+ """)
+ # Insert orthogonal-ish vectors
+ db.execute("INSERT INTO t(rowid, emb) VALUES (1, ?)", [_f32([1, 0, 0, 0, 0, 0, 0, 0])])
+ db.execute("INSERT INTO t(rowid, emb) VALUES (2, ?)", [_f32([0, 1, 0, 0, 0, 0, 0, 0])])
+ db.execute("INSERT INTO t(rowid, emb) VALUES (3, ?)", [_f32([0.7, 0.7, 0, 0, 0, 0, 0, 0])])
+
+ rows = db.execute(
+ "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=3",
+ [_f32([1, 0, 0, 0, 0, 0, 0, 0])],
+ ).fetchall()
+ assert len(rows) == 3
+ # rowid 1 should be closest (exact match in direction)
+ assert rows[0][0] == 1
+ # Distances should be sorted
+ distances = [r[1] for r in rows]
+ for i in range(len(distances) - 1):
+ assert distances[i] <= distances[i + 1]
+
+
+def test_diskann_knn_after_heavy_churn(db):
+ """Interleave many inserts and deletes, then query."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=16)
+ )
+ """)
+ import random
+ random.seed(321)
+
+ # Insert 50 vectors
+ for i in range(1, 51):
+ vec = [random.gauss(0, 1) for _ in range(8)]
+ db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32(vec)])
+
+ # Delete even-numbered rows
+ for i in range(2, 51, 2):
+ db.execute("DELETE FROM t WHERE rowid = ?", [i])
+
+ # Insert more vectors with higher rowids
+ for i in range(51, 76):
+ vec = [random.gauss(0, 1) for _ in range(8)]
+ db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32(vec)])
+
+ remaining = db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0]
+ assert remaining == 50 # 25 odd + 25 new
+
+ # KNN should still work and return results
+ query = [random.gauss(0, 1) for _ in range(8)]
+ rows = db.execute(
+ "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=10",
+ [_f32(query)],
+ ).fetchall()
+ assert len(rows) == 10
+ # Distances should be sorted
+ distances = [r[1] for r in rows]
+ for i in range(len(distances) - 1):
+ assert distances[i] <= distances[i + 1]
+
+
+def test_diskann_knn_batch_recall(db):
+ """Insert 100+ vectors and verify reasonable recall."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[16] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=16)
+ )
+ """)
+ import random
+ random.seed(42)
+ N = 150
+ vectors = {}
+ for i in range(1, N + 1):
+ vec = [random.gauss(0, 1) for _ in range(16)]
+ vectors[i] = vec
+ db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32(vec)])
+
+ # Brute-force top-5 for a query and compare with DiskANN
+ query = [random.gauss(0, 1) for _ in range(16)]
+
+ # Compute true distances
+ true_dists = []
+ for rid, vec in vectors.items():
+ d = sum((a - b) ** 2 for a, b in zip(query, vec))
+ true_dists.append((d, rid))
+ true_dists.sort()
+ true_top5 = set(r for _, r in true_dists[:5])
+
+ rows = db.execute(
+ "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=5",
+ [_f32(query)],
+ ).fetchall()
+ result_top5 = set(r[0] for r in rows)
+ assert len(rows) == 5
+
+ # At least 3 of top-5 should match (reasonable recall for approximate search)
+ overlap = len(true_top5 & result_top5)
+ assert overlap >= 3, f"Recall too low: only {overlap}/5 overlap"
+
+
+# ======================================================================
+# Additional edge cases
+# ======================================================================
+
+def test_diskann_insert_wrong_dimensions(db):
+ """INSERT with wrong dimension vector should error."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8)
+ )
+ """)
+ result = exec(db, "INSERT INTO t(rowid, emb) VALUES (1, ?)", [_f32([1.0] * 4)])
+ assert "error" in result
+
+
+def test_diskann_knn_wrong_query_dimensions(db):
+ """KNN MATCH with wrong dimension query should error."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8)
+ )
+ """)
+ db.execute("INSERT INTO t(rowid, emb) VALUES (1, ?)", [_f32([1.0] * 8)])
+
+ result = exec(db, "SELECT rowid FROM t WHERE emb MATCH ? AND k=1", [_f32([1.0] * 4)])
+ assert "error" in result
+
+
+def test_diskann_graph_connectivity_after_many_deletes(db):
+ """After many deletes, the graph should still be connected enough for search."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=16)
+ )
+ """)
+ import random
+ random.seed(789)
+ N = 40
+ for i in range(1, N + 1):
+ vec = [random.gauss(0, 1) for _ in range(8)]
+ db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32(vec)])
+
+ # Delete 30 of 40 nodes
+ to_delete = list(range(1, 31))
+ for i in to_delete:
+ db.execute("DELETE FROM t WHERE rowid = ?", [i])
+
+ remaining = db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0]
+ assert remaining == 10
+
+ # Search should still work and return results
+ query = [random.gauss(0, 1) for _ in range(8)]
+ rows = db.execute(
+ "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=10",
+ [_f32(query)],
+ ).fetchall()
+ # Should return some results (graph may be fragmented after heavy deletion)
+ assert len(rows) >= 1
+ # Distances should be sorted
+ distances = [r[1] for r in rows]
+ for i in range(len(distances) - 1):
+ assert distances[i] <= distances[i + 1]
+
+
+def test_diskann_large_batch_insert_500(db):
+ """Insert 500+ vectors and verify counts and KNN."""
+ db.execute("""
+ CREATE VIRTUAL TABLE t USING vec0(
+ emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=16)
+ )
+ """)
+ import random
+ random.seed(555)
+ N = 500
+ for i in range(1, N + 1):
+ vec = [random.gauss(0, 1) for _ in range(8)]
+ db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32(vec)])
+
+ assert db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] == N
+
+ query = [random.gauss(0, 1) for _ in range(8)]
+ rows = db.execute(
+ "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=20",
+ [_f32(query)],
+ ).fetchall()
+ assert len(rows) == 20
+ distances = [r[1] for r in rows]
+ for i in range(len(distances) - 1):
+ assert distances[i] <= distances[i + 1]
diff --git a/tests/test-unit.c b/tests/test-unit.c
index 27a469d..83cedd5 100644
--- a/tests/test-unit.c
+++ b/tests/test-unit.c
@@ -1187,6 +1187,7 @@ void test_ivf_quantize_binary() {
}
void test_ivf_config_parsing() {
+void test_vec0_parse_vector_column_diskann() {
printf("Starting %s...\n", __func__);
struct VectorColumnDefinition col;
int rc;
@@ -1199,6 +1200,34 @@ void test_ivf_config_parsing() {
assert(col.index_type == VEC0_INDEX_TYPE_RESCORE);
assert(col.rescore.quantizer_type == VEC0_RESCORE_QUANTIZER_BIT);
assert(col.rescore.oversample == 8); // default
+ // Existing syntax (no INDEXED BY) should have diskann.enabled == 0
+ {
+ const char *input = "emb float[128]";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.index_type != VEC0_INDEX_TYPE_DISKANN);
+ sqlite3_free(col.name);
+ }
+
+ // With distance_metric but no INDEXED BY
+ {
+ const char *input = "emb float[128] distance_metric=cosine";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.index_type != VEC0_INDEX_TYPE_DISKANN);
+ assert(col.distance_metric == VEC0_DISTANCE_METRIC_COSINE);
+ sqlite3_free(col.name);
+ }
+
+ // Basic binary quantizer
+ {
+ const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.index_type == VEC0_INDEX_TYPE_DISKANN);
+ assert(col.diskann.quantizer_type == VEC0_DISKANN_QUANTIZER_BINARY);
+ assert(col.diskann.n_neighbors == 72); // default
+ assert(col.diskann.search_list_size == 128); // default
assert(col.dimensions == 128);
sqlite3_free(col.name);
}
@@ -1370,6 +1399,681 @@ void test_ivf_config_parsing() {
printf(" All ivf_config_parsing tests passed.\n");
}
#endif /* SQLITE_VEC_ENABLE_IVF */
+ // INT8 quantizer
+ {
+ const char *input = "v float[64] INDEXED BY diskann(neighbor_quantizer=int8)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.index_type == VEC0_INDEX_TYPE_DISKANN);
+ assert(col.diskann.quantizer_type == VEC0_DISKANN_QUANTIZER_INT8);
+ sqlite3_free(col.name);
+ }
+
+ // Custom n_neighbors
+ {
+ const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=48)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.index_type == VEC0_INDEX_TYPE_DISKANN);
+ assert(col.diskann.n_neighbors == 48);
+ sqlite3_free(col.name);
+ }
+
+ // Custom search_list_size
+ {
+ const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, search_list_size=256)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.diskann.search_list_size == 256);
+ sqlite3_free(col.name);
+ }
+
+ // Combined with distance_metric (distance_metric first)
+ {
+ const char *input = "emb float[128] distance_metric=cosine INDEXED BY diskann(neighbor_quantizer=int8)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.distance_metric == VEC0_DISTANCE_METRIC_COSINE);
+ assert(col.index_type == VEC0_INDEX_TYPE_DISKANN);
+ assert(col.diskann.quantizer_type == VEC0_DISKANN_QUANTIZER_INT8);
+ sqlite3_free(col.name);
+ }
+
+ // Error: missing neighbor_quantizer (required)
+ {
+ const char *input = "emb float[128] INDEXED BY diskann(n_neighbors=72)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_ERROR);
+ }
+
+ // Error: empty parens
+ {
+ const char *input = "emb float[128] INDEXED BY diskann()";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_ERROR);
+ }
+
+ // Error: unknown quantizer
+ {
+ const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=unknown)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_ERROR);
+ }
+
+ // Error: bad n_neighbors (not divisible by 8)
+ {
+ const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=13)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_ERROR);
+ }
+
+ // Error: n_neighbors too large
+ {
+ const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=512)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_ERROR);
+ }
+
+ // Error: missing BY
+ {
+ const char *input = "emb float[128] INDEXED diskann(neighbor_quantizer=binary)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_ERROR);
+ }
+
+ // Error: unknown algorithm
+ {
+ const char *input = "emb float[128] INDEXED BY hnsw(neighbor_quantizer=binary)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_ERROR);
+ }
+
+ // Error: unknown option key
+ {
+ const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, foobar=baz)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_ERROR);
+ }
+
+ // Case insensitivity for keywords
+ {
+ const char *input = "emb float[128] indexed by DISKANN(NEIGHBOR_QUANTIZER=BINARY)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.index_type == VEC0_INDEX_TYPE_DISKANN);
+ assert(col.diskann.quantizer_type == VEC0_DISKANN_QUANTIZER_BINARY);
+ sqlite3_free(col.name);
+ }
+
+ // Split search_list_size: search and insert
+ {
+ const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, search_list_size_search=256, search_list_size_insert=64)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.diskann.search_list_size == 128); // default (unified)
+ assert(col.diskann.search_list_size_search == 256);
+ assert(col.diskann.search_list_size_insert == 64);
+ sqlite3_free(col.name);
+ }
+
+ // Split search_list_size: only search
+ {
+ const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, search_list_size_search=200)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_OK);
+ assert(col.diskann.search_list_size_search == 200);
+ assert(col.diskann.search_list_size_insert == 0);
+ sqlite3_free(col.name);
+ }
+
+ // Error: cannot mix search_list_size with search_list_size_search
+ {
+ const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, search_list_size=128, search_list_size_search=256)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_ERROR);
+ }
+
+ // Error: cannot mix search_list_size with search_list_size_insert
+ {
+ const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, search_list_size=128, search_list_size_insert=64)";
+ rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == SQLITE_ERROR);
+ }
+
+ printf(" All vec0_parse_vector_column_diskann tests passed.\n");
+}
+
+void test_diskann_validity_bitmap() {
+ printf("Starting %s...\n", __func__);
+
+ unsigned char validity[3]; // 24 bits
+ memset(validity, 0, sizeof(validity));
+
+ // All initially invalid
+ for (int i = 0; i < 24; i++) {
+ assert(diskann_validity_get(validity, i) == 0);
+ }
+ assert(diskann_validity_count(validity, 24) == 0);
+
+ // Set bit 0
+ diskann_validity_set(validity, 0, 1);
+ assert(diskann_validity_get(validity, 0) == 1);
+ assert(diskann_validity_count(validity, 24) == 1);
+
+ // Set bit 7 (last bit of first byte)
+ diskann_validity_set(validity, 7, 1);
+ assert(diskann_validity_get(validity, 7) == 1);
+ assert(diskann_validity_count(validity, 24) == 2);
+
+ // Set bit 8 (first bit of second byte)
+ diskann_validity_set(validity, 8, 1);
+ assert(diskann_validity_get(validity, 8) == 1);
+ assert(diskann_validity_count(validity, 24) == 3);
+
+ // Set bit 23 (last bit)
+ diskann_validity_set(validity, 23, 1);
+ assert(diskann_validity_get(validity, 23) == 1);
+ assert(diskann_validity_count(validity, 24) == 4);
+
+ // Clear bit 0
+ diskann_validity_set(validity, 0, 0);
+ assert(diskann_validity_get(validity, 0) == 0);
+ assert(diskann_validity_count(validity, 24) == 3);
+
+ // Other bits unaffected
+ assert(diskann_validity_get(validity, 7) == 1);
+ assert(diskann_validity_get(validity, 8) == 1);
+
+ printf(" All diskann_validity_bitmap tests passed.\n");
+}
+
+void test_diskann_neighbor_ids() {
+ printf("Starting %s...\n", __func__);
+
+ unsigned char ids[8 * 8]; // 8 slots * 8 bytes each
+ memset(ids, 0, sizeof(ids));
+
+ // Set and get slot 0
+ diskann_neighbor_id_set(ids, 0, 42);
+ assert(diskann_neighbor_id_get(ids, 0) == 42);
+
+ // Set and get middle slot
+ diskann_neighbor_id_set(ids, 3, 12345);
+ assert(diskann_neighbor_id_get(ids, 3) == 12345);
+
+ // Set and get last slot
+ diskann_neighbor_id_set(ids, 7, 99999);
+ assert(diskann_neighbor_id_get(ids, 7) == 99999);
+
+ // Slot 0 still correct
+ assert(diskann_neighbor_id_get(ids, 0) == 42);
+
+ // Large value
+ diskann_neighbor_id_set(ids, 1, INT64_MAX);
+ assert(diskann_neighbor_id_get(ids, 1) == INT64_MAX);
+
+ printf(" All diskann_neighbor_ids tests passed.\n");
+}
+
+void test_diskann_quantize_binary() {
+ printf("Starting %s...\n", __func__);
+
+ // 8-dimensional vector: positive values -> 1, negative/zero -> 0
+ float src[8] = {1.0f, -1.0f, 0.5f, 0.0f, -0.5f, 0.1f, -0.1f, 100.0f};
+ unsigned char out[1]; // 8 bits = 1 byte
+
+ int rc = diskann_quantize_vector(src, 8, VEC0_DISKANN_QUANTIZER_BINARY, out);
+ assert(rc == 0);
+
+ // Expected bits (LSB first within each byte):
+ // bit 0: 1.0 > 0 -> 1
+ // bit 1: -1.0 > 0 -> 0
+ // bit 2: 0.5 > 0 -> 1
+ // bit 3: 0.0 > 0 -> 0 (not strictly greater)
+ // bit 4: -0.5 > 0 -> 0
+ // bit 5: 0.1 > 0 -> 1
+ // bit 6: -0.1 > 0 -> 0
+ // bit 7: 100.0 > 0 -> 1
+ // Expected byte: 1 + 0 + 4 + 0 + 0 + 32 + 0 + 128 = 0b10100101 = 0xA5
+ assert(out[0] == 0xA5);
+
+ printf(" All diskann_quantize_binary tests passed.\n");
+}
+
+void test_diskann_node_init_sizes() {
+ printf("Starting %s...\n", __func__);
+
+ unsigned char *validity, *ids, *qvecs;
+ int validitySize, idsSize, qvecsSize;
+
+ // 72 neighbors, binary quantizer, 1024 dims
+ int rc = diskann_node_init(72, VEC0_DISKANN_QUANTIZER_BINARY, 1024,
+ &validity, &validitySize, &ids, &idsSize, &qvecs, &qvecsSize);
+ assert(rc == 0);
+ assert(validitySize == 9); // 72/8
+ assert(idsSize == 576); // 72 * 8
+ assert(qvecsSize == 9216); // 72 * (1024/8)
+
+ // All validity bits should be 0
+ assert(diskann_validity_count(validity, 72) == 0);
+
+ sqlite3_free(validity);
+ sqlite3_free(ids);
+ sqlite3_free(qvecs);
+
+ // 8 neighbors, int8 quantizer, 32 dims
+ rc = diskann_node_init(8, VEC0_DISKANN_QUANTIZER_INT8, 32,
+ &validity, &validitySize, &ids, &idsSize, &qvecs, &qvecsSize);
+ assert(rc == 0);
+ assert(validitySize == 1); // 8/8
+ assert(idsSize == 64); // 8 * 8
+ assert(qvecsSize == 256); // 8 * 32
+
+ sqlite3_free(validity);
+ sqlite3_free(ids);
+ sqlite3_free(qvecs);
+
+ printf(" All diskann_node_init_sizes tests passed.\n");
+}
+
+void test_diskann_node_set_clear_neighbor() {
+ printf("Starting %s...\n", __func__);
+
+ unsigned char *validity, *ids, *qvecs;
+ int validitySize, idsSize, qvecsSize;
+
+ // 8 neighbors, binary quantizer, 16 dims (2 bytes per qvec)
+ int rc = diskann_node_init(8, VEC0_DISKANN_QUANTIZER_BINARY, 16,
+ &validity, &validitySize, &ids, &idsSize, &qvecs, &qvecsSize);
+ assert(rc == 0);
+
+ // Create a test quantized vector (2 bytes)
+ unsigned char test_qvec[2] = {0xAB, 0xCD};
+
+ // Set neighbor at slot 3
+ diskann_node_set_neighbor(validity, ids, qvecs, 3,
+ 42, test_qvec, VEC0_DISKANN_QUANTIZER_BINARY, 16);
+
+ // Verify slot 3 is valid
+ assert(diskann_validity_get(validity, 3) == 1);
+ assert(diskann_validity_count(validity, 8) == 1);
+
+ // Verify rowid
+ assert(diskann_neighbor_id_get(ids, 3) == 42);
+
+ // Verify quantized vector
+ const unsigned char *read_qvec = diskann_neighbor_qvec_get(
+ qvecs, 3, VEC0_DISKANN_QUANTIZER_BINARY, 16);
+ assert(read_qvec[0] == 0xAB);
+ assert(read_qvec[1] == 0xCD);
+
+ // Clear slot 3
+ diskann_node_clear_neighbor(validity, ids, qvecs, 3,
+ VEC0_DISKANN_QUANTIZER_BINARY, 16);
+ assert(diskann_validity_get(validity, 3) == 0);
+ assert(diskann_neighbor_id_get(ids, 3) == 0);
+ assert(diskann_validity_count(validity, 8) == 0);
+
+ sqlite3_free(validity);
+ sqlite3_free(ids);
+ sqlite3_free(qvecs);
+
+ printf(" All diskann_node_set_clear_neighbor tests passed.\n");
+}
+
+void test_diskann_prune_select() {
+ printf("Starting %s...\n", __func__);
+
+ // Scenario: 5 candidates, sorted by distance to p
+ // Candidates: A(0), B(1), C(2), D(3), E(4)
+ // p_distances (already sorted): A=1.0, B=2.0, C=3.0, D=4.0, E=5.0
+ //
+ // Inter-candidate distances (5x5 matrix):
+ // A B C D E
+ // A 0.0 1.5 3.0 4.0 5.0
+ // B 1.5 0.0 1.5 3.0 4.0
+ // C 3.0 1.5 0.0 1.5 3.0
+ // D 4.0 3.0 1.5 0.0 1.5
+ // E 5.0 4.0 3.0 1.5 0.0
+
+ float p_distances[5] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
+ float inter[25] = {
+ 0.0f, 1.5f, 3.0f, 4.0f, 5.0f,
+ 1.5f, 0.0f, 1.5f, 3.0f, 4.0f,
+ 3.0f, 1.5f, 0.0f, 1.5f, 3.0f,
+ 4.0f, 3.0f, 1.5f, 0.0f, 1.5f,
+ 5.0f, 4.0f, 3.0f, 1.5f, 0.0f,
+ };
+ int selected[5];
+ int count;
+
+ // alpha=1.0, R=3: greedy selection
+ // Round 1: Pick A (closest). Prune check:
+ // B: 1.0*1.5 <= 2.0? yes -> pruned
+ // C: 1.0*3.0 <= 3.0? yes -> pruned
+ // D: 1.0*4.0 <= 4.0? yes -> pruned
+ // E: 1.0*5.0 <= 5.0? yes -> pruned
+ // Result: only A selected
+ {
+ int rc = diskann_prune_select(inter, p_distances, 5, 1.0f, 3, selected, &count);
+ assert(rc == 0);
+ assert(count == 1);
+ assert(selected[0] == 1); // A
+ }
+
+ // alpha=1.5, R=3: diversity-aware
+ // Round 1: Pick A. Prune check:
+ // B: 1.5*1.5=2.25 <= 2.0? no -> keep
+ // C: 1.5*3.0=4.5 <= 3.0? no -> keep
+ // D: 1.5*4.0=6.0 <= 4.0? no -> keep
+ // E: 1.5*5.0=7.5 <= 5.0? no -> keep
+ // Round 2: Pick B. Prune check:
+ // C: 1.5*1.5=2.25 <= 3.0? yes -> pruned
+ // D: 1.5*3.0=4.5 <= 4.0? no -> keep
+ // E: 1.5*4.0=6.0 <= 5.0? no -> keep
+ // Round 3: Pick D. Done, 3 selected.
+ {
+ int rc = diskann_prune_select(inter, p_distances, 5, 1.5f, 3, selected, &count);
+ assert(rc == 0);
+ assert(count == 3);
+ assert(selected[0] == 1); // A
+ assert(selected[1] == 1); // B
+ assert(selected[3] == 1); // D
+ assert(selected[2] == 0); // C pruned
+ assert(selected[4] == 0); // E not reached
+ }
+
+ // R > num_candidates with very high alpha (no pruning): select all
+ {
+ int rc = diskann_prune_select(inter, p_distances, 5, 100.0f, 10, selected, &count);
+ assert(rc == 0);
+ assert(count == 5);
+ }
+
+ // Empty candidate set
+ {
+ int rc = diskann_prune_select(NULL, NULL, 0, 1.2f, 3, selected, &count);
+ assert(rc == 0);
+ assert(count == 0);
+ }
+
+ printf(" All diskann_prune_select tests passed.\n");
+}
+
+void test_diskann_quantized_vector_byte_size() {
+ printf("Starting %s...\n", __func__);
+
+ // Binary quantizer: 1 bit per dimension, so 128 dims = 16 bytes
+ assert(diskann_quantized_vector_byte_size(VEC0_DISKANN_QUANTIZER_BINARY, 128) == 16);
+ assert(diskann_quantized_vector_byte_size(VEC0_DISKANN_QUANTIZER_BINARY, 8) == 1);
+ assert(diskann_quantized_vector_byte_size(VEC0_DISKANN_QUANTIZER_BINARY, 1024) == 128);
+
+ // INT8 quantizer: 1 byte per dimension
+ assert(diskann_quantized_vector_byte_size(VEC0_DISKANN_QUANTIZER_INT8, 128) == 128);
+ assert(diskann_quantized_vector_byte_size(VEC0_DISKANN_QUANTIZER_INT8, 1) == 1);
+ assert(diskann_quantized_vector_byte_size(VEC0_DISKANN_QUANTIZER_INT8, 768) == 768);
+
+ printf(" All diskann_quantized_vector_byte_size tests passed.\n");
+}
+
+void test_diskann_config_defaults() {
+ printf("Starting %s...\n", __func__);
+
+ // A freshly zero-initialized VectorColumnDefinition should have diskann.enabled == 0
+ struct VectorColumnDefinition col;
+ memset(&col, 0, sizeof(col));
+ assert(col.index_type != VEC0_INDEX_TYPE_DISKANN);
+ assert(col.diskann.n_neighbors == 0);
+ assert(col.diskann.search_list_size == 0);
+
+ // Verify parsing a normal vector column still works and diskann is not enabled
+ {
+ const char *input = "embedding float[768]";
+ int rc = vec0_parse_vector_column(input, (int)strlen(input), &col);
+ assert(rc == 0 /* SQLITE_OK */);
+ assert(col.index_type != VEC0_INDEX_TYPE_DISKANN);
+ sqlite3_free(col.name);
+ }
+
+ printf(" All diskann_config_defaults tests passed.\n");
+}
+
+// ======================================================================
+// Additional DiskANN unit tests
+// ======================================================================
+
+void test_diskann_quantize_int8() {
+ printf("Starting %s...\n", __func__);
+
+ // INT8 quantization uses fixed range [-1, 1]:
+ // step = 2.0 / 255.0
+ // out[i] = (i8)((src[i] + 1.0) / step - 128.0)
+ float src[4] = {-1.0f, 0.0f, 0.5f, 1.0f};
+ unsigned char out[4];
+
+ int rc = diskann_quantize_vector(src, 4, VEC0_DISKANN_QUANTIZER_INT8, out);
+ assert(rc == 0);
+
+ int8_t *signed_out = (int8_t *)out;
+ // -1.0 -> (0/step) - 128 = -128
+ assert(signed_out[0] == -128);
+ // 0.0 -> (1.0/step) - 128 ~= 127.5 - 128 ~= -0.5 -> (i8)(-0.5) = 0
+ assert(signed_out[1] >= -2 && signed_out[1] <= 2);
+ // 0.5 -> (1.5/step) - 128 ~= 191.25 - 128 = 63.25 -> (i8) 63
+ assert(signed_out[2] >= 60 && signed_out[2] <= 66);
+ // 1.0 -> should be close to 127 (may have float precision issues)
+ assert(signed_out[3] >= 126 && signed_out[3] <= 127);
+
+ printf(" All diskann_quantize_int8 tests passed.\n");
+}
+
+void test_diskann_quantize_binary_16d() {
+ printf("Starting %s...\n", __func__);
+
+ // 16-dimensional vector (2 bytes output)
+ float src[16] = {
+ 1.0f, -1.0f, 0.5f, -0.5f, // byte 0: bit0=1, bit1=0, bit2=1, bit3=0
+ 0.1f, -0.1f, 0.0f, 100.0f, // byte 0: bit4=1, bit5=0, bit6=0, bit7=1
+ -1.0f, 1.0f, 1.0f, 1.0f, // byte 1: bit0=0, bit1=1, bit2=1, bit3=1
+ -1.0f, -1.0f, 1.0f, -1.0f // byte 1: bit4=0, bit5=0, bit6=1, bit7=0
+ };
+ unsigned char out[2];
+
+ int rc = diskann_quantize_vector(src, 16, VEC0_DISKANN_QUANTIZER_BINARY, out);
+ assert(rc == 0);
+
+ // byte 0: bits 0,2,4,7 set -> 0b10010101 = 0x95
+ assert(out[0] == 0x95);
+ // byte 1: bits 1,2,3,6 set -> 0b01001110 = 0x4E
+ assert(out[1] == 0x4E);
+
+ printf(" All diskann_quantize_binary_16d tests passed.\n");
+}
+
+void test_diskann_quantize_binary_all_positive() {
+ printf("Starting %s...\n", __func__);
+
+ float src[8] = {1.0f, 2.0f, 0.1f, 0.001f, 100.0f, 42.0f, 0.5f, 3.14f};
+ unsigned char out[1];
+
+ int rc = diskann_quantize_vector(src, 8, VEC0_DISKANN_QUANTIZER_BINARY, out);
+ assert(rc == 0);
+ assert(out[0] == 0xFF); // All bits set
+
+ printf(" All diskann_quantize_binary_all_positive tests passed.\n");
+}
+
+void test_diskann_quantize_binary_all_negative() {
+ printf("Starting %s...\n", __func__);
+
+ float src[8] = {-1.0f, -2.0f, -0.1f, -0.001f, -100.0f, -42.0f, -0.5f, 0.0f};
+ unsigned char out[1];
+
+ int rc = diskann_quantize_vector(src, 8, VEC0_DISKANN_QUANTIZER_BINARY, out);
+ assert(rc == 0);
+ assert(out[0] == 0x00); // No bits set (all <= 0)
+
+ printf(" All diskann_quantize_binary_all_negative tests passed.\n");
+}
+
+void test_diskann_candidate_list_operations() {
+ printf("Starting %s...\n", __func__);
+
+ struct DiskannCandidateList list;
+ int rc = _test_diskann_candidate_list_init(&list, 5);
+ assert(rc == 0);
+
+ // Insert candidates in non-sorted order
+ _test_diskann_candidate_list_insert(&list, 10, 3.0f);
+ _test_diskann_candidate_list_insert(&list, 20, 1.0f);
+ _test_diskann_candidate_list_insert(&list, 30, 2.0f);
+
+ assert(_test_diskann_candidate_list_count(&list) == 3);
+ // Should be sorted by distance
+ assert(_test_diskann_candidate_list_rowid(&list, 0) == 20); // dist 1.0
+ assert(_test_diskann_candidate_list_rowid(&list, 1) == 30); // dist 2.0
+ assert(_test_diskann_candidate_list_rowid(&list, 2) == 10); // dist 3.0
+
+ assert(_test_diskann_candidate_list_distance(&list, 0) == 1.0f);
+ assert(_test_diskann_candidate_list_distance(&list, 1) == 2.0f);
+ assert(_test_diskann_candidate_list_distance(&list, 2) == 3.0f);
+
+ // Deduplication: inserting same rowid with better distance should update
+ _test_diskann_candidate_list_insert(&list, 10, 0.5f);
+ assert(_test_diskann_candidate_list_count(&list) == 3); // Same count
+ assert(_test_diskann_candidate_list_rowid(&list, 0) == 10); // Now first
+ assert(_test_diskann_candidate_list_distance(&list, 0) == 0.5f);
+
+ // Next unvisited: should be index 0
+ int idx = _test_diskann_candidate_list_next_unvisited(&list);
+ assert(idx == 0);
+
+ // Mark visited
+ _test_diskann_candidate_list_set_visited(&list, 0);
+ idx = _test_diskann_candidate_list_next_unvisited(&list);
+ assert(idx == 1); // Skip visited
+
+ // Fill to capacity (5) and try inserting a worse candidate
+ _test_diskann_candidate_list_insert(&list, 40, 4.0f);
+ _test_diskann_candidate_list_insert(&list, 50, 5.0f);
+ assert(_test_diskann_candidate_list_count(&list) == 5);
+
+ // Insert worse than worst -> should be discarded
+ int inserted = _test_diskann_candidate_list_insert(&list, 60, 10.0f);
+ assert(inserted == 0);
+ assert(_test_diskann_candidate_list_count(&list) == 5);
+
+ // Insert better than worst -> should replace worst
+ inserted = _test_diskann_candidate_list_insert(&list, 60, 3.5f);
+ assert(inserted == 1);
+ assert(_test_diskann_candidate_list_count(&list) == 5);
+
+ _test_diskann_candidate_list_free(&list);
+
+ printf(" All diskann_candidate_list_operations tests passed.\n");
+}
+
+void test_diskann_visited_set_operations() {
+ printf("Starting %s...\n", __func__);
+
+ struct DiskannVisitedSet set;
+ int rc = _test_diskann_visited_set_init(&set, 32);
+ assert(rc == 0);
+
+ // Empty set
+ assert(_test_diskann_visited_set_contains(&set, 1) == 0);
+ assert(_test_diskann_visited_set_contains(&set, 100) == 0);
+
+ // Insert and check
+ int inserted = _test_diskann_visited_set_insert(&set, 42);
+ assert(inserted == 1);
+ assert(_test_diskann_visited_set_contains(&set, 42) == 1);
+ assert(_test_diskann_visited_set_contains(&set, 43) == 0);
+
+ // Double insert returns 0
+ inserted = _test_diskann_visited_set_insert(&set, 42);
+ assert(inserted == 0);
+
+ // Insert several
+ _test_diskann_visited_set_insert(&set, 1);
+ _test_diskann_visited_set_insert(&set, 2);
+ _test_diskann_visited_set_insert(&set, 100);
+ _test_diskann_visited_set_insert(&set, 999);
+ assert(_test_diskann_visited_set_contains(&set, 1) == 1);
+ assert(_test_diskann_visited_set_contains(&set, 2) == 1);
+ assert(_test_diskann_visited_set_contains(&set, 100) == 1);
+ assert(_test_diskann_visited_set_contains(&set, 999) == 1);
+ assert(_test_diskann_visited_set_contains(&set, 3) == 0);
+
+ // Sentinel value (rowid 0) should not be insertable
+ assert(_test_diskann_visited_set_contains(&set, 0) == 0);
+ inserted = _test_diskann_visited_set_insert(&set, 0);
+ assert(inserted == 0);
+
+ _test_diskann_visited_set_free(&set);
+
+ printf(" All diskann_visited_set_operations tests passed.\n");
+}
+
+void test_diskann_prune_select_single_candidate() {
+ printf("Starting %s...\n", __func__);
+
+ float p_distances[1] = {5.0f};
+ float inter[1] = {0.0f};
+ int selected[1];
+ int count;
+
+ int rc = diskann_prune_select(inter, p_distances, 1, 1.0f, 3, selected, &count);
+ assert(rc == 0);
+ assert(count == 1);
+ assert(selected[0] == 1);
+
+ printf(" All diskann_prune_select_single_candidate tests passed.\n");
+}
+
+void test_diskann_prune_select_all_identical_distances() {
+ printf("Starting %s...\n", __func__);
+
+ float p_distances[4] = {2.0f, 2.0f, 2.0f, 2.0f};
+ // All inter-distances are equal too
+ float inter[16] = {
+ 0.0f, 1.0f, 1.0f, 1.0f,
+ 1.0f, 0.0f, 1.0f, 1.0f,
+ 1.0f, 1.0f, 0.0f, 1.0f,
+ 1.0f, 1.0f, 1.0f, 0.0f,
+ };
+ int selected[4];
+ int count;
+
+ // alpha=1.0: pick first, then check if alpha * inter[0][j] <= p_dist[j]
+ // 1.0 * 1.0 <= 2.0? yes, so all are pruned after picking the first
+ int rc = diskann_prune_select(inter, p_distances, 4, 1.0f, 4, selected, &count);
+ assert(rc == 0);
+ assert(count >= 1); // At least one selected
+
+ printf(" All diskann_prune_select_all_identical_distances tests passed.\n");
+}
+
+void test_diskann_prune_select_max_neighbors_1() {
+ printf("Starting %s...\n", __func__);
+
+ float p_distances[3] = {1.0f, 2.0f, 3.0f};
+ float inter[9] = {
+ 0.0f, 5.0f, 5.0f,
+ 5.0f, 0.0f, 5.0f,
+ 5.0f, 5.0f, 0.0f,
+ };
+ int selected[3];
+ int count;
+
+ // R=1: should select exactly 1
+ int rc = diskann_prune_select(inter, p_distances, 3, 1.0f, 1, selected, &count);
+ assert(rc == 0);
+ assert(count == 1);
+ assert(selected[0] == 1); // First (closest) is selected
+
+ printf(" All diskann_prune_select_max_neighbors_1 tests passed.\n");
+}
int main() {
printf("Starting unit tests...\n");
@@ -1402,5 +2106,23 @@ int main() {
test_ivf_quantize_binary();
test_ivf_config_parsing();
#endif
+ test_vec0_parse_vector_column_diskann();
+ test_diskann_validity_bitmap();
+ test_diskann_neighbor_ids();
+ test_diskann_quantize_binary();
+ test_diskann_node_init_sizes();
+ test_diskann_node_set_clear_neighbor();
+ test_diskann_prune_select();
+ test_diskann_quantized_vector_byte_size();
+ test_diskann_config_defaults();
+ test_diskann_quantize_int8();
+ test_diskann_quantize_binary_16d();
+ test_diskann_quantize_binary_all_positive();
+ test_diskann_quantize_binary_all_negative();
+ test_diskann_candidate_list_operations();
+ test_diskann_visited_set_operations();
+ test_diskann_prune_select_single_candidate();
+ test_diskann_prune_select_all_identical_distances();
+ test_diskann_prune_select_max_neighbors_1();
printf("All unit tests passed.\n");
}