diff --git a/Makefile b/Makefile index 2758ee5..89907fa 100644 --- a/Makefile +++ b/Makefile @@ -204,7 +204,7 @@ test-loadable-watch: watchexec --exts c,py,Makefile --clear -- make test-loadable test-unit: - $(CC) -DSQLITE_CORE -DSQLITE_VEC_TEST -DSQLITE_VEC_ENABLE_RESCORE tests/test-unit.c sqlite-vec.c vendor/sqlite3.c -I./ -Ivendor -o $(prefix)/test-unit && $(prefix)/test-unit + $(CC) -DSQLITE_CORE -DSQLITE_VEC_TEST -DSQLITE_VEC_ENABLE_RESCORE -DSQLITE_VEC_ENABLE_DISKANN=1 tests/test-unit.c sqlite-vec.c vendor/sqlite3.c -I./ -Ivendor $(CFLAGS) -o $(prefix)/test-unit && $(prefix)/test-unit # Standalone sqlite3 CLI with vec0 compiled in. Useful for benchmarking, # profiling (has debug symbols), and scripting without .load_extension. diff --git a/benchmarks-ann/Makefile b/benchmarks-ann/Makefile index 6081457..ddceb65 100644 --- a/benchmarks-ann/Makefile +++ b/benchmarks-ann/Makefile @@ -19,9 +19,16 @@ RESCORE_CONFIGS = \ "rescore-bit-os16:type=rescore,quantizer=bit,oversample=16" \ "rescore-int8-os8:type=rescore,quantizer=int8,oversample=8" -ALL_CONFIGS = $(BASELINES) $(RESCORE_CONFIGS) $(IVF_CONFIGS) +# --- DiskANN configs --- +DISKANN_CONFIGS = \ + "diskann-R48-binary:type=diskann,R=48,L=128,quantizer=binary" \ + "diskann-R72-binary:type=diskann,R=72,L=128,quantizer=binary" \ + "diskann-R72-int8:type=diskann,R=72,L=128,quantizer=int8" \ + "diskann-R72-L256:type=diskann,R=72,L=256,quantizer=binary" -.PHONY: seed ground-truth bench-smoke bench-rescore bench-ivf bench-10k bench-50k bench-100k bench-all \ +ALL_CONFIGS = $(BASELINES) $(RESCORE_CONFIGS) $(IVF_CONFIGS) $(DISKANN_CONFIGS) + +.PHONY: seed ground-truth bench-smoke bench-rescore bench-ivf bench-diskann bench-10k bench-50k bench-100k bench-all \ report clean # --- Data preparation --- @@ -37,7 +44,8 @@ ground-truth: seed bench-smoke: seed $(BENCH) --subset-size 5000 -k 10 -n 20 -o runs/smoke \ "brute-float:type=baseline,variant=float" \ - "ivf-quick:type=ivf,nlist=16,nprobe=4" + "ivf-quick:type=ivf,nlist=16,nprobe=4" \ + "diskann-quick:type=diskann,R=48,L=64,quantizer=binary" bench-rescore: seed $(BENCH) --subset-size 10000 -k 10 -o runs/rescore \ @@ -62,6 +70,12 @@ bench-ivf: seed $(BENCH) --subset-size 50000 -k 10 -o runs/ivf $(BASELINES) $(IVF_CONFIGS) $(BENCH) --subset-size 100000 -k 10 -o runs/ivf $(BASELINES) $(IVF_CONFIGS) +# --- DiskANN across sizes --- +bench-diskann: seed + $(BENCH) --subset-size 10000 -k 10 -o runs/diskann $(BASELINES) $(DISKANN_CONFIGS) + $(BENCH) --subset-size 50000 -k 10 -o runs/diskann $(BASELINES) $(DISKANN_CONFIGS) + $(BENCH) --subset-size 100000 -k 10 -o runs/diskann $(BASELINES) $(DISKANN_CONFIGS) + # --- Report --- report: @echo "Use: sqlite3 runs//results.db 'SELECT * FROM bench_results ORDER BY recall DESC'" diff --git a/benchmarks-ann/bench.py b/benchmarks-ann/bench.py index c640628..520db77 100644 --- a/benchmarks-ann/bench.py +++ b/benchmarks-ann/bench.py @@ -6,18 +6,16 @@ across different vec0 configurations. Config format: name:type=,key=val,key=val - Baseline (brute-force) keys: - type=baseline, variant=float|int8|bit, oversample=8 - - Index-specific types can be registered via INDEX_REGISTRY (see below). + Available types: none, vec0-flat, rescore, ivf, diskann Usage: python bench.py --subset-size 10000 \ - "brute-float:type=baseline,variant=float" \ - "brute-int8:type=baseline,variant=int8" \ - "brute-bit:type=baseline,variant=bit" + "raw:type=none" \ + "flat:type=vec0-flat,variant=float" \ + "flat-int8:type=vec0-flat,variant=int8" """ import argparse +from datetime import datetime, timezone import os import sqlite3 import statistics @@ -56,11 +54,118 @@ INDEX_REGISTRY = {} # ============================================================================ -# Baseline implementation +# "none" — regular table, no vec0, manual KNN via vec_distance_cosine() # ============================================================================ -def _baseline_create_table_sql(params): +def _none_create_table_sql(params): + variant = params["variant"] + if variant == "int8": + return ( + "CREATE TABLE vec_items (" + " id INTEGER PRIMARY KEY," + " embedding BLOB NOT NULL," + " embedding_int8 BLOB NOT NULL)" + ) + elif variant == "bit": + return ( + "CREATE TABLE vec_items (" + " id INTEGER PRIMARY KEY," + " embedding BLOB NOT NULL," + " embedding_bq BLOB NOT NULL)" + ) + return ( + "CREATE TABLE vec_items (" + " id INTEGER PRIMARY KEY," + " embedding BLOB NOT NULL)" + ) + + +def _none_insert_sql(params): + variant = params["variant"] + if variant == "int8": + return ( + "INSERT INTO vec_items(id, embedding, embedding_int8) " + "SELECT id, vector, vec_quantize_int8(vector, 'unit') " + "FROM base.train WHERE id >= :lo AND id < :hi" + ) + elif variant == "bit": + return ( + "INSERT INTO vec_items(id, embedding, embedding_bq) " + "SELECT id, vector, vec_quantize_binary(vector) " + "FROM base.train WHERE id >= :lo AND id < :hi" + ) + return ( + "INSERT INTO vec_items(id, embedding) " + "SELECT id, vector FROM base.train WHERE id >= :lo AND id < :hi" + ) + + +def _none_run_query(conn, params, query, k): + variant = params["variant"] + oversample = params.get("oversample", 8) + + if variant == "int8": + q_int8 = conn.execute( + "SELECT vec_quantize_int8(:query, 'unit')", {"query": query} + ).fetchone()[0] + return conn.execute( + "WITH coarse AS (" + " SELECT id, embedding FROM (" + " SELECT id, embedding, vec_distance_cosine(vec_int8(:q_int8), vec_int8(embedding_int8)) as dist " + " FROM vec_items ORDER BY dist LIMIT :oversample_k" + " )" + ") " + "SELECT id, vec_distance_cosine(:query, embedding) as distance " + "FROM coarse ORDER BY 2 LIMIT :k", + {"q_int8": q_int8, "query": query, "k": k, "oversample_k": k * oversample}, + ).fetchall() + elif variant == "bit": + q_bit = conn.execute( + "SELECT vec_quantize_binary(:query)", {"query": query} + ).fetchone()[0] + return conn.execute( + "WITH coarse AS (" + " SELECT id, embedding FROM (" + " SELECT id, embedding, vec_distance_hamming(vec_bit(:q_bit), vec_bit(embedding_bq)) as dist " + " FROM vec_items ORDER BY dist LIMIT :oversample_k" + " )" + ") " + "SELECT id, vec_distance_cosine(:query, embedding) as distance " + "FROM coarse ORDER BY 2 LIMIT :k", + {"q_bit": q_bit, "query": query, "k": k, "oversample_k": k * oversample}, + ).fetchall() + + return conn.execute( + "SELECT id, vec_distance_cosine(:query, embedding) as distance " + "FROM vec_items ORDER BY 2 LIMIT :k", + {"query": query, "k": k}, + ).fetchall() + + +def _none_describe(params): + v = params["variant"] + if v in ("int8", "bit"): + return f"none {v} (os={params['oversample']})" + return f"none float" + + +INDEX_REGISTRY["none"] = { + "defaults": {"variant": "float", "oversample": 8}, + "create_table_sql": _none_create_table_sql, + "insert_sql": _none_insert_sql, + "post_insert_hook": None, + "run_query": _none_run_query, + "describe": _none_describe, +} + + +# ============================================================================ +# vec0-flat — vec0 virtual table with brute-force MATCH +# ============================================================================ + + +def _vec0flat_create_table_sql(params): variant = params["variant"] extra = "" if variant == "int8": @@ -76,7 +181,7 @@ def _baseline_create_table_sql(params): ) -def _baseline_insert_sql(params): +def _vec0flat_insert_sql(params): variant = params["variant"] if variant == "int8": return ( @@ -93,7 +198,7 @@ def _baseline_insert_sql(params): return None # use default -def _baseline_run_query(conn, params, query, k): +def _vec0flat_run_query(conn, params, query, k): variant = params["variant"] oversample = params.get("oversample", 8) @@ -123,20 +228,20 @@ def _baseline_run_query(conn, params, query, k): return None # use default MATCH -def _baseline_describe(params): +def _vec0flat_describe(params): v = params["variant"] if v in ("int8", "bit"): - return f"baseline {v} (os={params['oversample']})" - return f"baseline {v}" + return f"vec0-flat {v} (os={params['oversample']})" + return f"vec0-flat {v}" -INDEX_REGISTRY["baseline"] = { +INDEX_REGISTRY["vec0-flat"] = { "defaults": {"variant": "float", "oversample": 8}, - "create_table_sql": _baseline_create_table_sql, - "insert_sql": _baseline_insert_sql, + "create_table_sql": _vec0flat_create_table_sql, + "insert_sql": _vec0flat_insert_sql, "post_insert_hook": None, - "run_query": _baseline_run_query, - "describe": _baseline_describe, + "run_query": _vec0flat_run_query, + "describe": _vec0flat_describe, } @@ -215,12 +320,64 @@ INDEX_REGISTRY["ivf"] = { } +# ============================================================================ +# DiskANN implementation +# ============================================================================ + + +def _diskann_create_table_sql(params): + bt = params["buffer_threshold"] + extra = f", buffer_threshold={bt}" if bt > 0 else "" + return ( + f"CREATE VIRTUAL TABLE vec_items USING vec0(" + f" id integer primary key," + f" embedding float[768] distance_metric=cosine" + f" INDEXED BY diskann(" + f" neighbor_quantizer={params['quantizer']}," + f" n_neighbors={params['R']}," + f" search_list_size={params['L']}" + f" {extra}" + f" )" + f")" + ) + + +def _diskann_pre_query_hook(conn, params): + L_search = params.get("L_search") + if L_search: + conn.execute( + "INSERT INTO vec_items(id) VALUES (?)", + (f"search_list_size_search={L_search}",), + ) + conn.commit() + print(f" Set search_list_size_search={L_search}") + + +def _diskann_describe(params): + desc = f"diskann q={params['quantizer']:<6} R={params['R']:<3} L={params['L']}" + L_search = params.get("L_search") + if L_search: + desc += f" L_search={L_search}" + return desc + + +INDEX_REGISTRY["diskann"] = { + "defaults": {"R": 72, "L": 128, "quantizer": "binary", "buffer_threshold": 0}, + "create_table_sql": _diskann_create_table_sql, + "insert_sql": None, + "post_insert_hook": None, + "pre_query_hook": _diskann_pre_query_hook, + "run_query": None, + "describe": _diskann_describe, +} + + # ============================================================================ # Config parsing # ============================================================================ INT_KEYS = { - "R", "L", "buffer_threshold", "nlist", "nprobe", "oversample", + "R", "L", "L_search", "buffer_threshold", "nlist", "nprobe", "oversample", "n_trees", "search_k", } @@ -238,7 +395,7 @@ def parse_config(spec): k, v = kv.split("=", 1) raw[k.strip()] = v.strip() - index_type = raw.pop("type", "baseline") + index_type = raw.pop("type", "vec0-flat") if index_type not in INDEX_REGISTRY: raise ValueError( f"Unknown index type: {index_type}. " @@ -289,7 +446,7 @@ def insert_loop(conn, sql, subset_size, label=""): return time.perf_counter() - t0 -def open_bench_db(db_path, ext_path, base_db): +def create_bench_db(db_path, ext_path, base_db): if os.path.exists(db_path): os.remove(db_path) conn = sqlite3.connect(db_path) @@ -300,6 +457,19 @@ def open_bench_db(db_path, ext_path, base_db): return conn +def open_existing_bench_db(db_path, ext_path, base_db): + if not os.path.exists(db_path): + raise FileNotFoundError( + f"Index DB not found: {db_path}\n" + f"Build it first with: --phase build" + ) + conn = sqlite3.connect(db_path) + conn.enable_load_extension(True) + conn.load_extension(ext_path) + conn.execute(f"ATTACH DATABASE '{base_db}' AS base") + return conn + + DEFAULT_INSERT_SQL = ( "INSERT INTO vec_items(id, embedding) " "SELECT id, vector FROM base.train WHERE id >= :lo AND id < :hi" @@ -313,7 +483,7 @@ DEFAULT_INSERT_SQL = ( def build_index(base_db, ext_path, name, params, subset_size, out_dir): db_path = os.path.join(out_dir, f"{name}.{subset_size}.db") - conn = open_bench_db(db_path, ext_path, base_db) + conn = create_bench_db(db_path, ext_path, base_db) reg = INDEX_REGISTRY[params["index_type"]] @@ -364,12 +534,16 @@ def _default_match_query(conn, query, k): ).fetchall() -def measure_knn(db_path, ext_path, base_db, params, subset_size, k=10, n=50): +def measure_knn(db_path, ext_path, base_db, params, subset_size, k=10, n=50, + pre_query_hook=None): conn = sqlite3.connect(db_path) conn.enable_load_extension(True) conn.load_extension(ext_path) conn.execute(f"ATTACH DATABASE '{base_db}' AS base") + if pre_query_hook: + pre_query_hook(conn, params) + query_vectors = load_query_vectors(base_db, n) reg = INDEX_REGISTRY[params["index_type"]] @@ -431,6 +605,34 @@ def measure_knn(db_path, ext_path, base_db, params, subset_size, k=10, n=50): # ============================================================================ +def open_results_db(results_path): + db = sqlite3.connect(results_path) + db.executescript(open(os.path.join(_SCRIPT_DIR, "schema.sql")).read()) + # Migrate existing DBs that predate the runs table + cols = {r[1] for r in db.execute("PRAGMA table_info(runs)").fetchall()} + if "phase" not in cols: + db.execute("ALTER TABLE runs ADD COLUMN phase TEXT NOT NULL DEFAULT 'both'") + db.commit() + return db + + +def create_run(db, config_name, index_type, subset_size, phase, k=None, n=None): + cur = db.execute( + "INSERT INTO runs (config_name, index_type, subset_size, phase, status, k, n) " + "VALUES (?, ?, ?, ?, 'pending', ?, ?)", + (config_name, index_type, subset_size, phase, k, n), + ) + db.commit() + return cur.lastrowid + + +def update_run(db, run_id, **kwargs): + sets = ", ".join(f"{k} = ?" for k in kwargs) + vals = list(kwargs.values()) + [run_id] + db.execute(f"UPDATE runs SET {sets} WHERE run_id = ?", vals) + db.commit() + + def save_results(results_path, rows): db = sqlite3.connect(results_path) db.executescript(open(os.path.join(_SCRIPT_DIR, "schema.sql")).read()) @@ -500,6 +702,8 @@ def main(): parser.add_argument("--subset-size", type=int, required=True) parser.add_argument("-k", type=int, default=10, help="KNN k (default 10)") parser.add_argument("-n", type=int, default=50, help="number of queries (default 50)") + parser.add_argument("--phase", choices=["build", "query", "both"], default="both", + help="build=build only, query=query existing index, both=default") parser.add_argument("--base-db", default=BASE_DB) parser.add_argument("--ext", default=EXT_PATH) parser.add_argument("-o", "--out-dir", default="runs") @@ -508,55 +712,164 @@ def main(): args = parser.parse_args() os.makedirs(args.out_dir, exist_ok=True) - results_db = args.results_db or os.path.join(args.out_dir, "results.db") + results_db_path = args.results_db or os.path.join(args.out_dir, "results.db") configs = [parse_config(c) for c in args.configs] + results_db = open_results_db(results_db_path) all_results = [] for i, (name, params) in enumerate(configs, 1): reg = INDEX_REGISTRY[params["index_type"]] desc = reg["describe"](params) - print(f"\n[{i}/{len(configs)}] {name} ({desc.strip()})") + print(f"\n[{i}/{len(configs)}] {name} ({desc.strip()}) [phase={args.phase}]") - build = build_index( - args.base_db, args.ext, name, params, args.subset_size, args.out_dir - ) - train_str = f" + {build['train_time_s']}s train" if build["train_time_s"] > 0 else "" - print( - f" Build: {build['insert_time_s']}s insert{train_str} " - f"{build['file_size_mb']} MB" - ) + db_path = os.path.join(args.out_dir, f"{name}.{args.subset_size}.db") - print(f" Measuring KNN (k={args.k}, n={args.n})...") - knn = measure_knn( - build["db_path"], args.ext, args.base_db, - params, args.subset_size, k=args.k, n=args.n, - ) - print(f" KNN: mean={knn['mean_ms']}ms recall@{args.k}={knn['recall']}") + if args.phase == "build": + run_id = create_run(results_db, name, params["index_type"], + args.subset_size, "build") + update_run(results_db, run_id, status="inserting") - all_results.append({ - "name": name, - "n_vectors": args.subset_size, - "index_type": params["index_type"], - "config_desc": desc, - "db_path": build["db_path"], - "insert_time_s": build["insert_time_s"], - "train_time_s": build["train_time_s"], - "total_time_s": build["total_time_s"], - "insert_per_vec_ms": build["insert_per_vec_ms"], - "rows": build["rows"], - "file_size_mb": build["file_size_mb"], - "k": args.k, - "n_queries": args.n, - "mean_ms": knn["mean_ms"], - "median_ms": knn["median_ms"], - "p99_ms": knn["p99_ms"], - "total_ms": knn["total_ms"], - "recall": knn["recall"], - }) + build = build_index( + args.base_db, args.ext, name, params, args.subset_size, args.out_dir + ) + train_str = f" + {build['train_time_s']}s train" if build["train_time_s"] > 0 else "" + print( + f" Build: {build['insert_time_s']}s insert{train_str} " + f"{build['file_size_mb']} MB" + ) + update_run(results_db, run_id, + status="built", + db_path=build["db_path"], + insert_time_s=build["insert_time_s"], + train_time_s=build["train_time_s"], + total_build_time_s=build["total_time_s"], + rows=build["rows"], + file_size_mb=build["file_size_mb"], + finished_at=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")) + print(f" Index DB: {build['db_path']}") - print_report(all_results) - save_results(results_db, all_results) - print(f"\nResults saved to {results_db}") + elif args.phase == "query": + if not os.path.exists(db_path): + raise FileNotFoundError( + f"Index DB not found: {db_path}\n" + f"Build it first with: --phase build" + ) + + run_id = create_run(results_db, name, params["index_type"], + args.subset_size, "query", k=args.k, n=args.n) + update_run(results_db, run_id, status="querying") + + pre_hook = reg.get("pre_query_hook") + print(f" Measuring KNN (k={args.k}, n={args.n})...") + knn = measure_knn( + db_path, args.ext, args.base_db, + params, args.subset_size, k=args.k, n=args.n, + pre_query_hook=pre_hook, + ) + print(f" KNN: mean={knn['mean_ms']}ms recall@{args.k}={knn['recall']}") + + qps = round(args.n / (knn["total_ms"] / 1000), 1) if knn["total_ms"] > 0 else 0 + update_run(results_db, run_id, + status="done", + db_path=db_path, + mean_ms=knn["mean_ms"], + median_ms=knn["median_ms"], + p99_ms=knn["p99_ms"], + total_query_ms=knn["total_ms"], + qps=qps, + recall=knn["recall"], + finished_at=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")) + + file_size_mb = os.path.getsize(db_path) / (1024 * 1024) + all_results.append({ + "name": name, + "n_vectors": args.subset_size, + "index_type": params["index_type"], + "config_desc": desc, + "db_path": db_path, + "insert_time_s": 0, + "train_time_s": 0, + "total_time_s": 0, + "insert_per_vec_ms": 0, + "rows": 0, + "file_size_mb": file_size_mb, + "k": args.k, + "n_queries": args.n, + "mean_ms": knn["mean_ms"], + "median_ms": knn["median_ms"], + "p99_ms": knn["p99_ms"], + "total_ms": knn["total_ms"], + "recall": knn["recall"], + }) + + else: # both + run_id = create_run(results_db, name, params["index_type"], + args.subset_size, "both", k=args.k, n=args.n) + update_run(results_db, run_id, status="inserting") + + build = build_index( + args.base_db, args.ext, name, params, args.subset_size, args.out_dir + ) + train_str = f" + {build['train_time_s']}s train" if build["train_time_s"] > 0 else "" + print( + f" Build: {build['insert_time_s']}s insert{train_str} " + f"{build['file_size_mb']} MB" + ) + update_run(results_db, run_id, status="querying", + db_path=build["db_path"], + insert_time_s=build["insert_time_s"], + train_time_s=build["train_time_s"], + total_build_time_s=build["total_time_s"], + rows=build["rows"], + file_size_mb=build["file_size_mb"]) + + print(f" Measuring KNN (k={args.k}, n={args.n})...") + knn = measure_knn( + build["db_path"], args.ext, args.base_db, + params, args.subset_size, k=args.k, n=args.n, + ) + print(f" KNN: mean={knn['mean_ms']}ms recall@{args.k}={knn['recall']}") + + qps = round(args.n / (knn["total_ms"] / 1000), 1) if knn["total_ms"] > 0 else 0 + update_run(results_db, run_id, + status="done", + mean_ms=knn["mean_ms"], + median_ms=knn["median_ms"], + p99_ms=knn["p99_ms"], + total_query_ms=knn["total_ms"], + qps=qps, + recall=knn["recall"], + finished_at=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")) + + all_results.append({ + "name": name, + "n_vectors": args.subset_size, + "index_type": params["index_type"], + "config_desc": desc, + "db_path": build["db_path"], + "insert_time_s": build["insert_time_s"], + "train_time_s": build["train_time_s"], + "total_time_s": build["total_time_s"], + "insert_per_vec_ms": build["insert_per_vec_ms"], + "rows": build["rows"], + "file_size_mb": build["file_size_mb"], + "k": args.k, + "n_queries": args.n, + "mean_ms": knn["mean_ms"], + "median_ms": knn["median_ms"], + "p99_ms": knn["p99_ms"], + "total_ms": knn["total_ms"], + "recall": knn["recall"], + }) + + results_db.close() + + if all_results: + print_report(all_results) + save_results(results_db_path, all_results) + print(f"\nResults saved to {results_db_path}") + elif args.phase == "build": + print(f"\nBuild complete. Results tracked in {results_db_path}") if __name__ == "__main__": diff --git a/benchmarks-ann/schema.sql b/benchmarks-ann/schema.sql index 681df4e..ae8acf3 100644 --- a/benchmarks-ann/schema.sql +++ b/benchmarks-ann/schema.sql @@ -3,6 +3,31 @@ -- "baseline"; index-specific branches add their own types (registered -- via INDEX_REGISTRY in bench.py). +CREATE TABLE IF NOT EXISTS runs ( + run_id INTEGER PRIMARY KEY AUTOINCREMENT, + config_name TEXT NOT NULL, + index_type TEXT NOT NULL, + subset_size INTEGER NOT NULL, + phase TEXT NOT NULL DEFAULT 'both', -- 'build', 'query', or 'both' + status TEXT NOT NULL DEFAULT 'pending', + k INTEGER, + n INTEGER, + db_path TEXT, + insert_time_s REAL, + train_time_s REAL, + total_build_time_s REAL, + rows INTEGER, + file_size_mb REAL, + mean_ms REAL, + median_ms REAL, + p99_ms REAL, + total_query_ms REAL, + qps REAL, + recall REAL, + created_at TEXT NOT NULL DEFAULT (datetime('now')), + finished_at TEXT +); + CREATE TABLE IF NOT EXISTS build_results ( config_name TEXT NOT NULL, index_type TEXT NOT NULL, diff --git a/sqlite-vec-diskann.c b/sqlite-vec-diskann.c new file mode 100644 index 0000000..1a5fd2b --- /dev/null +++ b/sqlite-vec-diskann.c @@ -0,0 +1,1768 @@ +// DiskANN algorithm implementation +// This file is #include'd into sqlite-vec.c — not compiled separately. + +// ============================================================ +// DiskANN node blob encode/decode functions +// ============================================================ + +/** Compute size of validity bitmap in bytes. */ +int diskann_validity_byte_size(int n_neighbors) { + return n_neighbors / CHAR_BIT; +} + +/** Compute size of neighbor IDs blob in bytes. */ +size_t diskann_neighbor_ids_byte_size(int n_neighbors) { + return (size_t)n_neighbors * sizeof(i64); +} + +/** Compute size of quantized vectors blob in bytes. */ +size_t diskann_neighbor_qvecs_byte_size( + int n_neighbors, enum Vec0DiskannQuantizerType quantizer_type, + size_t dimensions) { + return (size_t)n_neighbors * + diskann_quantized_vector_byte_size(quantizer_type, dimensions); +} + +/** + * Create empty blobs for a new DiskANN node (all neighbors invalid). + * Caller must free the returned pointers with sqlite3_free(). + */ +int diskann_node_init( + int n_neighbors, enum Vec0DiskannQuantizerType quantizer_type, + size_t dimensions, + u8 **outValidity, int *outValiditySize, + u8 **outNeighborIds, int *outNeighborIdsSize, + u8 **outNeighborQvecs, int *outNeighborQvecsSize) { + + int validitySize = diskann_validity_byte_size(n_neighbors); + size_t idsSize = diskann_neighbor_ids_byte_size(n_neighbors); + size_t qvecsSize = diskann_neighbor_qvecs_byte_size( + n_neighbors, quantizer_type, dimensions); + + u8 *validity = sqlite3_malloc(validitySize); + u8 *ids = sqlite3_malloc(idsSize); + u8 *qvecs = sqlite3_malloc(qvecsSize); + + if (!validity || !ids || !qvecs) { + sqlite3_free(validity); + sqlite3_free(ids); + sqlite3_free(qvecs); + return SQLITE_NOMEM; + } + + memset(validity, 0, validitySize); + memset(ids, 0, idsSize); + memset(qvecs, 0, qvecsSize); + + *outValidity = validity; *outValiditySize = validitySize; + *outNeighborIds = ids; *outNeighborIdsSize = (int)idsSize; + *outNeighborQvecs = qvecs; *outNeighborQvecsSize = (int)qvecsSize; + return SQLITE_OK; +} + +/** Check if neighbor slot i is valid. */ +int diskann_validity_get(const u8 *validity, int i) { + return (validity[i / CHAR_BIT] >> (i % CHAR_BIT)) & 1; +} + +/** Set neighbor slot i as valid (1) or invalid (0). */ +void diskann_validity_set(u8 *validity, int i, int value) { + if (value) { + validity[i / CHAR_BIT] |= (1 << (i % CHAR_BIT)); + } else { + validity[i / CHAR_BIT] &= ~(1 << (i % CHAR_BIT)); + } +} + +/** Count the number of valid neighbors. */ +int diskann_validity_count(const u8 *validity, int n_neighbors) { + int count = 0; + for (int i = 0; i < n_neighbors; i++) { + count += diskann_validity_get(validity, i); + } + return count; +} + +/** Get the rowid of the neighbor in slot i. */ +i64 diskann_neighbor_id_get(const u8 *neighbor_ids, int i) { + i64 result; + memcpy(&result, neighbor_ids + i * sizeof(i64), sizeof(i64)); + return result; +} + +/** Set the rowid of the neighbor in slot i. */ +void diskann_neighbor_id_set(u8 *neighbor_ids, int i, i64 rowid) { + memcpy(neighbor_ids + i * sizeof(i64), &rowid, sizeof(i64)); +} + +/** Get a pointer to the quantized vector in slot i (read-only). */ +const u8 *diskann_neighbor_qvec_get( + const u8 *qvecs, int i, + enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions) { + size_t qvec_size = diskann_quantized_vector_byte_size(quantizer_type, dimensions); + return qvecs + (size_t)i * qvec_size; +} + +/** Copy a quantized vector into slot i. */ +void diskann_neighbor_qvec_set( + u8 *qvecs, int i, const u8 *src_qvec, + enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions) { + size_t qvec_size = diskann_quantized_vector_byte_size(quantizer_type, dimensions); + memcpy(qvecs + (size_t)i * qvec_size, src_qvec, qvec_size); +} + +/** + * Set neighbor slot i with a rowid and quantized vector, and mark as valid. + */ +void diskann_node_set_neighbor( + u8 *validity, u8 *neighbor_ids, u8 *qvecs, int i, + i64 neighbor_rowid, const u8 *neighbor_qvec, + enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions) { + diskann_validity_set(validity, i, 1); + diskann_neighbor_id_set(neighbor_ids, i, neighbor_rowid); + diskann_neighbor_qvec_set(qvecs, i, neighbor_qvec, quantizer_type, dimensions); +} + +/** + * Clear neighbor slot i (mark invalid, zero out data). + */ +void diskann_node_clear_neighbor( + u8 *validity, u8 *neighbor_ids, u8 *qvecs, int i, + enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions) { + diskann_validity_set(validity, i, 0); + diskann_neighbor_id_set(neighbor_ids, i, 0); + size_t qvec_size = diskann_quantized_vector_byte_size(quantizer_type, dimensions); + memset(qvecs + (size_t)i * qvec_size, 0, qvec_size); +} + +/** + * Quantize a full-precision float32 vector into the target quantizer format. + * Output buffer must be pre-allocated with diskann_quantized_vector_byte_size() bytes. + */ +int diskann_quantize_vector( + const f32 *src, size_t dimensions, + enum Vec0DiskannQuantizerType quantizer_type, + u8 *out) { + + switch (quantizer_type) { + case VEC0_DISKANN_QUANTIZER_BINARY: { + memset(out, 0, dimensions / CHAR_BIT); + for (size_t i = 0; i < dimensions; i++) { + if (src[i] > 0.0f) { + out[i / CHAR_BIT] |= (1 << (i % CHAR_BIT)); + } + } + return SQLITE_OK; + } + case VEC0_DISKANN_QUANTIZER_INT8: { + f32 step = (1.0f - (-1.0f)) / 255.0f; + for (size_t i = 0; i < dimensions; i++) { + ((i8 *)out)[i] = (i8)(((src[i] - (-1.0f)) / step) - 128.0f); + } + return SQLITE_OK; + } + } + return SQLITE_ERROR; +} + +/** + * Compute approximate distance between a full-precision query vector and a + * quantized neighbor vector. Used during graph traversal. + */ +/** + * Compute distance between a pre-quantized query and a quantized neighbor. + * The caller is responsible for quantizing the query vector once and passing + * the result here for each neighbor comparison. + */ +static f32 diskann_distance_quantized_precomputed( + const u8 *query_quantized, const u8 *quantized_neighbor, + size_t dimensions, + enum Vec0DiskannQuantizerType quantizer_type, + enum Vec0DistanceMetrics distance_metric) { + + switch (quantizer_type) { + case VEC0_DISKANN_QUANTIZER_BINARY: + return distance_hamming(query_quantized, quantized_neighbor, &dimensions); + case VEC0_DISKANN_QUANTIZER_INT8: { + switch (distance_metric) { + case VEC0_DISTANCE_METRIC_L2: + return distance_l2_sqr_int8(query_quantized, quantized_neighbor, &dimensions); + case VEC0_DISTANCE_METRIC_COSINE: + return distance_cosine_int8(query_quantized, quantized_neighbor, &dimensions); + case VEC0_DISTANCE_METRIC_L1: + return (f32)distance_l1_int8(query_quantized, quantized_neighbor, &dimensions); + } + break; + } + } + return FLT_MAX; +} + +/** + * Quantize a float query vector. Returns allocated buffer (caller must free). + */ +static u8 *diskann_quantize_query( + const f32 *query_vector, size_t dimensions, + enum Vec0DiskannQuantizerType quantizer_type) { + size_t qsize = diskann_quantized_vector_byte_size(quantizer_type, dimensions); + u8 *buf = sqlite3_malloc(qsize); + if (!buf) return NULL; + diskann_quantize_vector(query_vector, dimensions, quantizer_type, buf); + return buf; +} + +/** + * Legacy wrapper: quantizes on-the-fly (used by callers that don't pre-quantize). + */ +f32 diskann_distance_quantized( + const void *query_vector, const u8 *quantized_neighbor, + size_t dimensions, + enum Vec0DiskannQuantizerType quantizer_type, + enum Vec0DistanceMetrics distance_metric) { + + u8 *query_q = diskann_quantize_query((const f32 *)query_vector, dimensions, quantizer_type); + if (!query_q) return FLT_MAX; + f32 dist = diskann_distance_quantized_precomputed( + query_q, quantized_neighbor, dimensions, quantizer_type, distance_metric); + sqlite3_free(query_q); + return dist; +} + +// ============================================================ +// DiskANN medoid / entry point management +// ============================================================ + +/** + * Get the current medoid rowid for the given vector column's DiskANN index. + * Returns SQLITE_OK with *outMedoid set to the medoid rowid. + * If the graph is empty, returns SQLITE_OK with *outIsEmpty = 1. + */ +static int diskann_medoid_get(vec0_vtab *p, int vec_col_idx, + i64 *outMedoid, int *outIsEmpty) { + int rc; + sqlite3_stmt *stmt = NULL; + char *key = sqlite3_mprintf("diskann_medoid_%02d", vec_col_idx); + char *zSql = sqlite3_mprintf( + "SELECT value FROM " VEC0_SHADOW_INFO_NAME " WHERE key = ?", + p->schemaName, p->tableName); + if (!key || !zSql) { + sqlite3_free(key); + sqlite3_free(zSql); + return SQLITE_NOMEM; + } + + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) { + sqlite3_free(key); + return rc; + } + + sqlite3_bind_text(stmt, 1, key, -1, sqlite3_free); + rc = sqlite3_step(stmt); + if (rc == SQLITE_ROW) { + if (sqlite3_column_type(stmt, 0) == SQLITE_NULL) { + *outIsEmpty = 1; + } else { + *outIsEmpty = 0; + *outMedoid = sqlite3_column_int64(stmt, 0); + } + rc = SQLITE_OK; + } else { + rc = SQLITE_ERROR; + } + sqlite3_finalize(stmt); + return rc; +} + +/** + * Set the medoid rowid for the given vector column's DiskANN index. + * Pass isEmpty = 1 to mark the graph as empty (NULL medoid). + */ +static int diskann_medoid_set(vec0_vtab *p, int vec_col_idx, + i64 medoidRowid, int isEmpty) { + int rc; + sqlite3_stmt *stmt = NULL; + char *key = sqlite3_mprintf("diskann_medoid_%02d", vec_col_idx); + char *zSql = sqlite3_mprintf( + "UPDATE " VEC0_SHADOW_INFO_NAME " SET value = ?2 WHERE key = ?1", + p->schemaName, p->tableName); + if (!key || !zSql) { + sqlite3_free(key); + sqlite3_free(zSql); + return SQLITE_NOMEM; + } + + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) { + sqlite3_free(key); + return rc; + } + + sqlite3_bind_text(stmt, 1, key, -1, sqlite3_free); + if (isEmpty) { + sqlite3_bind_null(stmt, 2); + } else { + sqlite3_bind_int64(stmt, 2, medoidRowid); + } + rc = sqlite3_step(stmt); + sqlite3_finalize(stmt); + return (rc == SQLITE_DONE) ? SQLITE_OK : SQLITE_ERROR; +} + + +/** + * Called when deleting a vector. If the deleted vector was the medoid, + * pick a new one (the first available vector, or set to empty if none remain). + */ +static int diskann_medoid_handle_delete(vec0_vtab *p, int vec_col_idx, + i64 deletedRowid) { + i64 currentMedoid; + int isEmpty; + int rc = diskann_medoid_get(p, vec_col_idx, ¤tMedoid, &isEmpty); + if (rc != SQLITE_OK) return rc; + + if (!isEmpty && currentMedoid == deletedRowid) { + sqlite3_stmt *stmt = NULL; + char *zSql = sqlite3_mprintf( + "SELECT rowid FROM " VEC0_SHADOW_VECTORS_N_NAME " WHERE rowid != ?1 LIMIT 1", + p->schemaName, p->tableName, vec_col_idx); + if (!zSql) return SQLITE_NOMEM; + + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) return rc; + + sqlite3_bind_int64(stmt, 1, deletedRowid); + rc = sqlite3_step(stmt); + if (rc == SQLITE_ROW) { + i64 newMedoid = sqlite3_column_int64(stmt, 0); + sqlite3_finalize(stmt); + return diskann_medoid_set(p, vec_col_idx, newMedoid, 0); + } else { + sqlite3_finalize(stmt); + return diskann_medoid_set(p, vec_col_idx, -1, 1); + } + } + return SQLITE_OK; +} + +// ============================================================ +// DiskANN database I/O helpers +// ============================================================ + +/** + * Read a node's full data from _diskann_nodes. + * Returns blobs that must be freed by the caller with sqlite3_free(). + */ +static int diskann_node_read(vec0_vtab *p, int vec_col_idx, i64 rowid, + u8 **outValidity, int *outValiditySize, + u8 **outNeighborIds, int *outNeighborIdsSize, + u8 **outQvecs, int *outQvecsSize) { + int rc; + if (!p->stmtDiskannNodeRead[vec_col_idx]) { + char *zSql = sqlite3_mprintf( + "SELECT neighbors_validity, neighbor_ids, neighbor_quantized_vectors " + "FROM " VEC0_SHADOW_DISKANN_NODES_N_NAME " WHERE rowid = ?", + p->schemaName, p->tableName, vec_col_idx); + if (!zSql) return SQLITE_NOMEM; + rc = sqlite3_prepare_v2(p->db, zSql, -1, + &p->stmtDiskannNodeRead[vec_col_idx], NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) return rc; + } + + sqlite3_stmt *stmt = p->stmtDiskannNodeRead[vec_col_idx]; + sqlite3_reset(stmt); + sqlite3_bind_int64(stmt, 1, rowid); + + rc = sqlite3_step(stmt); + if (rc != SQLITE_ROW) { + return SQLITE_ERROR; + } + + int vs = sqlite3_column_bytes(stmt, 0); + int is = sqlite3_column_bytes(stmt, 1); + int qs = sqlite3_column_bytes(stmt, 2); + + // Validate blob sizes against config expectations to detect truncated / + // corrupt data before any caller iterates using cfg->n_neighbors. + { + struct VectorColumnDefinition *col = &p->vector_columns[vec_col_idx]; + struct Vec0DiskannConfig *cfg = &col->diskann; + int expectedVs = diskann_validity_byte_size(cfg->n_neighbors); + int expectedIs = (int)diskann_neighbor_ids_byte_size(cfg->n_neighbors); + int expectedQs = (int)diskann_neighbor_qvecs_byte_size( + cfg->n_neighbors, cfg->quantizer_type, col->dimensions); + if (vs != expectedVs || is != expectedIs || qs != expectedQs) { + return SQLITE_CORRUPT; + } + } + + u8 *v = sqlite3_malloc(vs); + u8 *ids = sqlite3_malloc(is); + u8 *qv = sqlite3_malloc(qs); + if (!v || !ids || !qv) { + sqlite3_free(v); + sqlite3_free(ids); + sqlite3_free(qv); + return SQLITE_NOMEM; + } + + memcpy(v, sqlite3_column_blob(stmt, 0), vs); + memcpy(ids, sqlite3_column_blob(stmt, 1), is); + memcpy(qv, sqlite3_column_blob(stmt, 2), qs); + + *outValidity = v; *outValiditySize = vs; + *outNeighborIds = ids; *outNeighborIdsSize = is; + *outQvecs = qv; *outQvecsSize = qs; + return SQLITE_OK; +} + +/** + * Write (INSERT OR REPLACE) a node's data to _diskann_nodes. + */ +static int diskann_node_write(vec0_vtab *p, int vec_col_idx, i64 rowid, + const u8 *validity, int validitySize, + const u8 *neighborIds, int neighborIdsSize, + const u8 *qvecs, int qvecsSize) { + int rc; + if (!p->stmtDiskannNodeWrite[vec_col_idx]) { + char *zSql = sqlite3_mprintf( + "INSERT OR REPLACE INTO " VEC0_SHADOW_DISKANN_NODES_N_NAME + " (rowid, neighbors_validity, neighbor_ids, neighbor_quantized_vectors) " + "VALUES (?, ?, ?, ?)", + p->schemaName, p->tableName, vec_col_idx); + if (!zSql) return SQLITE_NOMEM; + rc = sqlite3_prepare_v2(p->db, zSql, -1, + &p->stmtDiskannNodeWrite[vec_col_idx], NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) return rc; + } + + sqlite3_stmt *stmt = p->stmtDiskannNodeWrite[vec_col_idx]; + sqlite3_reset(stmt); + sqlite3_bind_int64(stmt, 1, rowid); + sqlite3_bind_blob(stmt, 2, validity, validitySize, SQLITE_TRANSIENT); + sqlite3_bind_blob(stmt, 3, neighborIds, neighborIdsSize, SQLITE_TRANSIENT); + sqlite3_bind_blob(stmt, 4, qvecs, qvecsSize, SQLITE_TRANSIENT); + + rc = sqlite3_step(stmt); + return (rc == SQLITE_DONE) ? SQLITE_OK : SQLITE_ERROR; +} + +/** + * Read the full-precision vector for a given rowid from _vectors. + * Caller must free *outVector with sqlite3_free(). + */ +static int diskann_vector_read(vec0_vtab *p, int vec_col_idx, i64 rowid, + void **outVector, int *outVectorSize) { + int rc; + if (!p->stmtVectorsRead[vec_col_idx]) { + char *zSql = sqlite3_mprintf( + "SELECT vector FROM " VEC0_SHADOW_VECTORS_N_NAME " WHERE rowid = ?", + p->schemaName, p->tableName, vec_col_idx); + if (!zSql) return SQLITE_NOMEM; + rc = sqlite3_prepare_v2(p->db, zSql, -1, + &p->stmtVectorsRead[vec_col_idx], NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) return rc; + } + + sqlite3_stmt *stmt = p->stmtVectorsRead[vec_col_idx]; + sqlite3_reset(stmt); + sqlite3_bind_int64(stmt, 1, rowid); + + rc = sqlite3_step(stmt); + if (rc != SQLITE_ROW) { + return SQLITE_ERROR; + } + + int sz = sqlite3_column_bytes(stmt, 0); + void *vec = sqlite3_malloc(sz); + if (!vec) return SQLITE_NOMEM; + memcpy(vec, sqlite3_column_blob(stmt, 0), sz); + + *outVector = vec; + *outVectorSize = sz; + return SQLITE_OK; +} + +/** + * Write a full-precision vector to _vectors. + */ +static int diskann_vector_write(vec0_vtab *p, int vec_col_idx, i64 rowid, + const void *vector, int vectorSize) { + int rc; + if (!p->stmtVectorsInsert[vec_col_idx]) { + char *zSql = sqlite3_mprintf( + "INSERT OR REPLACE INTO " VEC0_SHADOW_VECTORS_N_NAME + " (rowid, vector) VALUES (?, ?)", + p->schemaName, p->tableName, vec_col_idx); + if (!zSql) return SQLITE_NOMEM; + rc = sqlite3_prepare_v2(p->db, zSql, -1, + &p->stmtVectorsInsert[vec_col_idx], NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) return rc; + } + + sqlite3_stmt *stmt = p->stmtVectorsInsert[vec_col_idx]; + sqlite3_reset(stmt); + sqlite3_bind_int64(stmt, 1, rowid); + sqlite3_bind_blob(stmt, 2, vector, vectorSize, SQLITE_TRANSIENT); + + rc = sqlite3_step(stmt); + return (rc == SQLITE_DONE) ? SQLITE_OK : SQLITE_ERROR; +} + +// ============================================================ +// DiskANN search data structures +// ============================================================ + +/** + * A sorted candidate list for greedy beam search. + */ +struct DiskannCandidateList { + struct Vec0DiskannCandidate *items; + int count; + int capacity; +}; + +static int diskann_candidate_list_init(struct DiskannCandidateList *list, int capacity) { + list->items = sqlite3_malloc(capacity * sizeof(struct Vec0DiskannCandidate)); + if (!list->items) return SQLITE_NOMEM; + list->count = 0; + list->capacity = capacity; + return SQLITE_OK; +} + +static void diskann_candidate_list_free(struct DiskannCandidateList *list) { + sqlite3_free(list->items); + list->items = NULL; + list->count = 0; + list->capacity = 0; +} + +/** + * Insert a candidate into the sorted list, maintaining sort order by distance. + * Deduplicates by rowid. If at capacity and new candidate is worse, discards it. + * Returns 1 if inserted, 0 if discarded. + */ +static int diskann_candidate_list_insert( + struct DiskannCandidateList *list, i64 rowid, f32 distance) { + + // Check for duplicate + for (int i = 0; i < list->count; i++) { + if (list->items[i].rowid == rowid) { + // Update distance if better + if (distance < list->items[i].distance) { + list->items[i].distance = distance; + // Re-sort this item into position + struct Vec0DiskannCandidate tmp = list->items[i]; + int j = i - 1; + while (j >= 0 && list->items[j].distance > tmp.distance) { + list->items[j + 1] = list->items[j]; + j--; + } + list->items[j + 1] = tmp; + } + return 1; + } + } + + // If at capacity, check if new candidate is better than worst + if (list->count >= list->capacity) { + if (distance >= list->items[list->count - 1].distance) { + return 0; // Discard + } + list->count--; // Make room by dropping the worst + } + + // Binary search for insertion point + int lo = 0, hi = list->count; + while (lo < hi) { + int mid = (lo + hi) / 2; + if (list->items[mid].distance < distance) { + lo = mid + 1; + } else { + hi = mid; + } + } + + // Shift elements to make room + memmove(&list->items[lo + 1], &list->items[lo], + (list->count - lo) * sizeof(struct Vec0DiskannCandidate)); + + list->items[lo].rowid = rowid; + list->items[lo].distance = distance; + list->items[lo].visited = 0; + list->count++; + return 1; +} + +/** + * Find the closest unvisited candidate. Returns its index, or -1 if none. + */ +static int diskann_candidate_list_next_unvisited( + const struct DiskannCandidateList *list) { + for (int i = 0; i < list->count; i++) { + if (!list->items[i].visited) return i; + } + return -1; +} + + + +/** + * Simple hash set for tracking visited rowids during search. + * Uses open addressing with linear probing. + */ +struct DiskannVisitedSet { + i64 *slots; + int capacity; + int count; +}; + +static int diskann_visited_set_init(struct DiskannVisitedSet *set, int capacity) { + // Round up to power of 2 + int cap = 16; + while (cap < capacity) cap *= 2; + set->slots = sqlite3_malloc(cap * sizeof(i64)); + if (!set->slots) return SQLITE_NOMEM; + memset(set->slots, 0, cap * sizeof(i64)); + set->capacity = cap; + set->count = 0; + return SQLITE_OK; +} + +static void diskann_visited_set_free(struct DiskannVisitedSet *set) { + sqlite3_free(set->slots); + set->slots = NULL; + set->capacity = 0; + set->count = 0; +} + +static int diskann_visited_set_contains(const struct DiskannVisitedSet *set, i64 rowid) { + if (rowid == 0) return 0; // 0 is our sentinel for empty + int mask = set->capacity - 1; + int idx = (int)(((u64)rowid * 0x9E3779B97F4A7C15ULL) >> 32) & mask; + for (int i = 0; i < set->capacity; i++) { + int slot = (idx + i) & mask; + if (set->slots[slot] == 0) return 0; + if (set->slots[slot] == rowid) return 1; + } + return 0; +} + +static int diskann_visited_set_insert(struct DiskannVisitedSet *set, i64 rowid) { + if (rowid == 0) return 0; + int mask = set->capacity - 1; + int idx = (int)(((u64)rowid * 0x9E3779B97F4A7C15ULL) >> 32) & mask; + for (int i = 0; i < set->capacity; i++) { + int slot = (idx + i) & mask; + if (set->slots[slot] == 0) { + set->slots[slot] = rowid; + set->count++; + return 1; + } + if (set->slots[slot] == rowid) return 0; // Already there + } + return 0; // Full (shouldn't happen with proper sizing) +} + +// ============================================================ +// DiskANN greedy beam search (LM-Search) +// ============================================================ + +/** + * Perform LM-Search: greedy beam search over the DiskANN graph. + * Follows Algorithm 1 from the LM-DiskANN paper. + */ +static int diskann_search( + vec0_vtab *p, int vec_col_idx, + const void *queryVector, size_t dimensions, + enum VectorElementType elementType, + int k, int searchListSize, + i64 *outRowids, f32 *outDistances, int *outCount) { + + struct VectorColumnDefinition *col = &p->vector_columns[vec_col_idx]; + struct Vec0DiskannConfig *cfg = &col->diskann; + int rc; + + if (searchListSize <= 0) { + searchListSize = cfg->search_list_size_search > 0 ? cfg->search_list_size_search : cfg->search_list_size; + } + if (searchListSize < k) { + searchListSize = k; + } + + // 1. Get the medoid (entry point) + i64 medoid; + int isEmpty; + rc = diskann_medoid_get(p, vec_col_idx, &medoid, &isEmpty); + if (rc != SQLITE_OK) return rc; + if (isEmpty) { + *outCount = 0; + return SQLITE_OK; + } + + // 2. Compute distance from query to medoid using full-precision vector + void *medoidVector = NULL; + int medoidVectorSize; + rc = diskann_vector_read(p, vec_col_idx, medoid, &medoidVector, &medoidVectorSize); + if (rc != SQLITE_OK) return rc; + + f32 medoidDist = vec0_distance_full(queryVector, medoidVector, + dimensions, elementType, + col->distance_metric); + sqlite3_free(medoidVector); + + // 3. Initialize candidate list and visited set + struct DiskannCandidateList candidates; + rc = diskann_candidate_list_init(&candidates, searchListSize); + if (rc != SQLITE_OK) return rc; + + struct DiskannVisitedSet visited; + rc = diskann_visited_set_init(&visited, searchListSize * 4); + if (rc != SQLITE_OK) { + diskann_candidate_list_free(&candidates); + return rc; + } + + // Seed with medoid + diskann_candidate_list_insert(&candidates, medoid, medoidDist); + + // Pre-quantize query vector once for all quantized distance comparisons + u8 *queryQuantized = NULL; + if (elementType == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) { + queryQuantized = diskann_quantize_query( + (const f32 *)queryVector, dimensions, cfg->quantizer_type); + } + + // 4. Greedy beam search loop (Algorithm 1 from LM-DiskANN paper) + while (1) { + int nextIdx = diskann_candidate_list_next_unvisited(&candidates); + if (nextIdx < 0) break; + + struct Vec0DiskannCandidate *current = &candidates.items[nextIdx]; + current->visited = 1; + i64 currentRowid = current->rowid; + + // Read the node's neighbor data + u8 *validity = NULL, *neighborIds = NULL, *qvecs = NULL; + int validitySize, neighborIdsSize, qvecsSize; + rc = diskann_node_read(p, vec_col_idx, currentRowid, + &validity, &validitySize, + &neighborIds, &neighborIdsSize, + &qvecs, &qvecsSize); + if (rc != SQLITE_OK) { + continue; // Skip if node doesn't exist + } + + // Insert all valid neighbors with approximate (quantized) distances + for (int i = 0; i < cfg->n_neighbors; i++) { + if (!diskann_validity_get(validity, i)) continue; + + i64 neighborRowid = diskann_neighbor_id_get(neighborIds, i); + + if (diskann_visited_set_contains(&visited, neighborRowid)) continue; + + const u8 *neighborQvec = diskann_neighbor_qvec_get( + qvecs, i, cfg->quantizer_type, dimensions); + + f32 approxDist; + if (queryQuantized) { + approxDist = diskann_distance_quantized_precomputed( + queryQuantized, neighborQvec, dimensions, + cfg->quantizer_type, col->distance_metric); + } else { + approxDist = diskann_distance_quantized( + queryVector, neighborQvec, dimensions, + cfg->quantizer_type, col->distance_metric); + } + + diskann_candidate_list_insert(&candidates, neighborRowid, approxDist); + } + + sqlite3_free(validity); + sqlite3_free(neighborIds); + sqlite3_free(qvecs); + + // Add to visited set + diskann_visited_set_insert(&visited, currentRowid); + + // Paper line 13: Re-rank p* using full-precision distance + // We already have exact distance for medoid; for others, update now + void *fullVec = NULL; + int fullVecSize; + rc = diskann_vector_read(p, vec_col_idx, currentRowid, &fullVec, &fullVecSize); + if (rc == SQLITE_OK) { + f32 exactDist = vec0_distance_full(queryVector, fullVec, + dimensions, elementType, + col->distance_metric); + sqlite3_free(fullVec); + // Update distance in candidate list and re-sort + diskann_candidate_list_insert(&candidates, currentRowid, exactDist); + } + } + + // 5. Output results (candidates are already sorted by distance) + int resultCount = (candidates.count < k) ? candidates.count : k; + *outCount = resultCount; + for (int i = 0; i < resultCount; i++) { + outRowids[i] = candidates.items[i].rowid; + outDistances[i] = candidates.items[i].distance; + } + + sqlite3_free(queryQuantized); + diskann_candidate_list_free(&candidates); + diskann_visited_set_free(&visited); + return SQLITE_OK; +} + +// ============================================================ +// DiskANN RobustPrune (Algorithm 4 from LM-DiskANN paper) +// ============================================================ + +/** + * RobustPrune: Select up to max_neighbors neighbors for node p from a + * candidate set, using alpha-pruning for diversity. + * + * Following Algorithm 4 (LM-Prune): + * C = C union N_out(p) \ {p} + * N_out(p) = empty + * while C not empty: + * p* = argmin d(p, c) for c in C + * N_out(p).insert(p*) + * if |N_out(p)| >= R: break + * for each p' in C: + * if alpha * d(p*, p') <= d(p, p'): remove p' from C + */ +/** + * Pure function: given pre-sorted candidates and a distance matrix, select + * up to max_neighbors using alpha-pruning. inter_distances is a flattened + * num_candidates x num_candidates matrix where inter_distances[i*num_candidates+j] + * = d(candidate_i, candidate_j). p_distances[i] = d(p, candidate_i), already sorted. + * outSelected[i] = 1 if selected. Returns count of selected. + */ +int diskann_prune_select( + const f32 *inter_distances, const f32 *p_distances, + int num_candidates, f32 alpha, int max_neighbors, + int *outSelected, int *outCount) { + + if (num_candidates == 0) { + *outCount = 0; + return SQLITE_OK; + } + + u8 *active = sqlite3_malloc(num_candidates); + if (!active) return SQLITE_NOMEM; + memset(active, 1, num_candidates); + memset(outSelected, 0, num_candidates * sizeof(int)); + + int selectedCount = 0; + + for (int round = 0; round < num_candidates && selectedCount < max_neighbors; round++) { + int bestIdx = -1; + for (int i = 0; i < num_candidates; i++) { + if (active[i]) { bestIdx = i; break; } + } + if (bestIdx < 0) break; + + outSelected[bestIdx] = 1; + selectedCount++; + active[bestIdx] = 0; + + for (int i = 0; i < num_candidates; i++) { + if (!active[i]) continue; + f32 dist_best_to_i = inter_distances[bestIdx * num_candidates + i]; + if (alpha * dist_best_to_i <= p_distances[i]) { + active[i] = 0; + } + } + } + + *outCount = selectedCount; + sqlite3_free(active); + return SQLITE_OK; +} + +static int diskann_robust_prune( + vec0_vtab *p, int vec_col_idx, + i64 p_rowid, const void *p_vector, + i64 *candidates, f32 *candidate_distances, int num_candidates, + f32 alpha, int max_neighbors, + i64 *outNeighborRowids, int *outNeighborCount) { + + struct VectorColumnDefinition *col = &p->vector_columns[vec_col_idx]; + int rc; + + // Remove p itself from candidates + for (int i = 0; i < num_candidates; i++) { + if (candidates[i] == p_rowid) { + candidates[i] = candidates[num_candidates - 1]; + candidate_distances[i] = candidate_distances[num_candidates - 1]; + num_candidates--; + break; + } + } + + if (num_candidates == 0) { + *outNeighborCount = 0; + return SQLITE_OK; + } + + // Sort candidates by distance to p (ascending) - insertion sort + for (int i = 1; i < num_candidates; i++) { + f32 tmpDist = candidate_distances[i]; + i64 tmpRowid = candidates[i]; + int j = i - 1; + while (j >= 0 && candidate_distances[j] > tmpDist) { + candidate_distances[j + 1] = candidate_distances[j]; + candidates[j + 1] = candidates[j]; + j--; + } + candidate_distances[j + 1] = tmpDist; + candidates[j + 1] = tmpRowid; + } + + // Active flags + u8 *active = sqlite3_malloc(num_candidates); + if (!active) return SQLITE_NOMEM; + memset(active, 1, num_candidates); + + // Cache full-precision vectors for inter-candidate distance + void **candidateVectors = sqlite3_malloc(num_candidates * sizeof(void *)); + if (!candidateVectors) { + sqlite3_free(active); + return SQLITE_NOMEM; + } + memset(candidateVectors, 0, num_candidates * sizeof(void *)); + + int selectedCount = 0; + + for (int round = 0; round < num_candidates && selectedCount < max_neighbors; round++) { + // Find closest active candidate + int bestIdx = -1; + for (int i = 0; i < num_candidates; i++) { + if (active[i]) { bestIdx = i; break; } + } + if (bestIdx < 0) break; + + // Select this candidate + outNeighborRowids[selectedCount] = candidates[bestIdx]; + selectedCount++; + active[bestIdx] = 0; + + // Load selected candidate's vector + if (!candidateVectors[bestIdx]) { + int vecSize; + rc = diskann_vector_read(p, vec_col_idx, candidates[bestIdx], + &candidateVectors[bestIdx], &vecSize); + if (rc != SQLITE_OK) continue; + } + + // Alpha-prune: remove candidates covered by the selected neighbor + for (int i = 0; i < num_candidates; i++) { + if (!active[i]) continue; + + if (!candidateVectors[i]) { + int vecSize; + rc = diskann_vector_read(p, vec_col_idx, candidates[i], + &candidateVectors[i], &vecSize); + if (rc != SQLITE_OK) continue; + } + + f32 dist_selected_to_i = vec0_distance_full( + candidateVectors[bestIdx], candidateVectors[i], + col->dimensions, col->element_type, col->distance_metric); + + if (alpha * dist_selected_to_i <= candidate_distances[i]) { + active[i] = 0; + } + } + } + + *outNeighborCount = selectedCount; + + for (int i = 0; i < num_candidates; i++) { + sqlite3_free(candidateVectors[i]); + } + sqlite3_free(candidateVectors); + sqlite3_free(active); + + return SQLITE_OK; +} + +/** + * After RobustPrune selects neighbors, build the node blobs and write to DB. + * Quantizes each neighbor's vector and packs into the node format. + */ +static int diskann_write_pruned_neighbors( + vec0_vtab *p, int vec_col_idx, i64 nodeRowid, + const i64 *neighborRowids, int neighborCount) { + + struct VectorColumnDefinition *col = &p->vector_columns[vec_col_idx]; + struct Vec0DiskannConfig *cfg = &col->diskann; + int rc; + + u8 *validity, *neighborIds, *qvecs; + int validitySize, neighborIdsSize, qvecsSize; + rc = diskann_node_init(cfg->n_neighbors, cfg->quantizer_type, + col->dimensions, + &validity, &validitySize, + &neighborIds, &neighborIdsSize, + &qvecs, &qvecsSize); + if (rc != SQLITE_OK) return rc; + + size_t qvecSize = diskann_quantized_vector_byte_size( + cfg->quantizer_type, col->dimensions); + u8 *qvec = sqlite3_malloc(qvecSize); + if (!qvec) { + sqlite3_free(validity); + sqlite3_free(neighborIds); + sqlite3_free(qvecs); + return SQLITE_NOMEM; + } + + for (int i = 0; i < neighborCount && i < cfg->n_neighbors; i++) { + void *neighborVec = NULL; + int neighborVecSize; + rc = diskann_vector_read(p, vec_col_idx, neighborRowids[i], + &neighborVec, &neighborVecSize); + if (rc != SQLITE_OK) continue; + + if (col->element_type == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) { + diskann_quantize_vector((const f32 *)neighborVec, col->dimensions, + cfg->quantizer_type, qvec); + } else { + memcpy(qvec, neighborVec, + qvecSize < (size_t)neighborVecSize ? qvecSize : (size_t)neighborVecSize); + } + + diskann_node_set_neighbor(validity, neighborIds, qvecs, i, + neighborRowids[i], qvec, + cfg->quantizer_type, col->dimensions); + + sqlite3_free(neighborVec); + } + sqlite3_free(qvec); + + rc = diskann_node_write(p, vec_col_idx, nodeRowid, + validity, validitySize, + neighborIds, neighborIdsSize, + qvecs, qvecsSize); + + sqlite3_free(validity); + sqlite3_free(neighborIds); + sqlite3_free(qvecs); + return rc; +} + +// ============================================================ +// DiskANN insert (Algorithm 2 from LM-DiskANN paper) +// ============================================================ + +/** + * Add a reverse edge: make target_rowid a neighbor of node_rowid. + * If node is full, run RobustPrune. + */ +static int diskann_add_reverse_edge( + vec0_vtab *p, int vec_col_idx, + i64 node_rowid, i64 target_rowid, const void *target_vector) { + + struct VectorColumnDefinition *col = &p->vector_columns[vec_col_idx]; + struct Vec0DiskannConfig *cfg = &col->diskann; + int rc; + + u8 *validity = NULL, *neighborIds = NULL, *qvecs = NULL; + int validitySize, neighborIdsSize, qvecsSize; + rc = diskann_node_read(p, vec_col_idx, node_rowid, + &validity, &validitySize, + &neighborIds, &neighborIdsSize, + &qvecs, &qvecsSize); + if (rc != SQLITE_OK) return rc; + + int currentCount = diskann_validity_count(validity, cfg->n_neighbors); + + // Check if target is already a neighbor + for (int i = 0; i < cfg->n_neighbors; i++) { + if (diskann_validity_get(validity, i) && + diskann_neighbor_id_get(neighborIds, i) == target_rowid) { + sqlite3_free(validity); + sqlite3_free(neighborIds); + sqlite3_free(qvecs); + return SQLITE_OK; + } + } + + if (currentCount < cfg->n_neighbors) { + // Room available: find first empty slot + for (int i = 0; i < cfg->n_neighbors; i++) { + if (!diskann_validity_get(validity, i)) { + size_t qvecSize = diskann_quantized_vector_byte_size( + cfg->quantizer_type, col->dimensions); + u8 *qvec = sqlite3_malloc(qvecSize); + if (!qvec) { + sqlite3_free(validity); + sqlite3_free(neighborIds); + sqlite3_free(qvecs); + return SQLITE_NOMEM; + } + + if (col->element_type == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) { + diskann_quantize_vector((const f32 *)target_vector, col->dimensions, + cfg->quantizer_type, qvec); + } else { + size_t vbs = vector_column_byte_size(*col); + memcpy(qvec, target_vector, qvecSize < vbs ? qvecSize : vbs); + } + + diskann_node_set_neighbor(validity, neighborIds, qvecs, i, + target_rowid, qvec, + cfg->quantizer_type, col->dimensions); + sqlite3_free(qvec); + break; + } + } + + rc = diskann_node_write(p, vec_col_idx, node_rowid, + validity, validitySize, + neighborIds, neighborIdsSize, + qvecs, qvecsSize); + } else { + // Full: lazy replacement — use quantized distances to find the worst + // existing neighbor and replace it if target is closer. This avoids + // reading all neighbors' float vectors (the expensive RobustPrune path). + + // Quantize the node's vector and the target vector for comparison + void *nodeVector = NULL; + int nodeVecSize; + rc = diskann_vector_read(p, vec_col_idx, node_rowid, + &nodeVector, &nodeVecSize); + if (rc != SQLITE_OK) { + sqlite3_free(validity); + sqlite3_free(neighborIds); + sqlite3_free(qvecs); + return rc; + } + + // Quantize target for node-level comparison + size_t qvecSize = diskann_quantized_vector_byte_size( + cfg->quantizer_type, col->dimensions); + u8 *targetQ = sqlite3_malloc(qvecSize); + u8 *nodeQ = sqlite3_malloc(qvecSize); + if (!targetQ || !nodeQ) { + sqlite3_free(targetQ); + sqlite3_free(nodeQ); + sqlite3_free(nodeVector); + sqlite3_free(validity); + sqlite3_free(neighborIds); + sqlite3_free(qvecs); + return SQLITE_NOMEM; + } + + if (col->element_type == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) { + diskann_quantize_vector((const f32 *)target_vector, col->dimensions, + cfg->quantizer_type, targetQ); + diskann_quantize_vector((const f32 *)nodeVector, col->dimensions, + cfg->quantizer_type, nodeQ); + } else { + memcpy(targetQ, target_vector, qvecSize); + memcpy(nodeQ, nodeVector, qvecSize); + } + + // Compute quantized distance from node to target + f32 targetDist = diskann_distance_quantized_precomputed( + nodeQ, targetQ, col->dimensions, + cfg->quantizer_type, col->distance_metric); + + // Find the worst (farthest) existing neighbor using quantized distances + int worstIdx = -1; + f32 worstDist = -1.0f; + for (int i = 0; i < cfg->n_neighbors; i++) { + if (!diskann_validity_get(validity, i)) continue; + const u8 *nqvec = diskann_neighbor_qvec_get( + qvecs, i, cfg->quantizer_type, col->dimensions); + f32 d = diskann_distance_quantized_precomputed( + nodeQ, nqvec, col->dimensions, + cfg->quantizer_type, col->distance_metric); + if (d > worstDist) { + worstDist = d; + worstIdx = i; + } + } + + // Replace worst neighbor if target is closer + if (worstIdx >= 0 && targetDist < worstDist) { + diskann_node_set_neighbor(validity, neighborIds, qvecs, worstIdx, + target_rowid, targetQ, + cfg->quantizer_type, col->dimensions); + rc = diskann_node_write(p, vec_col_idx, node_rowid, + validity, validitySize, + neighborIds, neighborIdsSize, + qvecs, qvecsSize); + } else { + rc = SQLITE_OK; // target is farther than all existing neighbors, skip + } + + sqlite3_free(targetQ); + sqlite3_free(nodeQ); + sqlite3_free(nodeVector); + } + + sqlite3_free(validity); + sqlite3_free(neighborIds); + sqlite3_free(qvecs); + return rc; +} + +// ============================================================ +// DiskANN buffer helpers (for batched inserts) +// ============================================================ + +/** + * Insert a vector into the _diskann_buffer table. + */ +static int diskann_buffer_write(vec0_vtab *p, int vec_col_idx, + i64 rowid, const void *vector, int vectorSize) { + sqlite3_stmt *stmt = NULL; + char *zSql = sqlite3_mprintf( + "INSERT INTO " VEC0_SHADOW_DISKANN_BUFFER_N_NAME + " (rowid, vector) VALUES (?, ?)", + p->schemaName, p->tableName, vec_col_idx); + if (!zSql) return SQLITE_NOMEM; + int rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) return rc; + sqlite3_bind_int64(stmt, 1, rowid); + sqlite3_bind_blob(stmt, 2, vector, vectorSize, SQLITE_STATIC); + rc = sqlite3_step(stmt); + sqlite3_finalize(stmt); + return (rc == SQLITE_DONE) ? SQLITE_OK : SQLITE_ERROR; +} + +/** + * Delete a vector from the _diskann_buffer table. + */ +static int diskann_buffer_delete(vec0_vtab *p, int vec_col_idx, i64 rowid) { + sqlite3_stmt *stmt = NULL; + char *zSql = sqlite3_mprintf( + "DELETE FROM " VEC0_SHADOW_DISKANN_BUFFER_N_NAME " WHERE rowid = ?", + p->schemaName, p->tableName, vec_col_idx); + if (!zSql) return SQLITE_NOMEM; + int rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) return rc; + sqlite3_bind_int64(stmt, 1, rowid); + rc = sqlite3_step(stmt); + sqlite3_finalize(stmt); + return (rc == SQLITE_DONE) ? SQLITE_OK : SQLITE_ERROR; +} + +/** + * Check if a rowid exists in the _diskann_buffer table. + * Returns SQLITE_OK and sets *exists to 1 if found, 0 if not. + */ +static int diskann_buffer_exists(vec0_vtab *p, int vec_col_idx, + i64 rowid, int *exists) { + sqlite3_stmt *stmt = NULL; + char *zSql = sqlite3_mprintf( + "SELECT 1 FROM " VEC0_SHADOW_DISKANN_BUFFER_N_NAME " WHERE rowid = ?", + p->schemaName, p->tableName, vec_col_idx); + if (!zSql) return SQLITE_NOMEM; + int rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) return rc; + sqlite3_bind_int64(stmt, 1, rowid); + rc = sqlite3_step(stmt); + *exists = (rc == SQLITE_ROW) ? 1 : 0; + sqlite3_finalize(stmt); + return SQLITE_OK; +} + +/** + * Get the count of rows in the _diskann_buffer table. + */ +static int diskann_buffer_count(vec0_vtab *p, int vec_col_idx, i64 *count) { + sqlite3_stmt *stmt = NULL; + char *zSql = sqlite3_mprintf( + "SELECT count(*) FROM " VEC0_SHADOW_DISKANN_BUFFER_N_NAME, + p->schemaName, p->tableName, vec_col_idx); + if (!zSql) return SQLITE_NOMEM; + int rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) return rc; + rc = sqlite3_step(stmt); + if (rc == SQLITE_ROW) { + *count = sqlite3_column_int64(stmt, 0); + sqlite3_finalize(stmt); + return SQLITE_OK; + } + sqlite3_finalize(stmt); + return SQLITE_ERROR; +} + +// Forward declaration: diskann_insert_graph does the actual graph insertion +static int diskann_insert_graph(vec0_vtab *p, int vec_col_idx, + i64 rowid, const void *vector); + +/** + * Flush all buffered vectors into the DiskANN graph. + * Iterates over _diskann_buffer rows and calls diskann_insert_graph for each. + */ +static int diskann_flush_buffer(vec0_vtab *p, int vec_col_idx) { + sqlite3_stmt *stmt = NULL; + char *zSql = sqlite3_mprintf( + "SELECT rowid, vector FROM " VEC0_SHADOW_DISKANN_BUFFER_N_NAME, + p->schemaName, p->tableName, vec_col_idx); + if (!zSql) return SQLITE_NOMEM; + int rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) return rc; + + while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) { + i64 rowid = sqlite3_column_int64(stmt, 0); + const void *vector = sqlite3_column_blob(stmt, 1); + // Note: vector is already written to _vectors table, so + // diskann_insert_graph will skip re-writing it (vector already exists). + // We call the graph-only insert path. + int insertRc = diskann_insert_graph(p, vec_col_idx, rowid, vector); + if (insertRc != SQLITE_OK) { + sqlite3_finalize(stmt); + return insertRc; + } + } + sqlite3_finalize(stmt); + + // Clear the buffer + zSql = sqlite3_mprintf( + "DELETE FROM " VEC0_SHADOW_DISKANN_BUFFER_N_NAME, + p->schemaName, p->tableName, vec_col_idx); + if (!zSql) return SQLITE_NOMEM; + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) return rc; + rc = sqlite3_step(stmt); + sqlite3_finalize(stmt); + return (rc == SQLITE_DONE) ? SQLITE_OK : SQLITE_ERROR; +} + +/** + * Insert a new vector into the DiskANN graph (graph-only path). + * The vector must already be written to _vectors table. + * This is the core graph insertion logic (Algorithm 2: LM-Insert). + */ +static int diskann_insert_graph(vec0_vtab *p, int vec_col_idx, + i64 rowid, const void *vector) { + struct VectorColumnDefinition *col = &p->vector_columns[vec_col_idx]; + struct Vec0DiskannConfig *cfg = &col->diskann; + int rc; + + // Handle first insert (empty graph) + i64 medoid; + int isEmpty; + rc = diskann_medoid_get(p, vec_col_idx, &medoid, &isEmpty); + if (rc != SQLITE_OK) return rc; + + if (isEmpty) { + u8 *validity, *neighborIds, *qvecs; + int validitySize, neighborIdsSize, qvecsSize; + rc = diskann_node_init(cfg->n_neighbors, cfg->quantizer_type, + col->dimensions, + &validity, &validitySize, + &neighborIds, &neighborIdsSize, + &qvecs, &qvecsSize); + if (rc != SQLITE_OK) return rc; + + rc = diskann_node_write(p, vec_col_idx, rowid, + validity, validitySize, + neighborIds, neighborIdsSize, + qvecs, qvecsSize); + sqlite3_free(validity); + sqlite3_free(neighborIds); + sqlite3_free(qvecs); + if (rc != SQLITE_OK) return rc; + + return diskann_medoid_set(p, vec_col_idx, rowid, 0); + } + + // Search for nearest neighbors + int L = cfg->search_list_size_insert > 0 ? cfg->search_list_size_insert : cfg->search_list_size; + i64 *searchRowids = sqlite3_malloc(L * sizeof(i64)); + f32 *searchDistances = sqlite3_malloc(L * sizeof(f32)); + if (!searchRowids || !searchDistances) { + sqlite3_free(searchRowids); + sqlite3_free(searchDistances); + return SQLITE_NOMEM; + } + + int searchCount; + rc = diskann_search(p, vec_col_idx, vector, col->dimensions, + col->element_type, L, L, + searchRowids, searchDistances, &searchCount); + if (rc != SQLITE_OK) { + sqlite3_free(searchRowids); + sqlite3_free(searchDistances); + return rc; + } + + // RobustPrune to select neighbors for x + i64 *selectedNeighbors = sqlite3_malloc(cfg->n_neighbors * sizeof(i64)); + int selectedCount = 0; + if (!selectedNeighbors) { + sqlite3_free(searchRowids); + sqlite3_free(searchDistances); + return SQLITE_NOMEM; + } + + rc = diskann_robust_prune(p, vec_col_idx, rowid, vector, + searchRowids, searchDistances, searchCount, + cfg->alpha, cfg->n_neighbors, + selectedNeighbors, &selectedCount); + sqlite3_free(searchRowids); + sqlite3_free(searchDistances); + if (rc != SQLITE_OK) { + sqlite3_free(selectedNeighbors); + return rc; + } + + // Write x's node with selected neighbors + rc = diskann_write_pruned_neighbors(p, vec_col_idx, rowid, + selectedNeighbors, selectedCount); + if (rc != SQLITE_OK) { + sqlite3_free(selectedNeighbors); + return rc; + } + + // Add bidirectional edges + for (int i = 0; i < selectedCount; i++) { + diskann_add_reverse_edge(p, vec_col_idx, + selectedNeighbors[i], rowid, vector); + } + + sqlite3_free(selectedNeighbors); + return SQLITE_OK; +} + +/** + * Insert a new vector into the DiskANN index (Algorithm 2: LM-Insert). + * When buffer_threshold > 0, vectors are buffered and flushed in batch. + */ +static int diskann_insert(vec0_vtab *p, int vec_col_idx, + i64 rowid, const void *vector) { + struct VectorColumnDefinition *col = &p->vector_columns[vec_col_idx]; + struct Vec0DiskannConfig *cfg = &col->diskann; + int rc; + size_t vectorSize = vector_column_byte_size(*col); + + // 1. Write full-precision vector to _vectors table (always needed for queries) + rc = diskann_vector_write(p, vec_col_idx, rowid, vector, (int)vectorSize); + if (rc != SQLITE_OK) return rc; + + // 2. If buffering is enabled, write to buffer instead of graph + if (cfg->buffer_threshold > 0) { + rc = diskann_buffer_write(p, vec_col_idx, rowid, vector, (int)vectorSize); + if (rc != SQLITE_OK) return rc; + + i64 count; + rc = diskann_buffer_count(p, vec_col_idx, &count); + if (rc != SQLITE_OK) return rc; + + if (count >= cfg->buffer_threshold) { + return diskann_flush_buffer(p, vec_col_idx); + } + return SQLITE_OK; + } + + // 3. Legacy per-row insert directly into graph + return diskann_insert_graph(p, vec_col_idx, rowid, vector); +} + +/** + * Returns 1 if ALL vector columns in this table are DiskANN-indexed. + */ +// ============================================================ +// DiskANN delete (Algorithm 3 from LM-DiskANN paper) +// ============================================================ + +static int diskann_node_delete(vec0_vtab *p, int vec_col_idx, i64 rowid) { + sqlite3_stmt *stmt = NULL; + char *zSql = sqlite3_mprintf( + "DELETE FROM " VEC0_SHADOW_DISKANN_NODES_N_NAME " WHERE rowid = ?", + p->schemaName, p->tableName, vec_col_idx); + if (!zSql) return SQLITE_NOMEM; + int rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) return rc; + sqlite3_bind_int64(stmt, 1, rowid); + rc = sqlite3_step(stmt); + sqlite3_finalize(stmt); + return (rc == SQLITE_DONE) ? SQLITE_OK : SQLITE_ERROR; +} + +static int diskann_vector_delete(vec0_vtab *p, int vec_col_idx, i64 rowid) { + sqlite3_stmt *stmt = NULL; + char *zSql = sqlite3_mprintf( + "DELETE FROM " VEC0_SHADOW_VECTORS_N_NAME " WHERE rowid = ?", + p->schemaName, p->tableName, vec_col_idx); + if (!zSql) return SQLITE_NOMEM; + int rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, NULL); + sqlite3_free(zSql); + if (rc != SQLITE_OK) return rc; + sqlite3_bind_int64(stmt, 1, rowid); + rc = sqlite3_step(stmt); + sqlite3_finalize(stmt); + return (rc == SQLITE_DONE) ? SQLITE_OK : SQLITE_ERROR; +} + +/** + * Repair graph after deleting a node. Following Algorithm 3 (LM-Delete): + * For each neighbor n of the deleted node, add deleted node's other neighbors + * to n's candidate set, then remove the deleted node from n's neighbor list. + * Uses simple slot replacement rather than full RobustPrune for performance. + */ +static int diskann_repair_reverse_edges( + vec0_vtab *p, int vec_col_idx, i64 deleted_rowid, + const i64 *deleted_neighbors, int deleted_neighbor_count) { + + struct VectorColumnDefinition *col = &p->vector_columns[vec_col_idx]; + struct Vec0DiskannConfig *cfg = &col->diskann; + int rc; + + // For each neighbor of the deleted node, fix their neighbor list + for (int dn = 0; dn < deleted_neighbor_count; dn++) { + i64 nodeRowid = deleted_neighbors[dn]; + + u8 *validity = NULL, *neighborIds = NULL, *qvecs = NULL; + int vs, nis, qs; + rc = diskann_node_read(p, vec_col_idx, nodeRowid, + &validity, &vs, &neighborIds, &nis, &qvecs, &qs); + if (rc != SQLITE_OK) continue; + + // Find and clear the deleted node's slot + int clearedSlot = -1; + for (int i = 0; i < cfg->n_neighbors; i++) { + if (diskann_validity_get(validity, i) && + diskann_neighbor_id_get(neighborIds, i) == deleted_rowid) { + diskann_node_clear_neighbor(validity, neighborIds, qvecs, i, + cfg->quantizer_type, col->dimensions); + clearedSlot = i; + break; + } + } + + if (clearedSlot >= 0) { + // Try to fill the cleared slot with one of the deleted node's other neighbors + for (int di = 0; di < deleted_neighbor_count; di++) { + i64 candidate = deleted_neighbors[di]; + if (candidate == nodeRowid || candidate == deleted_rowid) continue; + + // Check not already a neighbor + int alreadyNeighbor = 0; + for (int ni = 0; ni < cfg->n_neighbors; ni++) { + if (diskann_validity_get(validity, ni) && + diskann_neighbor_id_get(neighborIds, ni) == candidate) { + alreadyNeighbor = 1; + break; + } + } + if (alreadyNeighbor) continue; + + // Load, quantize, and set + void *candidateVec = NULL; + int cvs; + rc = diskann_vector_read(p, vec_col_idx, candidate, &candidateVec, &cvs); + if (rc != SQLITE_OK) continue; + + size_t qvecSize = diskann_quantized_vector_byte_size( + cfg->quantizer_type, col->dimensions); + u8 *qvec = sqlite3_malloc(qvecSize); + if (qvec) { + if (col->element_type == SQLITE_VEC_ELEMENT_TYPE_FLOAT32) { + diskann_quantize_vector((const f32 *)candidateVec, col->dimensions, + cfg->quantizer_type, qvec); + } else { + memcpy(qvec, candidateVec, + qvecSize < (size_t)cvs ? qvecSize : (size_t)cvs); + } + diskann_node_set_neighbor(validity, neighborIds, qvecs, clearedSlot, + candidate, qvec, + cfg->quantizer_type, col->dimensions); + sqlite3_free(qvec); + } + sqlite3_free(candidateVec); + break; + } + + diskann_node_write(p, vec_col_idx, nodeRowid, + validity, vs, neighborIds, nis, qvecs, qs); + } + + sqlite3_free(validity); + sqlite3_free(neighborIds); + sqlite3_free(qvecs); + } + + return SQLITE_OK; +} + +/** + * Delete a vector from the DiskANN graph (Algorithm 3: LM-Delete). + * If the vector is in the buffer (not yet flushed), just remove from buffer. + */ +static int diskann_delete(vec0_vtab *p, int vec_col_idx, i64 rowid) { + struct VectorColumnDefinition *col = &p->vector_columns[vec_col_idx]; + struct Vec0DiskannConfig *cfg = &col->diskann; + int rc; + + // Check if this rowid is in the buffer (not yet in graph) + if (cfg->buffer_threshold > 0) { + int inBuffer = 0; + rc = diskann_buffer_exists(p, vec_col_idx, rowid, &inBuffer); + if (rc != SQLITE_OK) return rc; + if (inBuffer) { + // Just remove from buffer and _vectors, no graph repair needed + rc = diskann_buffer_delete(p, vec_col_idx, rowid); + if (rc == SQLITE_OK) { + rc = diskann_vector_delete(p, vec_col_idx, rowid); + } + return rc; + } + } + + // 1. Read the node to get its neighbor list + u8 *delValidity = NULL, *delNeighborIds = NULL, *delQvecs = NULL; + int dvs, dnis, dqs; + rc = diskann_node_read(p, vec_col_idx, rowid, + &delValidity, &dvs, &delNeighborIds, &dnis, + &delQvecs, &dqs); + if (rc != SQLITE_OK) { + return SQLITE_OK; // Node doesn't exist, nothing to do + } + + i64 *deletedNeighbors = sqlite3_malloc(cfg->n_neighbors * sizeof(i64)); + int deletedNeighborCount = 0; + if (!deletedNeighbors) { + sqlite3_free(delValidity); + sqlite3_free(delNeighborIds); + sqlite3_free(delQvecs); + return SQLITE_NOMEM; + } + + for (int i = 0; i < cfg->n_neighbors; i++) { + if (diskann_validity_get(delValidity, i)) { + deletedNeighbors[deletedNeighborCount++] = + diskann_neighbor_id_get(delNeighborIds, i); + } + } + + sqlite3_free(delValidity); + sqlite3_free(delNeighborIds); + sqlite3_free(delQvecs); + + // 2. Repair reverse edges + rc = diskann_repair_reverse_edges(p, vec_col_idx, rowid, + deletedNeighbors, deletedNeighborCount); + sqlite3_free(deletedNeighbors); + + // 3. Delete node and vector + if (rc == SQLITE_OK) { + rc = diskann_node_delete(p, vec_col_idx, rowid); + } + if (rc == SQLITE_OK) { + rc = diskann_vector_delete(p, vec_col_idx, rowid); + } + + // 4. Handle medoid deletion + if (rc == SQLITE_OK) { + rc = diskann_medoid_handle_delete(p, vec_col_idx, rowid); + } + + return rc; +} + +static int vec0_all_columns_diskann(vec0_vtab *p) { + for (int i = 0; i < p->numVectorColumns; i++) { + if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_DISKANN) return 0; + } + return p->numVectorColumns > 0; +} + +// ============================================================================ +// Command dispatch +// ============================================================================ + +static int diskann_handle_command(vec0_vtab *p, const char *command) { + int col_idx = -1; + for (int i = 0; i < p->numVectorColumns; i++) { + if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_DISKANN) { col_idx = i; break; } + } + if (col_idx < 0) return SQLITE_EMPTY; + + struct Vec0DiskannConfig *cfg = &p->vector_columns[col_idx].diskann; + + if (strncmp(command, "search_list_size_search=", 24) == 0) { + int val = atoi(command + 24); + if (val < 1) { vtab_set_error(&p->base, "search_list_size_search must be >= 1"); return SQLITE_ERROR; } + cfg->search_list_size_search = val; + return SQLITE_OK; + } + if (strncmp(command, "search_list_size_insert=", 24) == 0) { + int val = atoi(command + 24); + if (val < 1) { vtab_set_error(&p->base, "search_list_size_insert must be >= 1"); return SQLITE_ERROR; } + cfg->search_list_size_insert = val; + return SQLITE_OK; + } + if (strncmp(command, "search_list_size=", 17) == 0) { + int val = atoi(command + 17); + if (val < 1) { vtab_set_error(&p->base, "search_list_size must be >= 1"); return SQLITE_ERROR; } + cfg->search_list_size = val; + return SQLITE_OK; + } + return SQLITE_EMPTY; +} + +#ifdef SQLITE_VEC_TEST +// Expose internal DiskANN data structures and functions for unit testing. + +int _test_diskann_candidate_list_init(struct DiskannCandidateList *list, int capacity) { + return diskann_candidate_list_init(list, capacity); +} +void _test_diskann_candidate_list_free(struct DiskannCandidateList *list) { + diskann_candidate_list_free(list); +} +int _test_diskann_candidate_list_insert(struct DiskannCandidateList *list, long long rowid, float distance) { + return diskann_candidate_list_insert(list, (i64)rowid, (f32)distance); +} +int _test_diskann_candidate_list_next_unvisited(const struct DiskannCandidateList *list) { + return diskann_candidate_list_next_unvisited(list); +} +int _test_diskann_candidate_list_count(const struct DiskannCandidateList *list) { + return list->count; +} +long long _test_diskann_candidate_list_rowid(const struct DiskannCandidateList *list, int i) { + return (long long)list->items[i].rowid; +} +float _test_diskann_candidate_list_distance(const struct DiskannCandidateList *list, int i) { + return (float)list->items[i].distance; +} +void _test_diskann_candidate_list_set_visited(struct DiskannCandidateList *list, int i) { + list->items[i].visited = 1; +} + +int _test_diskann_visited_set_init(struct DiskannVisitedSet *set, int capacity) { + return diskann_visited_set_init(set, capacity); +} +void _test_diskann_visited_set_free(struct DiskannVisitedSet *set) { + diskann_visited_set_free(set); +} +int _test_diskann_visited_set_contains(const struct DiskannVisitedSet *set, long long rowid) { + return diskann_visited_set_contains(set, (i64)rowid); +} +int _test_diskann_visited_set_insert(struct DiskannVisitedSet *set, long long rowid) { + return diskann_visited_set_insert(set, (i64)rowid); +} +#endif /* SQLITE_VEC_TEST */ + diff --git a/sqlite-vec-rescore.c b/sqlite-vec-rescore.c index a45f52f..ef4e692 100644 --- a/sqlite-vec-rescore.c +++ b/sqlite-vec-rescore.c @@ -156,21 +156,11 @@ static void rescore_quantize_float_to_bit(const float *src, uint8_t *dst, static void rescore_quantize_float_to_int8(const float *src, int8_t *dst, size_t dimensions) { - float vmin = src[0], vmax = src[0]; - for (size_t i = 1; i < dimensions; i++) { - if (src[i] < vmin) vmin = src[i]; - if (src[i] > vmax) vmax = src[i]; - } - float range = vmax - vmin; - if (range == 0.0f) { - memset(dst, 0, dimensions); - return; - } - float scale = 255.0f / range; + float step = 2.0f / 255.0f; for (size_t i = 0; i < dimensions; i++) { - float v = (src[i] - vmin) * scale - 128.0f; - if (v < -128.0f) v = -128.0f; - if (v > 127.0f) v = 127.0f; + float v = (src[i] - (-1.0f)) / step - 128.0f; + if (!(v <= 127.0f)) v = 127.0f; + if (!(v >= -128.0f)) v = -128.0f; dst[i] = (int8_t)v; } } diff --git a/sqlite-vec.c b/sqlite-vec.c index 015792b..5ca7834 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -61,6 +61,10 @@ SQLITE_EXTENSION_INIT1 #define LONGDOUBLE_TYPE long double #endif +#ifndef SQLITE_VEC_ENABLE_DISKANN +#define SQLITE_VEC_ENABLE_DISKANN 1 +#endif + #ifndef _WIN32 #ifndef __EMSCRIPTEN__ #ifndef __COSMOPOLITAN__ @@ -2544,6 +2548,7 @@ enum Vec0IndexType { VEC0_INDEX_TYPE_RESCORE = 2, #endif VEC0_INDEX_TYPE_IVF = 3, + VEC0_INDEX_TYPE_DISKANN = 4, }; #if SQLITE_VEC_ENABLE_RESCORE @@ -2575,6 +2580,75 @@ struct Vec0IvfConfig { struct Vec0IvfConfig { char _unused; }; #endif +// ============================================================ +// DiskANN types and constants +// ============================================================ + +#define VEC0_DISKANN_DEFAULT_N_NEIGHBORS 72 +#define VEC0_DISKANN_MAX_N_NEIGHBORS 256 +#define VEC0_DISKANN_DEFAULT_SEARCH_LIST_SIZE 128 +#define VEC0_DISKANN_DEFAULT_ALPHA 1.2f + +/** + * Quantizer type used for compressing neighbor vectors in the DiskANN graph. + */ +enum Vec0DiskannQuantizerType { + VEC0_DISKANN_QUANTIZER_BINARY = 1, // 1 bit per dimension (1/32 compression) + VEC0_DISKANN_QUANTIZER_INT8 = 2, // 1 byte per dimension (1/4 compression) +}; + +/** + * Configuration for a DiskANN index on a single vector column. + * Parsed from `INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=72)`. + */ +struct Vec0DiskannConfig { + // Quantizer type for neighbor vectors + enum Vec0DiskannQuantizerType quantizer_type; + + // Maximum number of neighbors per node (R in the paper). Must be divisible by 8. + int n_neighbors; + + // Search list size (L in the paper) — unified default for both insert and query. + int search_list_size; + + // Per-path overrides (0 = fall back to search_list_size). + int search_list_size_search; + int search_list_size_insert; + + // Alpha parameter for RobustPrune (distance scaling factor, typically 1.0-1.5) + f32 alpha; + + // Buffer threshold for batched inserts. When > 0, inserts go into a flat + // buffer table and are flushed into the graph when the buffer reaches this + // size. 0 = disabled (legacy per-row insert behavior). + int buffer_threshold; +}; + +/** + * Represents a single candidate during greedy beam search. + * Used in priority queues / sorted arrays during LM-Search. + */ +struct Vec0DiskannCandidate { + i64 rowid; + f32 distance; + int visited; // 1 if this candidate's neighbors have been explored +}; + +/** + * Returns the byte size of a quantized vector for the given quantizer type + * and number of dimensions. + */ +size_t diskann_quantized_vector_byte_size( + enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions) { + switch (quantizer_type) { + case VEC0_DISKANN_QUANTIZER_BINARY: + return dimensions / CHAR_BIT; // 1 bit per dimension + case VEC0_DISKANN_QUANTIZER_INT8: + return dimensions * sizeof(i8); // 1 byte per dimension + } + return 0; +} + struct VectorColumnDefinition { char *name; int name_length; @@ -2586,6 +2660,7 @@ struct VectorColumnDefinition { struct Vec0RescoreConfig rescore; #endif struct Vec0IvfConfig ivf; + struct Vec0DiskannConfig diskann; }; struct Vec0PartitionColumnDefinition { @@ -2743,6 +2818,126 @@ static int vec0_parse_ivf_options(struct Vec0Scanner *scanner, struct Vec0IvfConfig *config); #endif +/** + * Parse the options inside diskann(...) parentheses. + * Scanner should be positioned right before the '(' token. + * + * Recognized options: + * neighbor_quantizer = binary | int8 (required) + * n_neighbors = (optional, default 72) + * search_list_size = (optional, default 128) + */ +static int vec0_parse_diskann_options(struct Vec0Scanner *scanner, + struct Vec0DiskannConfig *config) { + int rc; + struct Vec0Token token; + int hasQuantizer = 0; + + // Set defaults + config->n_neighbors = VEC0_DISKANN_DEFAULT_N_NEIGHBORS; + config->search_list_size = VEC0_DISKANN_DEFAULT_SEARCH_LIST_SIZE; + config->search_list_size_search = 0; + config->search_list_size_insert = 0; + config->alpha = VEC0_DISKANN_DEFAULT_ALPHA; + config->buffer_threshold = 0; + int hasSearchListSize = 0; + int hasSearchListSizeSplit = 0; + + // Expect '(' + rc = vec0_scanner_next(scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME || token.token_type != TOKEN_TYPE_LPAREN) { + return SQLITE_ERROR; + } + + while (1) { + // key + rc = vec0_scanner_next(scanner, &token); + if (rc == VEC0_TOKEN_RESULT_SOME && token.token_type == TOKEN_TYPE_RPAREN) { + break; // empty parens or trailing comma + } + if (rc != VEC0_TOKEN_RESULT_SOME || token.token_type != TOKEN_TYPE_IDENTIFIER) { + return SQLITE_ERROR; + } + char *optKey = token.start; + int optKeyLen = token.end - token.start; + + // '=' + rc = vec0_scanner_next(scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME || token.token_type != TOKEN_TYPE_EQ) { + return SQLITE_ERROR; + } + + // value (identifier or digit) + rc = vec0_scanner_next(scanner, &token); + if (rc != VEC0_TOKEN_RESULT_SOME) { + return SQLITE_ERROR; + } + char *optVal = token.start; + int optValLen = token.end - token.start; + + if (sqlite3_strnicmp(optKey, "neighbor_quantizer", optKeyLen) == 0) { + if (sqlite3_strnicmp(optVal, "binary", optValLen) == 0) { + config->quantizer_type = VEC0_DISKANN_QUANTIZER_BINARY; + } else if (sqlite3_strnicmp(optVal, "int8", optValLen) == 0) { + config->quantizer_type = VEC0_DISKANN_QUANTIZER_INT8; + } else { + return SQLITE_ERROR; // unknown quantizer + } + hasQuantizer = 1; + } else if (sqlite3_strnicmp(optKey, "n_neighbors", optKeyLen) == 0) { + config->n_neighbors = atoi(optVal); + if (config->n_neighbors <= 0 || (config->n_neighbors % 8) != 0 || + config->n_neighbors > VEC0_DISKANN_MAX_N_NEIGHBORS) { + return SQLITE_ERROR; + } + } else if (sqlite3_strnicmp(optKey, "search_list_size_search", optKeyLen) == 0 && optKeyLen == 23) { + config->search_list_size_search = atoi(optVal); + if (config->search_list_size_search <= 0) { + return SQLITE_ERROR; + } + hasSearchListSizeSplit = 1; + } else if (sqlite3_strnicmp(optKey, "search_list_size_insert", optKeyLen) == 0 && optKeyLen == 23) { + config->search_list_size_insert = atoi(optVal); + if (config->search_list_size_insert <= 0) { + return SQLITE_ERROR; + } + hasSearchListSizeSplit = 1; + } else if (sqlite3_strnicmp(optKey, "search_list_size", optKeyLen) == 0) { + config->search_list_size = atoi(optVal); + if (config->search_list_size <= 0) { + return SQLITE_ERROR; + } + hasSearchListSize = 1; + } else if (sqlite3_strnicmp(optKey, "buffer_threshold", optKeyLen) == 0) { + config->buffer_threshold = atoi(optVal); + if (config->buffer_threshold < 0) { + return SQLITE_ERROR; + } + } else { + return SQLITE_ERROR; // unknown option + } + + // Expect ',' or ')' + rc = vec0_scanner_next(scanner, &token); + if (rc == VEC0_TOKEN_RESULT_SOME && token.token_type == TOKEN_TYPE_RPAREN) { + break; + } + if (rc != VEC0_TOKEN_RESULT_SOME || token.token_type != TOKEN_TYPE_COMMA) { + return SQLITE_ERROR; + } + } + + if (!hasQuantizer) { + return SQLITE_ERROR; // neighbor_quantizer is required + } + + if (hasSearchListSize && hasSearchListSizeSplit) { + return SQLITE_ERROR; // cannot mix search_list_size with search_list_size_search/insert + } + + return SQLITE_OK; +} + int vec0_parse_vector_column(const char *source, int source_length, struct VectorColumnDefinition *outColumn) { // parses a vector column definition like so: @@ -2763,8 +2958,9 @@ int vec0_parse_vector_column(const char *source, int source_length, #endif struct Vec0IvfConfig ivfConfig; memset(&ivfConfig, 0, sizeof(ivfConfig)); + struct Vec0DiskannConfig diskannConfig; + memset(&diskannConfig, 0, sizeof(diskannConfig)); int dimensions; - vec0_scanner_init(&scanner, source, source_length); // starts with an identifier @@ -2931,6 +3127,16 @@ int vec0_parse_vector_column(const char *source, int source_length, } #else return SQLITE_ERROR; // IVF not compiled in +#endif + } else if (sqlite3_strnicmp(token.start, "diskann", indexNameLen) == 0) { +#if SQLITE_VEC_ENABLE_DISKANN + indexType = VEC0_INDEX_TYPE_DISKANN; + rc = vec0_parse_diskann_options(&scanner, &diskannConfig); + if (rc != SQLITE_OK) { + return rc; + } +#else + return SQLITE_ERROR; #endif } else { // unknown index type @@ -2956,6 +3162,7 @@ int vec0_parse_vector_column(const char *source, int source_length, outColumn->rescore = rescoreConfig; #endif outColumn->ivf = ivfConfig; + outColumn->diskann = diskannConfig; return SQLITE_OK; } @@ -3154,6 +3361,7 @@ static sqlite3_module vec_eachModule = { #pragma endregion + #pragma region vec0 virtual table #define VEC0_COLUMN_ID 0 @@ -3214,6 +3422,9 @@ static sqlite3_module vec_eachModule = { #define VEC0_SHADOW_AUXILIARY_NAME "\"%w\".\"%w_auxiliary\"" #define VEC0_SHADOW_METADATA_N_NAME "\"%w\".\"%w_metadatachunks%02d\"" +#define VEC0_SHADOW_VECTORS_N_NAME "\"%w\".\"%w_vectors%02d\"" +#define VEC0_SHADOW_DISKANN_NODES_N_NAME "\"%w\".\"%w_diskann_nodes%02d\"" +#define VEC0_SHADOW_DISKANN_BUFFER_N_NAME "\"%w\".\"%w_diskann_buffer%02d\"" #define VEC0_SHADOW_METADATA_TEXT_DATA_NAME "\"%w\".\"%w_metadatatext%02d\"" #define VEC_INTERAL_ERROR "Internal sqlite-vec error: " @@ -3388,6 +3599,24 @@ struct vec0_vtab { * Must be cleaned up with sqlite3_finalize(). */ sqlite3_stmt *stmtRowidsGetChunkPosition; + + // === DiskANN additions === +#if SQLITE_VEC_ENABLE_DISKANN + // Shadow table names for DiskANN, per vector column + // e.g., "{schema}"."{table}_vectors{00..15}" + char *shadowVectorsNames[VEC0_MAX_VECTOR_COLUMNS]; + + // e.g., "{schema}"."{table}_diskann_nodes{00..15}" + char *shadowDiskannNodesNames[VEC0_MAX_VECTOR_COLUMNS]; + + // Prepared statements for DiskANN operations (per vector column) + // These will be lazily prepared on first use. + sqlite3_stmt *stmtDiskannNodeRead[VEC0_MAX_VECTOR_COLUMNS]; + sqlite3_stmt *stmtDiskannNodeWrite[VEC0_MAX_VECTOR_COLUMNS]; + sqlite3_stmt *stmtDiskannNodeInsert[VEC0_MAX_VECTOR_COLUMNS]; + sqlite3_stmt *stmtVectorsRead[VEC0_MAX_VECTOR_COLUMNS]; + sqlite3_stmt *stmtVectorsInsert[VEC0_MAX_VECTOR_COLUMNS]; +#endif }; #if SQLITE_VEC_ENABLE_RESCORE @@ -3427,6 +3656,13 @@ void vec0_free_resources(vec0_vtab *p) { sqlite3_finalize(p->stmtIvfRowidMapLookup[i]); p->stmtIvfRowidMapLookup[i] = NULL; sqlite3_finalize(p->stmtIvfRowidMapDelete[i]); p->stmtIvfRowidMapDelete[i] = NULL; sqlite3_finalize(p->stmtIvfCentroidsAll[i]); p->stmtIvfCentroidsAll[i] = NULL; +#if SQLITE_VEC_ENABLE_DISKANN + sqlite3_finalize(p->stmtDiskannNodeRead[i]); p->stmtDiskannNodeRead[i] = NULL; + sqlite3_finalize(p->stmtDiskannNodeWrite[i]); p->stmtDiskannNodeWrite[i] = NULL; + sqlite3_finalize(p->stmtDiskannNodeInsert[i]); p->stmtDiskannNodeInsert[i] = NULL; + sqlite3_finalize(p->stmtVectorsRead[i]); p->stmtVectorsRead[i] = NULL; + sqlite3_finalize(p->stmtVectorsInsert[i]); p->stmtVectorsInsert[i] = NULL; +#endif } #endif } @@ -3464,6 +3700,13 @@ void vec0_free(vec0_vtab *p) { p->shadowRescoreVectorsNames[i] = NULL; #endif +#if SQLITE_VEC_ENABLE_DISKANN + sqlite3_free(p->shadowVectorsNames[i]); + p->shadowVectorsNames[i] = NULL; + sqlite3_free(p->shadowDiskannNodesNames[i]); + p->shadowDiskannNodesNames[i] = NULL; +#endif + sqlite3_free(p->vector_columns[i].name); p->vector_columns[i].name = NULL; } @@ -3484,6 +3727,12 @@ void vec0_free(vec0_vtab *p) { } } +#if SQLITE_VEC_ENABLE_DISKANN +#include "sqlite-vec-diskann.c" +#else +static int vec0_all_columns_diskann(vec0_vtab *p) { (void)p; return 0; } +#endif + int vec0_num_defined_user_columns(vec0_vtab *p) { return p->numVectorColumns + p->numPartitionColumns + p->numAuxiliaryColumns + p->numMetadataColumns; } @@ -3753,6 +4002,25 @@ int vec0_get_vector_data(vec0_vtab *pVtab, i64 rowid, int vector_column_idx, void **outVector, int *outVectorSize) { vec0_vtab *p = pVtab; int rc, brc; + +#if SQLITE_VEC_ENABLE_DISKANN + // DiskANN fast path: read from _vectors table + if (p->vector_columns[vector_column_idx].index_type == VEC0_INDEX_TYPE_DISKANN) { + void *vec = NULL; + int vecSize; + rc = diskann_vector_read(p, vector_column_idx, rowid, &vec, &vecSize); + if (rc != SQLITE_OK) { + vtab_set_error(&pVtab->base, + "Could not fetch vector data for %lld from DiskANN vectors table", + rowid); + return SQLITE_ERROR; + } + *outVector = vec; + if (outVectorSize) *outVectorSize = vecSize; + return SQLITE_OK; + } +#endif + i64 chunk_id; i64 chunk_offset; @@ -4653,6 +4921,26 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, (i64)vecColumn.dimensions, SQLITE_VEC_VEC0_MAX_DIMENSIONS); goto error; } + + // DiskANN validation + if (vecColumn.index_type == VEC0_INDEX_TYPE_DISKANN) { + if (vecColumn.element_type == SQLITE_VEC_ELEMENT_TYPE_BIT) { + sqlite3_free(vecColumn.name); + *pzErr = sqlite3_mprintf( + VEC_CONSTRUCTOR_ERROR + "DiskANN index is not supported on bit vector columns"); + goto error; + } + if (vecColumn.diskann.quantizer_type == VEC0_DISKANN_QUANTIZER_BINARY && + (vecColumn.dimensions % CHAR_BIT) != 0) { + sqlite3_free(vecColumn.name); + *pzErr = sqlite3_mprintf( + VEC_CONSTRUCTOR_ERROR + "DiskANN with binary quantizer requires dimensions divisible by 8"); + goto error; + } + } + pNew->user_column_kinds[user_column_idx] = SQLITE_VEC0_USER_COLUMN_KIND_VECTOR; pNew->user_column_idxs[user_column_idx] = numVectorColumns; memcpy(&pNew->vector_columns[numVectorColumns], &vecColumn, sizeof(vecColumn)); @@ -4881,6 +5169,31 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, } } + // DiskANN columns cannot coexist with aux/metadata/partition columns + for (int i = 0; i < numVectorColumns; i++) { + if (pNew->vector_columns[i].index_type == VEC0_INDEX_TYPE_DISKANN) { + if (numAuxiliaryColumns > 0) { + *pzErr = sqlite3_mprintf( + VEC_CONSTRUCTOR_ERROR + "Auxiliary columns are not supported with DiskANN-indexed vector columns"); + goto error; + } + if (numMetadataColumns > 0) { + *pzErr = sqlite3_mprintf( + VEC_CONSTRUCTOR_ERROR + "Metadata columns are not supported with DiskANN-indexed vector columns"); + goto error; + } + if (numPartitionColumns > 0) { + *pzErr = sqlite3_mprintf( + VEC_CONSTRUCTOR_ERROR + "Partition key columns are not supported with DiskANN-indexed vector columns"); + goto error; + } + break; + } + } + sqlite3_str *createStr = sqlite3_str_new(NULL); sqlite3_str_appendall(createStr, "CREATE TABLE x("); if (pkColumnName) { @@ -4984,6 +5297,20 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, goto error; } } +#endif +#if SQLITE_VEC_ENABLE_DISKANN + if (pNew->vector_columns[i].index_type == VEC0_INDEX_TYPE_DISKANN) { + pNew->shadowVectorsNames[i] = + sqlite3_mprintf("%s_vectors%02d", tableName, i); + if (!pNew->shadowVectorsNames[i]) { + goto error; + } + pNew->shadowDiskannNodesNames[i] = + sqlite3_mprintf("%s_diskann_nodes%02d", tableName, i); + if (!pNew->shadowDiskannNodesNames[i]) { + goto error; + } + } #endif } #if SQLITE_VEC_EXPERIMENTAL_IVF_ENABLE @@ -5060,7 +5387,32 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, } sqlite3_finalize(stmt); - +#if SQLITE_VEC_ENABLE_DISKANN + // Seed medoid entries for DiskANN-indexed columns + for (int i = 0; i < pNew->numVectorColumns; i++) { + if (pNew->vector_columns[i].index_type != VEC0_INDEX_TYPE_DISKANN) { + continue; + } + char *key = sqlite3_mprintf("diskann_medoid_%02d", i); + char *zInsert = sqlite3_mprintf( + "INSERT INTO " VEC0_SHADOW_INFO_NAME "(key, value) VALUES (?1, ?2)", + pNew->schemaName, pNew->tableName); + rc = sqlite3_prepare_v2(db, zInsert, -1, &stmt, NULL); + sqlite3_free(zInsert); + if (rc != SQLITE_OK) { + sqlite3_free(key); + sqlite3_finalize(stmt); + goto error; + } + sqlite3_bind_text(stmt, 1, key, -1, sqlite3_free); + sqlite3_bind_null(stmt, 2); // NULL means empty graph + if (sqlite3_step(stmt) != SQLITE_DONE) { + sqlite3_finalize(stmt); + goto error; + } + sqlite3_finalize(stmt); + } +#endif // create the _chunks shadow table char *zCreateShadowChunks = NULL; @@ -5118,7 +5470,7 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, for (int i = 0; i < pNew->numVectorColumns; i++) { #if SQLITE_VEC_ENABLE_RESCORE - // Rescore and IVF columns don't use _vector_chunks + // Non-FLAT columns don't use _vector_chunks if (pNew->vector_columns[i].index_type != VEC0_INDEX_TYPE_FLAT) continue; #endif @@ -5159,6 +5511,84 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv, } #endif +#if SQLITE_VEC_ENABLE_DISKANN + // Create DiskANN shadow tables for indexed vector columns + for (int i = 0; i < pNew->numVectorColumns; i++) { + if (pNew->vector_columns[i].index_type != VEC0_INDEX_TYPE_DISKANN) { + continue; + } + + // Create _vectors{NN} table + { + char *zSql = sqlite3_mprintf( + "CREATE TABLE " VEC0_SHADOW_VECTORS_N_NAME + " (rowid INTEGER PRIMARY KEY, vector BLOB NOT NULL);", + pNew->schemaName, pNew->tableName, i); + if (!zSql) { + goto error; + } + rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, 0); + sqlite3_free(zSql); + if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) { + sqlite3_finalize(stmt); + *pzErr = sqlite3_mprintf( + "Could not create '_vectors%02d' shadow table: %s", i, + sqlite3_errmsg(db)); + goto error; + } + sqlite3_finalize(stmt); + } + + // Create _diskann_nodes{NN} table + { + char *zSql = sqlite3_mprintf( + "CREATE TABLE " VEC0_SHADOW_DISKANN_NODES_N_NAME " (" + "rowid INTEGER PRIMARY KEY, " + "neighbors_validity BLOB NOT NULL, " + "neighbor_ids BLOB NOT NULL, " + "neighbor_quantized_vectors BLOB NOT NULL" + ");", + pNew->schemaName, pNew->tableName, i); + if (!zSql) { + goto error; + } + rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, 0); + sqlite3_free(zSql); + if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) { + sqlite3_finalize(stmt); + *pzErr = sqlite3_mprintf( + "Could not create '_diskann_nodes%02d' shadow table: %s", i, + sqlite3_errmsg(db)); + goto error; + } + sqlite3_finalize(stmt); + } + + // Create _diskann_buffer{NN} table (for batched inserts) + { + char *zSql = sqlite3_mprintf( + "CREATE TABLE " VEC0_SHADOW_DISKANN_BUFFER_N_NAME " (" + "rowid INTEGER PRIMARY KEY, " + "vector BLOB NOT NULL" + ");", + pNew->schemaName, pNew->tableName, i); + if (!zSql) { + goto error; + } + rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, 0); + sqlite3_free(zSql); + if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) { + sqlite3_finalize(stmt); + *pzErr = sqlite3_mprintf( + "Could not create '_diskann_buffer%02d' shadow table: %s", i, + sqlite3_errmsg(db)); + goto error; + } + sqlite3_finalize(stmt); + } + } +#endif + // See SHADOW_TABLE_ROWID_QUIRK in vec0_new_chunk() — same "rowid PRIMARY KEY" // without INTEGER type issue applies here. for (int i = 0; i < pNew->numMetadataColumns; i++) { @@ -5293,6 +5723,45 @@ static int vec0Destroy(sqlite3_vtab *pVtab) { sqlite3_finalize(stmt); for (int i = 0; i < p->numVectorColumns; i++) { +#if SQLITE_VEC_ENABLE_DISKANN + if (p->vector_columns[i].index_type == VEC0_INDEX_TYPE_DISKANN) { + // Drop DiskANN shadow tables + zSql = sqlite3_mprintf("DROP TABLE IF EXISTS " VEC0_SHADOW_VECTORS_N_NAME, + p->schemaName, p->tableName, i); + if (zSql) { + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0); + sqlite3_free((void *)zSql); + if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) { + rc = SQLITE_ERROR; + goto done; + } + sqlite3_finalize(stmt); + } + zSql = sqlite3_mprintf("DROP TABLE IF EXISTS " VEC0_SHADOW_DISKANN_NODES_N_NAME, + p->schemaName, p->tableName, i); + if (zSql) { + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0); + sqlite3_free((void *)zSql); + if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) { + rc = SQLITE_ERROR; + goto done; + } + sqlite3_finalize(stmt); + } + zSql = sqlite3_mprintf("DROP TABLE IF EXISTS " VEC0_SHADOW_DISKANN_BUFFER_N_NAME, + p->schemaName, p->tableName, i); + if (zSql) { + rc = sqlite3_prepare_v2(p->db, zSql, -1, &stmt, 0); + sqlite3_free((void *)zSql); + if ((rc != SQLITE_OK) || (sqlite3_step(stmt) != SQLITE_DONE)) { + rc = SQLITE_ERROR; + goto done; + } + sqlite3_finalize(stmt); + } + continue; + } +#endif #if SQLITE_VEC_ENABLE_RESCORE if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_FLAT) continue; @@ -7088,6 +7557,171 @@ cleanup: #include "sqlite-vec-rescore.c" #endif +#if SQLITE_VEC_ENABLE_DISKANN +/** + * Handle a KNN query using the DiskANN graph search. + */ +static int vec0Filter_knn_diskann( + vec0_cursor *pCur, vec0_vtab *p, int idxNum, + const char *idxStr, int argc, sqlite3_value **argv) { + + int rc; + int vectorColumnIdx = idxNum; + struct VectorColumnDefinition *vector_column = &p->vector_columns[vectorColumnIdx]; + struct vec0_query_knn_data *knn_data; + + knn_data = sqlite3_malloc(sizeof(*knn_data)); + if (!knn_data) return SQLITE_NOMEM; + memset(knn_data, 0, sizeof(*knn_data)); + + // Parse query_idx and k_idx from idxStr + int query_idx = -1; + int k_idx = -1; + for (int i = 0; i < argc; i++) { + if (idxStr[1 + (i * 4)] == VEC0_IDXSTR_KIND_KNN_MATCH) { + query_idx = i; + } + if (idxStr[1 + (i * 4)] == VEC0_IDXSTR_KIND_KNN_K) { + k_idx = i; + } + } + assert(query_idx >= 0); + assert(k_idx >= 0); + + // Extract query vector + void *queryVector; + size_t dimensions; + enum VectorElementType elementType; + vector_cleanup queryVectorCleanup = vector_cleanup_noop; + char *pzError; + + rc = vector_from_value(argv[query_idx], &queryVector, &dimensions, + &elementType, &queryVectorCleanup, &pzError); + if (rc != SQLITE_OK) { + vtab_set_error(&p->base, "Invalid query vector: %z", pzError); + sqlite3_free(knn_data); + return SQLITE_ERROR; + } + + if (elementType != vector_column->element_type || + dimensions != vector_column->dimensions) { + vtab_set_error(&p->base, "Query vector type/dimension mismatch"); + queryVectorCleanup(queryVector); + sqlite3_free(knn_data); + return SQLITE_ERROR; + } + + i64 k = sqlite3_value_int64(argv[k_idx]); + if (k <= 0) { + knn_data->k = 0; + knn_data->k_used = 0; + pCur->knn_data = knn_data; + pCur->query_plan = VEC0_QUERY_PLAN_KNN; + queryVectorCleanup(queryVector); + return SQLITE_OK; + } + + // Run DiskANN search + i64 *resultRowids = sqlite3_malloc(k * sizeof(i64)); + f32 *resultDistances = sqlite3_malloc(k * sizeof(f32)); + if (!resultRowids || !resultDistances) { + sqlite3_free(resultRowids); + sqlite3_free(resultDistances); + queryVectorCleanup(queryVector); + sqlite3_free(knn_data); + return SQLITE_NOMEM; + } + + int resultCount; + rc = diskann_search(p, vectorColumnIdx, queryVector, dimensions, + elementType, (int)k, 0, + resultRowids, resultDistances, &resultCount); + + if (rc != SQLITE_OK) { + queryVectorCleanup(queryVector); + sqlite3_free(resultRowids); + sqlite3_free(resultDistances); + sqlite3_free(knn_data); + return rc; + } + + // Scan _diskann_buffer for any buffered (unflushed) vectors and merge + // with graph results. This ensures no recall loss for buffered vectors. + { + sqlite3_stmt *bufStmt = NULL; + char *zSql = sqlite3_mprintf( + "SELECT rowid, vector FROM " VEC0_SHADOW_DISKANN_BUFFER_N_NAME, + p->schemaName, p->tableName, vectorColumnIdx); + if (!zSql) { + queryVectorCleanup(queryVector); + sqlite3_free(resultRowids); + sqlite3_free(resultDistances); + sqlite3_free(knn_data); + return SQLITE_NOMEM; + } + int bufRc = sqlite3_prepare_v2(p->db, zSql, -1, &bufStmt, NULL); + sqlite3_free(zSql); + if (bufRc == SQLITE_OK) { + while (sqlite3_step(bufStmt) == SQLITE_ROW) { + i64 bufRowid = sqlite3_column_int64(bufStmt, 0); + const void *bufVec = sqlite3_column_blob(bufStmt, 1); + f32 dist = vec0_distance_full( + queryVector, bufVec, dimensions, elementType, + vector_column->distance_metric); + + // Check if this buffer vector should replace the worst graph result + if (resultCount < (int)k) { + // Still have room, just add it + resultRowids[resultCount] = bufRowid; + resultDistances[resultCount] = dist; + resultCount++; + } else { + // Find worst (largest distance) in results + int worstIdx = 0; + for (int wi = 1; wi < resultCount; wi++) { + if (resultDistances[wi] > resultDistances[worstIdx]) { + worstIdx = wi; + } + } + if (dist < resultDistances[worstIdx]) { + resultRowids[worstIdx] = bufRowid; + resultDistances[worstIdx] = dist; + } + } + } + sqlite3_finalize(bufStmt); + } + } + + queryVectorCleanup(queryVector); + + // Sort results by distance (ascending) + for (int si = 0; si < resultCount - 1; si++) { + for (int sj = si + 1; sj < resultCount; sj++) { + if (resultDistances[sj] < resultDistances[si]) { + f32 tmpD = resultDistances[si]; + resultDistances[si] = resultDistances[sj]; + resultDistances[sj] = tmpD; + i64 tmpR = resultRowids[si]; + resultRowids[si] = resultRowids[sj]; + resultRowids[sj] = tmpR; + } + } + } + + knn_data->k = resultCount; + knn_data->k_used = resultCount; + knn_data->rowids = resultRowids; + knn_data->distances = resultDistances; + knn_data->current_idx = 0; + + pCur->knn_data = knn_data; + pCur->query_plan = VEC0_QUERY_PLAN_KNN; + + return SQLITE_OK; +} +#endif /* SQLITE_VEC_ENABLE_DISKANN */ + int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum, const char *idxStr, int argc, sqlite3_value **argv) { assert(argc == (strlen(idxStr)-1) / 4); @@ -7098,6 +7732,13 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum, struct VectorColumnDefinition *vector_column = &p->vector_columns[vectorColumnIdx]; +#if SQLITE_VEC_ENABLE_DISKANN + // DiskANN dispatch + if (vector_column->index_type == VEC0_INDEX_TYPE_DISKANN) { + return vec0Filter_knn_diskann(pCur, p, idxNum, idxStr, argc, argv); + } +#endif + struct Array *arrayRowidsIn = NULL; sqlite3_stmt *stmtChunks = NULL; void *queryVector; @@ -8567,24 +9208,37 @@ int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv, goto cleanup; } - // Step #2: Find the next "available" position in the _chunks table for this - // row. - rc = vec0Update_InsertNextAvailableStep(p, partitionKeyValues, - &chunk_rowid, &chunk_offset, - &blobChunksValidity, - &bufferChunksValidity); - if (rc != SQLITE_OK) { - goto cleanup; + if (!vec0_all_columns_diskann(p)) { + // Step #2: Find the next "available" position in the _chunks table for this + // row. + rc = vec0Update_InsertNextAvailableStep(p, partitionKeyValues, + &chunk_rowid, &chunk_offset, + &blobChunksValidity, + &bufferChunksValidity); + if (rc != SQLITE_OK) { + goto cleanup; + } + + // Step #3: With the next available chunk position, write out all the vectors + // to their specified location. + rc = vec0Update_InsertWriteFinalStep(p, chunk_rowid, chunk_offset, rowid, + vectorDatas, blobChunksValidity, + bufferChunksValidity); + if (rc != SQLITE_OK) { + goto cleanup; + } } - // Step #3: With the next available chunk position, write out all the vectors - // to their specified location. - rc = vec0Update_InsertWriteFinalStep(p, chunk_rowid, chunk_offset, rowid, - vectorDatas, blobChunksValidity, - bufferChunksValidity); - if (rc != SQLITE_OK) { - goto cleanup; +#if SQLITE_VEC_ENABLE_DISKANN + // Step #4: Insert into DiskANN graph for indexed vector columns + for (int i = 0; i < p->numVectorColumns; i++) { + if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_DISKANN) continue; + rc = diskann_insert(p, i, rowid, vectorDatas[i]); + if (rc != SQLITE_OK) { + goto cleanup; + } } +#endif #if SQLITE_VEC_ENABLE_RESCORE rc = rescore_on_insert(p, chunk_rowid, chunk_offset, rowid, vectorDatas); @@ -9126,29 +9780,43 @@ int vec0Update_Delete(sqlite3_vtab *pVTab, sqlite3_value *idValue) { // 4. Zero out vector data in all vector column chunks // 5. Delete value in _rowids table - // 1. get chunk_id and chunk_offset from _rowids - rc = vec0_get_chunk_position(p, rowid, NULL, &chunk_id, &chunk_offset); - if (rc != SQLITE_OK) { - return rc; +#if SQLITE_VEC_ENABLE_DISKANN + // DiskANN graph deletion for indexed columns + for (int i = 0; i < p->numVectorColumns; i++) { + if (p->vector_columns[i].index_type != VEC0_INDEX_TYPE_DISKANN) continue; + rc = diskann_delete(p, i, rowid); + if (rc != SQLITE_OK) { + return rc; + } + } +#endif + + if (!vec0_all_columns_diskann(p)) { + // 1. get chunk_id and chunk_offset from _rowids + rc = vec0_get_chunk_position(p, rowid, NULL, &chunk_id, &chunk_offset); + if (rc != SQLITE_OK) { + return rc; + } + + // 2. clear validity bit + rc = vec0Update_Delete_ClearValidity(p, chunk_id, chunk_offset); + if (rc != SQLITE_OK) { + return rc; + } + + // 3. zero out rowid in chunks.rowids + rc = vec0Update_Delete_ClearRowid(p, chunk_id, chunk_offset); + if (rc != SQLITE_OK) { + return rc; + } + + // 4. zero out any data in vector chunks tables + rc = vec0Update_Delete_ClearVectors(p, chunk_id, chunk_offset); + if (rc != SQLITE_OK) { + return rc; + } } - // 2. clear validity bit - rc = vec0Update_Delete_ClearValidity(p, chunk_id, chunk_offset); - if (rc != SQLITE_OK) { - return rc; - } - - // 3. zero out rowid in chunks.rowids - rc = vec0Update_Delete_ClearRowid(p, chunk_id, chunk_offset); - if (rc != SQLITE_OK) { - return rc; - } - - // 4. zero out any data in vector chunks tables - rc = vec0Update_Delete_ClearVectors(p, chunk_id, chunk_offset); - if (rc != SQLITE_OK) { - return rc; - } #if SQLITE_VEC_ENABLE_RESCORE // 4b. zero out quantized data in rescore chunk tables, delete from rescore vectors @@ -9172,20 +9840,22 @@ int vec0Update_Delete(sqlite3_vtab *pVTab, sqlite3_value *idValue) { } } - // 7. delete metadata - for(int i = 0; i < p->numMetadataColumns; i++) { - rc = vec0Update_Delete_ClearMetadata(p, i, rowid, chunk_id, chunk_offset); - if (rc != SQLITE_OK) { - return rc; + // 7. delete metadata and reclaim chunk (only when using chunk-based storage) + if (!vec0_all_columns_diskann(p)) { + for(int i = 0; i < p->numMetadataColumns; i++) { + rc = vec0Update_Delete_ClearMetadata(p, i, rowid, chunk_id, chunk_offset); + if (rc != SQLITE_OK) { + return rc; + } } - } - // 8. reclaim chunk if fully empty - { - int chunkDeleted; - rc = vec0Update_Delete_DeleteChunkIfEmpty(p, chunk_id, &chunkDeleted); - if (rc != SQLITE_OK) { - return rc; + // 8. reclaim chunk if fully empty + { + int chunkDeleted; + rc = vec0Update_Delete_DeleteChunkIfEmpty(p, chunk_id, &chunkDeleted); + if (rc != SQLITE_OK) { + return rc; + } } } @@ -9481,8 +10151,12 @@ static int vec0Update(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv, const char *cmd = (const char *)sqlite3_value_text(idVal); vec0_vtab *p = (vec0_vtab *)pVTab; int cmdRc = ivf_handle_command(p, cmd, argc, argv); +#if SQLITE_VEC_ENABLE_DISKANN + if (cmdRc == SQLITE_EMPTY) + cmdRc = diskann_handle_command(p, cmd); +#endif if (cmdRc != SQLITE_EMPTY) return cmdRc; // handled (or error) - // SQLITE_EMPTY means not an IVF command — fall through to normal insert + // SQLITE_EMPTY means not a recognized command — fall through to normal insert } #endif return vec0Update_Insert(pVTab, argc, argv, pRowid); @@ -9638,9 +10312,16 @@ static sqlite3_module vec0Module = { #define SQLITE_VEC_DEBUG_BUILD_IVF "" #endif +#if SQLITE_VEC_ENABLE_DISKANN +#define SQLITE_VEC_DEBUG_BUILD_DISKANN "diskann" +#else +#define SQLITE_VEC_DEBUG_BUILD_DISKANN "" +#endif + #define SQLITE_VEC_DEBUG_BUILD \ SQLITE_VEC_DEBUG_BUILD_AVX " " SQLITE_VEC_DEBUG_BUILD_NEON " " \ - SQLITE_VEC_DEBUG_BUILD_RESCORE " " SQLITE_VEC_DEBUG_BUILD_IVF + SQLITE_VEC_DEBUG_BUILD_RESCORE " " SQLITE_VEC_DEBUG_BUILD_IVF " " \ + SQLITE_VEC_DEBUG_BUILD_DISKANN #define SQLITE_VEC_DEBUG_STRING \ "Version: " SQLITE_VEC_VERSION "\n" \ diff --git a/tests/fuzz/Makefile b/tests/fuzz/Makefile index a3405a4..202dc2b 100644 --- a/tests/fuzz/Makefile +++ b/tests/fuzz/Makefile @@ -26,7 +26,7 @@ FUZZ_LDFLAGS ?= $(shell \ echo "-Wl,-ld_classic"; \ fi) -FUZZ_CFLAGS = $(FUZZ_SANITIZERS) -I ../../ -I ../../vendor -DSQLITE_CORE -g $(FUZZ_LDFLAGS) +FUZZ_CFLAGS = $(FUZZ_SANITIZERS) -I ../../ -I ../../vendor -DSQLITE_CORE -DSQLITE_VEC_ENABLE_DISKANN=1 -g $(FUZZ_LDFLAGS) FUZZ_SRCS = ../../vendor/sqlite3.c ../../sqlite-vec.c TARGET_DIR = ./targets @@ -115,6 +115,34 @@ $(TARGET_DIR)/ivf_cell_overflow: ivf-cell-overflow.c $(FUZZ_SRCS) | $(TARGET_DIR $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@ $(TARGET_DIR)/ivf_rescore: ivf-rescore.c $(FUZZ_SRCS) | $(TARGET_DIR) +$(TARGET_DIR)/diskann_operations: diskann-operations.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@ + +$(TARGET_DIR)/diskann_create: diskann-create.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@ + +$(TARGET_DIR)/diskann_graph_corrupt: diskann-graph-corrupt.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@ + +$(TARGET_DIR)/diskann_deep_search: diskann-deep-search.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@ + +$(TARGET_DIR)/diskann_blob_truncate: diskann-blob-truncate.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@ + +$(TARGET_DIR)/diskann_delete_stress: diskann-delete-stress.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@ + +$(TARGET_DIR)/diskann_buffer_flush: diskann-buffer-flush.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@ + +$(TARGET_DIR)/diskann_int8_quant: diskann-int8-quant.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@ + +$(TARGET_DIR)/diskann_prune_direct: diskann-prune-direct.c $(FUZZ_SRCS) | $(TARGET_DIR) + $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@ + +$(TARGET_DIR)/diskann_command_inject: diskann-command-inject.c $(FUZZ_SRCS) | $(TARGET_DIR) $(FUZZ_CC) $(FUZZ_CFLAGS) $(FUZZ_SRCS) $< -o $@ FUZZ_TARGETS = vec0_create exec json numpy \ @@ -127,6 +155,11 @@ FUZZ_TARGETS = vec0_create exec json numpy \ ivf_create ivf_operations \ ivf_quantize ivf_kmeans ivf_shadow_corrupt \ ivf_knn_deep ivf_cell_overflow ivf_rescore + diskann_operations diskann_create diskann_graph_corrupt \ + diskann_deep_search diskann_blob_truncate \ + diskann_delete_stress diskann_buffer_flush \ + diskann_int8_quant diskann_prune_direct \ + diskann_command_inject all: $(addprefix $(TARGET_DIR)/,$(FUZZ_TARGETS)) diff --git a/tests/fuzz/diskann-blob-truncate.c b/tests/fuzz/diskann-blob-truncate.c new file mode 100644 index 0000000..903a0d7 --- /dev/null +++ b/tests/fuzz/diskann-blob-truncate.c @@ -0,0 +1,250 @@ +/** + * Fuzz target for DiskANN shadow table blob size mismatches. + * + * The critical vulnerability: diskann_node_read() copies whatever blob size + * SQLite returns, but diskann_search/insert/delete index into those blobs + * using cfg->n_neighbors * sizeof(i64) etc. If the blob is truncated, + * extended, or has wrong size, this causes out-of-bounds reads/writes. + * + * This fuzzer: + * 1. Creates a valid DiskANN graph with several nodes + * 2. Uses fuzz data to directly write malformed blobs to shadow tables: + * - Truncated neighbor_ids (fewer bytes than n_neighbors * 8) + * - Truncated validity bitmaps + * - Oversized blobs with garbage trailing data + * - Zero-length blobs + * - Blobs with valid headers but corrupted neighbor rowids + * 3. Runs INSERT, DELETE, and KNN operations that traverse the corrupted graph + * + * Key code paths targeted: + * - diskann_node_read with mismatched blob sizes + * - diskann_validity_get / diskann_neighbor_id_get on truncated blobs + * - diskann_add_reverse_edge reading corrupted neighbor data + * - diskann_repair_reverse_edges traversing corrupted neighbor lists + * - diskann_search iterating neighbors from corrupted blobs + */ +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) { + if (*size == 0) return def; + uint8_t b = **data; + (*data)++; + (*size)--; + return b; +} + +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + if (size < 32) return 0; + + int rc; + sqlite3 *db; + + rc = sqlite3_open(":memory:", &db); + assert(rc == SQLITE_OK); + rc = sqlite3_vec_init(db, NULL, NULL); + assert(rc == SQLITE_OK); + + /* Use binary quantizer, float[16], n_neighbors=8 for predictable blob sizes: + * validity: 8/8 = 1 byte + * neighbor_ids: 8 * 8 = 64 bytes + * qvecs: 8 * (16/8) = 16 bytes (binary: 2 bytes per qvec) + */ + rc = sqlite3_exec(db, + "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[16] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8))", + NULL, NULL, NULL); + if (rc != SQLITE_OK) { sqlite3_close(db); return 0; } + + /* Insert 12 vectors to create a valid graph structure */ + { + sqlite3_stmt *stmt; + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmt, NULL); + for (int i = 1; i <= 12; i++) { + float vec[16]; + for (int j = 0; j < 16; j++) { + vec[j] = (float)i * 0.1f + (float)j * 0.01f; + } + sqlite3_reset(stmt); + sqlite3_bind_int64(stmt, 1, i); + sqlite3_bind_blob(stmt, 2, vec, sizeof(vec), SQLITE_TRANSIENT); + sqlite3_step(stmt); + } + sqlite3_finalize(stmt); + } + + /* Now corrupt shadow table blobs using fuzz data */ + const char *columns[] = { + "neighbors_validity", + "neighbor_ids", + "neighbor_quantized_vectors" + }; + + /* Expected sizes for n_neighbors=8, dims=16, binary quantizer */ + int expected_sizes[] = {1, 64, 16}; + + while (size >= 4) { + int target_row = (fuzz_byte(&data, &size, 0) % 12) + 1; + int col_idx = fuzz_byte(&data, &size, 0) % 3; + uint8_t corrupt_mode = fuzz_byte(&data, &size, 0) % 6; + uint8_t extra = fuzz_byte(&data, &size, 0); + + char sqlbuf[256]; + snprintf(sqlbuf, sizeof(sqlbuf), + "UPDATE v_diskann_nodes00 SET %s = ? WHERE rowid = ?", + columns[col_idx]); + + sqlite3_stmt *writeStmt; + rc = sqlite3_prepare_v2(db, sqlbuf, -1, &writeStmt, NULL); + if (rc != SQLITE_OK) continue; + + int expected = expected_sizes[col_idx]; + unsigned char *blob = NULL; + int blob_size = 0; + + switch (corrupt_mode) { + case 0: { + /* Truncated blob: 0 to expected-1 bytes */ + blob_size = extra % expected; + if (blob_size == 0) blob_size = 0; /* zero-length is interesting */ + blob = sqlite3_malloc(blob_size > 0 ? blob_size : 1); + if (!blob) { sqlite3_finalize(writeStmt); continue; } + for (int i = 0; i < blob_size; i++) { + blob[i] = fuzz_byte(&data, &size, 0); + } + break; + } + case 1: { + /* Oversized blob: expected + extra bytes */ + blob_size = expected + (extra % 64); + blob = sqlite3_malloc(blob_size); + if (!blob) { sqlite3_finalize(writeStmt); continue; } + for (int i = 0; i < blob_size; i++) { + blob[i] = fuzz_byte(&data, &size, 0xFF); + } + break; + } + case 2: { + /* Zero-length blob */ + blob_size = 0; + blob = NULL; + sqlite3_bind_zeroblob(writeStmt, 1, 0); + sqlite3_bind_int64(writeStmt, 2, target_row); + sqlite3_step(writeStmt); + sqlite3_finalize(writeStmt); + continue; + } + case 3: { + /* Correct size but all-ones validity (all slots "valid") with + * garbage neighbor IDs -- forces reading non-existent nodes */ + blob_size = expected; + blob = sqlite3_malloc(blob_size); + if (!blob) { sqlite3_finalize(writeStmt); continue; } + memset(blob, 0xFF, blob_size); + break; + } + case 4: { + /* neighbor_ids with very large rowid values (near INT64_MAX) */ + blob_size = expected; + blob = sqlite3_malloc(blob_size); + if (!blob) { sqlite3_finalize(writeStmt); continue; } + memset(blob, 0x7F, blob_size); /* fills with large positive values */ + break; + } + case 5: { + /* neighbor_ids with negative rowid values (rowid=0 is sentinel) */ + blob_size = expected; + blob = sqlite3_malloc(blob_size); + if (!blob) { sqlite3_finalize(writeStmt); continue; } + memset(blob, 0x80, blob_size); /* fills with large negative values */ + /* Flip some bytes from fuzz data */ + for (int i = 0; i < blob_size && size > 0; i++) { + blob[i] ^= fuzz_byte(&data, &size, 0); + } + break; + } + } + + if (blob) { + sqlite3_bind_blob(writeStmt, 1, blob, blob_size, SQLITE_TRANSIENT); + } else { + sqlite3_bind_blob(writeStmt, 1, "", 0, SQLITE_STATIC); + } + sqlite3_bind_int64(writeStmt, 2, target_row); + sqlite3_step(writeStmt); + sqlite3_finalize(writeStmt); + sqlite3_free(blob); + } + + /* Exercise the corrupted graph with various operations */ + + /* KNN query */ + { + float qvec[16]; + for (int j = 0; j < 16; j++) qvec[j] = (float)j * 0.1f; + sqlite3_stmt *knnStmt; + rc = sqlite3_prepare_v2(db, + "SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = 5", + -1, &knnStmt, NULL); + if (rc == SQLITE_OK) { + sqlite3_bind_blob(knnStmt, 1, qvec, sizeof(qvec), SQLITE_STATIC); + while (sqlite3_step(knnStmt) == SQLITE_ROW) {} + sqlite3_finalize(knnStmt); + } + } + + /* Insert into corrupted graph (triggers add_reverse_edge on corrupted nodes) */ + { + float vec[16]; + for (int j = 0; j < 16; j++) vec[j] = 0.5f; + sqlite3_stmt *stmt; + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmt, NULL); + if (stmt) { + sqlite3_bind_int64(stmt, 1, 100); + sqlite3_bind_blob(stmt, 2, vec, sizeof(vec), SQLITE_TRANSIENT); + sqlite3_step(stmt); + sqlite3_finalize(stmt); + } + } + + /* Delete from corrupted graph (triggers repair_reverse_edges) */ + { + sqlite3_stmt *stmt; + sqlite3_prepare_v2(db, + "DELETE FROM v WHERE rowid = ?", -1, &stmt, NULL); + if (stmt) { + sqlite3_bind_int64(stmt, 1, 5); + sqlite3_step(stmt); + sqlite3_finalize(stmt); + } + } + + /* Another KNN to traverse the post-mutation graph */ + { + float qvec[16]; + for (int j = 0; j < 16; j++) qvec[j] = -0.5f + (float)j * 0.07f; + sqlite3_stmt *knnStmt; + rc = sqlite3_prepare_v2(db, + "SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = 12", + -1, &knnStmt, NULL); + if (rc == SQLITE_OK) { + sqlite3_bind_blob(knnStmt, 1, qvec, sizeof(qvec), SQLITE_STATIC); + while (sqlite3_step(knnStmt) == SQLITE_ROW) {} + sqlite3_finalize(knnStmt); + } + } + + /* Full scan */ + sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL); + + sqlite3_close(db); + return 0; +} diff --git a/tests/fuzz/diskann-buffer-flush.c b/tests/fuzz/diskann-buffer-flush.c new file mode 100644 index 0000000..f10e100 --- /dev/null +++ b/tests/fuzz/diskann-buffer-flush.c @@ -0,0 +1,164 @@ +/** + * Fuzz target for DiskANN buffered insert and flush paths. + * + * When buffer_threshold > 0, inserts go into a flat buffer table and + * are flushed into the graph in batch. This fuzzer exercises: + * + * - diskann_buffer_write / diskann_buffer_delete / diskann_buffer_exists + * - diskann_flush_buffer (batch graph insertion) + * - diskann_insert with buffer_threshold (batching logic) + * - Buffer-graph merge in vec0Filter_knn_diskann (unflushed vectors + * must be scanned during KNN and merged with graph results) + * - Delete of a buffered (not yet flushed) vector + * - Delete of a graph vector while buffer has pending inserts + * - Interaction: insert to buffer, query (triggers buffer scan), flush, + * query again (now from graph) + * + * The buffer merge path in vec0Filter_knn_diskann is particularly + * interesting because it does a brute-force scan of buffer vectors and + * merges with the top-k from graph search. + */ +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) { + if (*size == 0) return def; + uint8_t b = **data; + (*data)++; + (*size)--; + return b; +} + +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + if (size < 16) return 0; + + int rc; + sqlite3 *db; + rc = sqlite3_open(":memory:", &db); + assert(rc == SQLITE_OK); + rc = sqlite3_vec_init(db, NULL, NULL); + assert(rc == SQLITE_OK); + + /* buffer_threshold: small (3-8) to trigger frequent flushes */ + int buf_threshold = 3 + (fuzz_byte(&data, &size, 0) % 6); + int dims = 8; + + char sql[512]; + snprintf(sql, sizeof(sql), + "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[%d] INDEXED BY diskann(" + "neighbor_quantizer=binary, n_neighbors=8, " + "search_list_size=16, buffer_threshold=%d" + "))", dims, buf_threshold); + + rc = sqlite3_exec(db, sql, NULL, NULL, NULL); + if (rc != SQLITE_OK) { sqlite3_close(db); return 0; } + + sqlite3_stmt *stmtInsert = NULL, *stmtDelete = NULL, *stmtKnn = NULL; + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); + sqlite3_prepare_v2(db, + "DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL); + sqlite3_prepare_v2(db, + "SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?", + -1, &stmtKnn, NULL); + + if (!stmtInsert || !stmtDelete || !stmtKnn) goto cleanup; + + float vec[8]; + int next_rowid = 1; + + while (size >= 2) { + uint8_t op = fuzz_byte(&data, &size, 0) % 6; + uint8_t param = fuzz_byte(&data, &size, 0); + + switch (op) { + case 0: { /* Insert: accumulates in buffer until threshold */ + int64_t rowid = next_rowid++; + if (next_rowid > 64) next_rowid = 1; /* wrap around for reuse */ + for (int j = 0; j < dims; j++) { + vec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f; + } + sqlite3_reset(stmtInsert); + sqlite3_bind_int64(stmtInsert, 1, rowid); + sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT); + sqlite3_step(stmtInsert); + break; + } + case 1: { /* KNN query while buffer may have unflushed vectors */ + for (int j = 0; j < dims; j++) { + vec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f; + } + int k = (param % 10) + 1; + sqlite3_reset(stmtKnn); + sqlite3_bind_blob(stmtKnn, 1, vec, sizeof(vec), SQLITE_TRANSIENT); + sqlite3_bind_int(stmtKnn, 2, k); + while (sqlite3_step(stmtKnn) == SQLITE_ROW) {} + break; + } + case 2: { /* Delete a potentially-buffered vector */ + int64_t rowid = (int64_t)(param % 64) + 1; + sqlite3_reset(stmtDelete); + sqlite3_bind_int64(stmtDelete, 1, rowid); + sqlite3_step(stmtDelete); + break; + } + case 3: { /* Insert several at once to trigger flush mid-batch */ + for (int i = 0; i < buf_threshold + 1 && size >= 2; i++) { + int64_t rowid = (int64_t)(fuzz_byte(&data, &size, 0) % 64) + 1; + for (int j = 0; j < dims; j++) { + vec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f; + } + sqlite3_reset(stmtInsert); + sqlite3_bind_int64(stmtInsert, 1, rowid); + sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT); + sqlite3_step(stmtInsert); + } + break; + } + case 4: { /* Insert then immediately delete (still in buffer) */ + int64_t rowid = (int64_t)(param % 64) + 1; + for (int j = 0; j < dims; j++) vec[j] = 0.1f * param; + sqlite3_reset(stmtInsert); + sqlite3_bind_int64(stmtInsert, 1, rowid); + sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT); + sqlite3_step(stmtInsert); + + sqlite3_reset(stmtDelete); + sqlite3_bind_int64(stmtDelete, 1, rowid); + sqlite3_step(stmtDelete); + break; + } + case 5: { /* Query with k=0 and k=1 (boundary) */ + for (int j = 0; j < dims; j++) vec[j] = 0.0f; + sqlite3_reset(stmtKnn); + sqlite3_bind_blob(stmtKnn, 1, vec, sizeof(vec), SQLITE_TRANSIENT); + sqlite3_bind_int(stmtKnn, 2, param % 2); /* k=0 or k=1 */ + while (sqlite3_step(stmtKnn) == SQLITE_ROW) {} + break; + } + } + } + + /* Final query to exercise post-operation state */ + { + float qvec[8] = {1.0f, -1.0f, 0.5f, -0.5f, 0.0f, 0.0f, 0.0f, 0.0f}; + sqlite3_reset(stmtKnn); + sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_TRANSIENT); + sqlite3_bind_int(stmtKnn, 2, 20); + while (sqlite3_step(stmtKnn) == SQLITE_ROW) {} + } + +cleanup: + sqlite3_finalize(stmtInsert); + sqlite3_finalize(stmtDelete); + sqlite3_finalize(stmtKnn); + sqlite3_close(db); + return 0; +} diff --git a/tests/fuzz/diskann-command-inject.c b/tests/fuzz/diskann-command-inject.c new file mode 100644 index 0000000..ef62884 --- /dev/null +++ b/tests/fuzz/diskann-command-inject.c @@ -0,0 +1,158 @@ +/** + * Fuzz target for DiskANN runtime command dispatch (diskann_handle_command). + * + * The command handler parses strings like "search_list_size_search=42" and + * modifies live DiskANN config. This fuzzer exercises: + * + * - atoi on fuzz-controlled strings (integer overflow, negative, non-numeric) + * - strncmp boundary with fuzz data (near-matches to valid commands) + * - Changing search_list_size mid-operation (affects subsequent queries) + * - Setting search_list_size to 1 (minimum - single-candidate beam search) + * - Setting search_list_size very large (memory pressure) + * - Interleaving command changes with inserts and queries + * + * Also tests the UPDATE v SET command = ? path through the vtable. + */ +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) { + if (*size == 0) return def; + uint8_t b = **data; + (*data)++; + (*size)--; + return b; +} + +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + if (size < 20) return 0; + + int rc; + sqlite3 *db; + rc = sqlite3_open(":memory:", &db); + assert(rc == SQLITE_OK); + rc = sqlite3_vec_init(db, NULL, NULL); + assert(rc == SQLITE_OK); + + rc = sqlite3_exec(db, + "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8))", + NULL, NULL, NULL); + if (rc != SQLITE_OK) { sqlite3_close(db); return 0; } + + /* Insert some vectors first */ + { + sqlite3_stmt *stmt; + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmt, NULL); + for (int i = 1; i <= 8; i++) { + float vec[8]; + for (int j = 0; j < 8; j++) vec[j] = (float)i * 0.1f + (float)j * 0.01f; + sqlite3_reset(stmt); + sqlite3_bind_int64(stmt, 1, i); + sqlite3_bind_blob(stmt, 2, vec, sizeof(vec), SQLITE_TRANSIENT); + sqlite3_step(stmt); + } + sqlite3_finalize(stmt); + } + + sqlite3_stmt *stmtCmd = NULL; + sqlite3_stmt *stmtInsert = NULL; + sqlite3_stmt *stmtKnn = NULL; + + /* Commands are dispatched via INSERT INTO t(rowid) VALUES ('cmd_string') */ + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid) VALUES (?)", -1, &stmtCmd, NULL); + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); + sqlite3_prepare_v2(db, + "SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?", + -1, &stmtKnn, NULL); + + if (!stmtCmd || !stmtInsert || !stmtKnn) goto cleanup; + + /* Fuzz-driven command + operation interleaving */ + while (size >= 2) { + uint8_t op = fuzz_byte(&data, &size, 0) % 5; + + switch (op) { + case 0: { /* Send fuzz command string */ + int cmd_len = fuzz_byte(&data, &size, 0) % 64; + char cmd[65]; + for (int i = 0; i < cmd_len && size > 0; i++) { + cmd[i] = (char)fuzz_byte(&data, &size, 0); + } + cmd[cmd_len] = '\0'; + sqlite3_reset(stmtCmd); + sqlite3_bind_text(stmtCmd, 1, cmd, -1, SQLITE_TRANSIENT); + sqlite3_step(stmtCmd); /* May fail -- that's expected */ + break; + } + case 1: { /* Send valid-looking command with fuzz value */ + const char *prefixes[] = { + "search_list_size=", + "search_list_size_search=", + "search_list_size_insert=", + }; + int prefix_idx = fuzz_byte(&data, &size, 0) % 3; + int val = (int)(int8_t)fuzz_byte(&data, &size, 0); + + char cmd[128]; + snprintf(cmd, sizeof(cmd), "%s%d", prefixes[prefix_idx], val); + sqlite3_reset(stmtCmd); + sqlite3_bind_text(stmtCmd, 1, cmd, -1, SQLITE_TRANSIENT); + sqlite3_step(stmtCmd); + break; + } + case 2: { /* KNN query (uses whatever search_list_size is set) */ + float qvec[8] = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + qvec[0] = (float)((int8_t)fuzz_byte(&data, &size, 127)) / 10.0f; + int k = fuzz_byte(&data, &size, 3) % 10 + 1; + sqlite3_reset(stmtKnn); + sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_TRANSIENT); + sqlite3_bind_int(stmtKnn, 2, k); + while (sqlite3_step(stmtKnn) == SQLITE_ROW) {} + break; + } + case 3: { /* Insert (uses whatever search_list_size_insert is set) */ + int64_t rowid = (int64_t)(fuzz_byte(&data, &size, 0) % 32) + 1; + float vec[8]; + for (int j = 0; j < 8; j++) { + vec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f; + } + sqlite3_reset(stmtInsert); + sqlite3_bind_int64(stmtInsert, 1, rowid); + sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT); + sqlite3_step(stmtInsert); + break; + } + case 4: { /* Set search_list_size to extreme values */ + const char *extreme_cmds[] = { + "search_list_size=1", + "search_list_size=2", + "search_list_size=1000", + "search_list_size_search=1", + "search_list_size_insert=1", + }; + int idx = fuzz_byte(&data, &size, 0) % 5; + sqlite3_reset(stmtCmd); + sqlite3_bind_text(stmtCmd, 1, extreme_cmds[idx], -1, SQLITE_STATIC); + sqlite3_step(stmtCmd); + break; + } + } + } + +cleanup: + sqlite3_finalize(stmtCmd); + sqlite3_finalize(stmtInsert); + sqlite3_finalize(stmtKnn); + sqlite3_close(db); + return 0; +} diff --git a/tests/fuzz/diskann-create.c b/tests/fuzz/diskann-create.c new file mode 100644 index 0000000..1b40a84 --- /dev/null +++ b/tests/fuzz/diskann-create.c @@ -0,0 +1,44 @@ +/** + * Fuzz target for DiskANN CREATE TABLE config parsing. + * Feeds fuzz data as the INDEXED BY diskann(...) option string. + */ +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + if (size > 4096) return 0; /* Limit input size */ + + int rc; + sqlite3 *db; + sqlite3_stmt *stmt; + + rc = sqlite3_open(":memory:", &db); + assert(rc == SQLITE_OK); + rc = sqlite3_vec_init(db, NULL, NULL); + assert(rc == SQLITE_OK); + + sqlite3_str *s = sqlite3_str_new(NULL); + assert(s); + sqlite3_str_appendall(s, + "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[64] INDEXED BY diskann("); + sqlite3_str_appendf(s, "%.*s", (int)size, data); + sqlite3_str_appendall(s, "))"); + const char *zSql = sqlite3_str_finish(s); + assert(zSql); + + rc = sqlite3_prepare_v2(db, zSql, -1, &stmt, NULL); + sqlite3_free((char *)zSql); + if (rc == SQLITE_OK) { + sqlite3_step(stmt); + } + sqlite3_finalize(stmt); + sqlite3_close(db); + return 0; +} diff --git a/tests/fuzz/diskann-deep-search.c b/tests/fuzz/diskann-deep-search.c new file mode 100644 index 0000000..35d548c --- /dev/null +++ b/tests/fuzz/diskann-deep-search.c @@ -0,0 +1,187 @@ +/** + * Fuzz target for DiskANN greedy beam search deep paths. + * + * Builds a graph with enough nodes to force multi-hop traversal, then + * uses fuzz data to control: query vector values, k, search_list_size + * overrides, and interleaved insert/delete/query sequences that stress + * the candidate list growth, visited set hash collisions, and the + * re-ranking logic. + * + * Key code paths targeted: + * - diskann_candidate_list_insert (sorted insert, dedup, eviction) + * - diskann_visited_set (hash collisions, capacity) + * - diskann_search (full beam search loop, re-ranking with exact dist) + * - diskann_distance_quantized_precomputed (both binary and int8) + * - Buffer merge in vec0Filter_knn_diskann + */ +#include +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +/* Consume one byte from fuzz input, or return default. */ +static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) { + if (*size == 0) return def; + uint8_t b = **data; + (*data)++; + (*size)--; + return b; +} + +static uint16_t fuzz_u16(const uint8_t **data, size_t *size) { + uint8_t lo = fuzz_byte(data, size, 0); + uint8_t hi = fuzz_byte(data, size, 0); + return (uint16_t)hi << 8 | lo; +} + +static float fuzz_float(const uint8_t **data, size_t *size) { + return (float)((int8_t)fuzz_byte(data, size, 0)) / 10.0f; +} + +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + if (size < 32) return 0; + + /* Use first bytes to pick quantizer type and dimensions */ + uint8_t quantizer_choice = fuzz_byte(&data, &size, 0) % 2; + const char *quantizer = quantizer_choice ? "int8" : "binary"; + + /* Dimensions must be divisible by 8. Pick from {8, 16, 32} */ + int dim_choices[] = {8, 16, 32}; + int dims = dim_choices[fuzz_byte(&data, &size, 0) % 3]; + + /* n_neighbors: 8 or 16 -- small to force full-neighbor scenarios quickly */ + int n_neighbors = (fuzz_byte(&data, &size, 0) % 2) ? 16 : 8; + + /* search_list_size: small so beam search terminates quickly but still exercises loops */ + int search_list_size = 8 + (fuzz_byte(&data, &size, 0) % 24); + + /* alpha: vary to test RobustPrune pruning logic */ + float alpha_choices[] = {1.0f, 1.2f, 1.5f, 2.0f}; + float alpha = alpha_choices[fuzz_byte(&data, &size, 0) % 4]; + + int rc; + sqlite3 *db; + rc = sqlite3_open(":memory:", &db); + assert(rc == SQLITE_OK); + rc = sqlite3_vec_init(db, NULL, NULL); + assert(rc == SQLITE_OK); + + char sql[512]; + snprintf(sql, sizeof(sql), + "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[%d] INDEXED BY diskann(" + "neighbor_quantizer=%s, n_neighbors=%d, " + "search_list_size=%d" + "))", dims, quantizer, n_neighbors, search_list_size); + + rc = sqlite3_exec(db, sql, NULL, NULL, NULL); + if (rc != SQLITE_OK) { sqlite3_close(db); return 0; } + + sqlite3_stmt *stmtInsert = NULL, *stmtDelete = NULL, *stmtKnn = NULL; + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); + sqlite3_prepare_v2(db, + "DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL); + + char knn_sql[256]; + snprintf(knn_sql, sizeof(knn_sql), + "SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?"); + sqlite3_prepare_v2(db, knn_sql, -1, &stmtKnn, NULL); + + if (!stmtInsert || !stmtDelete || !stmtKnn) goto cleanup; + + /* Phase 1: Seed the graph with enough nodes to create multi-hop structure. + * Insert 2*n_neighbors nodes so the graph is dense enough for search + * to actually traverse multiple hops. */ + int seed_count = n_neighbors * 2; + if (seed_count > 64) seed_count = 64; /* Bound for performance */ + { + float *vec = malloc(dims * sizeof(float)); + if (!vec) goto cleanup; + for (int i = 1; i <= seed_count; i++) { + for (int j = 0; j < dims; j++) { + vec[j] = fuzz_float(&data, &size); + } + sqlite3_reset(stmtInsert); + sqlite3_bind_int64(stmtInsert, 1, i); + sqlite3_bind_blob(stmtInsert, 2, vec, dims * sizeof(float), SQLITE_TRANSIENT); + sqlite3_step(stmtInsert); + } + free(vec); + } + + /* Phase 2: Fuzz-driven operations on the seeded graph */ + float *vec = malloc(dims * sizeof(float)); + if (!vec) goto cleanup; + + while (size >= 2) { + uint8_t op = fuzz_byte(&data, &size, 0) % 5; + uint8_t param = fuzz_byte(&data, &size, 0); + + switch (op) { + case 0: { /* INSERT with fuzz-controlled vector and rowid */ + int64_t rowid = (int64_t)(param % 128) + 1; + for (int j = 0; j < dims; j++) { + vec[j] = fuzz_float(&data, &size); + } + sqlite3_reset(stmtInsert); + sqlite3_bind_int64(stmtInsert, 1, rowid); + sqlite3_bind_blob(stmtInsert, 2, vec, dims * sizeof(float), SQLITE_TRANSIENT); + sqlite3_step(stmtInsert); + break; + } + case 1: { /* DELETE */ + int64_t rowid = (int64_t)(param % 128) + 1; + sqlite3_reset(stmtDelete); + sqlite3_bind_int64(stmtDelete, 1, rowid); + sqlite3_step(stmtDelete); + break; + } + case 2: { /* KNN with fuzz query vector and variable k */ + for (int j = 0; j < dims; j++) { + vec[j] = fuzz_float(&data, &size); + } + int k = (param % 20) + 1; + sqlite3_reset(stmtKnn); + sqlite3_bind_blob(stmtKnn, 1, vec, dims * sizeof(float), SQLITE_TRANSIENT); + sqlite3_bind_int(stmtKnn, 2, k); + while (sqlite3_step(stmtKnn) == SQLITE_ROW) {} + break; + } + case 3: { /* KNN with k > number of nodes (boundary) */ + for (int j = 0; j < dims; j++) { + vec[j] = fuzz_float(&data, &size); + } + sqlite3_reset(stmtKnn); + sqlite3_bind_blob(stmtKnn, 1, vec, dims * sizeof(float), SQLITE_TRANSIENT); + sqlite3_bind_int(stmtKnn, 2, 1000); /* k >> graph size */ + while (sqlite3_step(stmtKnn) == SQLITE_ROW) {} + break; + } + case 4: { /* INSERT duplicate rowid (triggers OR REPLACE path) */ + int64_t rowid = (int64_t)(param % 32) + 1; + for (int j = 0; j < dims; j++) { + vec[j] = (float)(param + j) / 50.0f; + } + sqlite3_reset(stmtInsert); + sqlite3_bind_int64(stmtInsert, 1, rowid); + sqlite3_bind_blob(stmtInsert, 2, vec, dims * sizeof(float), SQLITE_TRANSIENT); + sqlite3_step(stmtInsert); + break; + } + } + } + free(vec); + +cleanup: + sqlite3_finalize(stmtInsert); + sqlite3_finalize(stmtDelete); + sqlite3_finalize(stmtKnn); + sqlite3_close(db); + return 0; +} diff --git a/tests/fuzz/diskann-delete-stress.c b/tests/fuzz/diskann-delete-stress.c new file mode 100644 index 0000000..d10a7ff --- /dev/null +++ b/tests/fuzz/diskann-delete-stress.c @@ -0,0 +1,175 @@ +/** + * Fuzz target for DiskANN delete path and graph connectivity maintenance. + * + * The delete path is the most complex graph mutation: + * 1. Read deleted node's neighbor list + * 2. For each neighbor, remove deleted node from their list + * 3. Try to fill the gap with one of deleted node's other neighbors + * 4. Handle medoid deletion (pick new medoid) + * + * Edge cases this targets: + * - Delete the medoid (entry point) -- forces medoid reassignment + * - Delete all nodes except one -- graph degenerates + * - Delete nodes in a chain -- cascading dangling edges + * - Re-insert at deleted rowids -- stale graph edges to old data + * - Delete nonexistent rowids -- should be no-op + * - Insert-delete-insert same rowid rapidly + * - Delete when graph has exactly n_neighbors entries (full nodes) + * + * Key code paths: + * - diskann_delete -> diskann_repair_reverse_edges + * - diskann_medoid_handle_delete + * - diskann_node_clear_neighbor + * - Interaction between delete and concurrent search + */ +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) { + if (*size == 0) return def; + uint8_t b = **data; + (*data)++; + (*size)--; + return b; +} + +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + if (size < 20) return 0; + + int rc; + sqlite3 *db; + rc = sqlite3_open(":memory:", &db); + assert(rc == SQLITE_OK); + rc = sqlite3_vec_init(db, NULL, NULL); + assert(rc == SQLITE_OK); + + /* int8 quantizer to exercise that distance code path */ + uint8_t quant = fuzz_byte(&data, &size, 0) % 2; + const char *qname = quant ? "int8" : "binary"; + + char sql[256]; + snprintf(sql, sizeof(sql), + "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[8] INDEXED BY diskann(neighbor_quantizer=%s, n_neighbors=8))", + qname); + rc = sqlite3_exec(db, sql, NULL, NULL, NULL); + if (rc != SQLITE_OK) { sqlite3_close(db); return 0; } + + sqlite3_stmt *stmtInsert = NULL, *stmtDelete = NULL, *stmtKnn = NULL; + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); + sqlite3_prepare_v2(db, + "DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL); + sqlite3_prepare_v2(db, + "SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?", + -1, &stmtKnn, NULL); + + if (!stmtInsert || !stmtDelete || !stmtKnn) goto cleanup; + + /* Phase 1: Build a graph of exactly n_neighbors+2 = 10 nodes. + * This makes every node nearly full, maximizing the chance that + * inserts trigger the "full node" path in add_reverse_edge. */ + for (int i = 1; i <= 10; i++) { + float vec[8]; + for (int j = 0; j < 8; j++) { + vec[j] = (float)((int8_t)fuzz_byte(&data, &size, (uint8_t)(i*13+j*7))) / 20.0f; + } + sqlite3_reset(stmtInsert); + sqlite3_bind_int64(stmtInsert, 1, i); + sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT); + sqlite3_step(stmtInsert); + } + + /* Phase 2: Fuzz-driven delete-heavy workload */ + while (size >= 2) { + uint8_t op = fuzz_byte(&data, &size, 0); + uint8_t param = fuzz_byte(&data, &size, 0); + + switch (op % 6) { + case 0: /* Delete existing node */ + case 1: { /* (weighted toward deletes) */ + int64_t rowid = (int64_t)(param % 16) + 1; + sqlite3_reset(stmtDelete); + sqlite3_bind_int64(stmtDelete, 1, rowid); + sqlite3_step(stmtDelete); + break; + } + case 2: { /* Delete then immediately re-insert same rowid */ + int64_t rowid = (int64_t)(param % 10) + 1; + sqlite3_reset(stmtDelete); + sqlite3_bind_int64(stmtDelete, 1, rowid); + sqlite3_step(stmtDelete); + + float vec[8]; + for (int j = 0; j < 8; j++) { + vec[j] = (float)((int8_t)fuzz_byte(&data, &size, (uint8_t)(rowid+j))) / 15.0f; + } + sqlite3_reset(stmtInsert); + sqlite3_bind_int64(stmtInsert, 1, rowid); + sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT); + sqlite3_step(stmtInsert); + break; + } + case 3: { /* KNN query on potentially sparse/empty graph */ + float qvec[8]; + for (int j = 0; j < 8; j++) { + qvec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f; + } + int k = (param % 15) + 1; + sqlite3_reset(stmtKnn); + sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_TRANSIENT); + sqlite3_bind_int(stmtKnn, 2, k); + while (sqlite3_step(stmtKnn) == SQLITE_ROW) {} + break; + } + case 4: { /* Insert new node */ + int64_t rowid = (int64_t)(param % 32) + 1; + float vec[8]; + for (int j = 0; j < 8; j++) { + vec[j] = (float)((int8_t)fuzz_byte(&data, &size, 0)) / 10.0f; + } + sqlite3_reset(stmtInsert); + sqlite3_bind_int64(stmtInsert, 1, rowid); + sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT); + sqlite3_step(stmtInsert); + break; + } + case 5: { /* Delete ALL remaining nodes, then insert fresh */ + for (int i = 1; i <= 32; i++) { + sqlite3_reset(stmtDelete); + sqlite3_bind_int64(stmtDelete, 1, i); + sqlite3_step(stmtDelete); + } + /* Now insert one node into empty graph */ + float vec[8] = {1.0f, 0, 0, 0, 0, 0, 0, 0}; + sqlite3_reset(stmtInsert); + sqlite3_bind_int64(stmtInsert, 1, 1); + sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT); + sqlite3_step(stmtInsert); + break; + } + } + } + + /* Final KNN on whatever state the graph is in */ + { + float qvec[8] = {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}; + sqlite3_reset(stmtKnn); + sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_TRANSIENT); + sqlite3_bind_int(stmtKnn, 2, 10); + while (sqlite3_step(stmtKnn) == SQLITE_ROW) {} + } + +cleanup: + sqlite3_finalize(stmtInsert); + sqlite3_finalize(stmtDelete); + sqlite3_finalize(stmtKnn); + sqlite3_close(db); + return 0; +} diff --git a/tests/fuzz/diskann-graph-corrupt.c b/tests/fuzz/diskann-graph-corrupt.c new file mode 100644 index 0000000..a8dbc19 --- /dev/null +++ b/tests/fuzz/diskann-graph-corrupt.c @@ -0,0 +1,123 @@ +/** + * Fuzz target for DiskANN shadow table corruption resilience. + * Creates and populates a DiskANN table, then corrupts shadow table blobs + * using fuzz data and runs queries. + */ +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + if (size < 16) return 0; + + int rc; + sqlite3 *db; + + rc = sqlite3_open(":memory:", &db); + assert(rc == SQLITE_OK); + rc = sqlite3_vec_init(db, NULL, NULL); + assert(rc == SQLITE_OK); + + rc = sqlite3_exec(db, + "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8))", + NULL, NULL, NULL); + if (rc != SQLITE_OK) { sqlite3_close(db); return 0; } + + /* Insert a few vectors to create graph structure */ + { + sqlite3_stmt *stmt; + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmt, NULL); + for (int i = 1; i <= 10; i++) { + float vec[8]; + for (int j = 0; j < 8; j++) { + vec[j] = (float)i * 0.1f + (float)j * 0.01f; + } + sqlite3_reset(stmt); + sqlite3_bind_int64(stmt, 1, i); + sqlite3_bind_blob(stmt, 2, vec, sizeof(vec), SQLITE_TRANSIENT); + sqlite3_step(stmt); + } + sqlite3_finalize(stmt); + } + + /* Corrupt shadow table data using fuzz bytes */ + size_t offset = 0; + + /* Determine which row and column to corrupt */ + int target_row = (data[offset++] % 10) + 1; + int corrupt_type = data[offset++] % 3; /* 0=validity, 1=neighbor_ids, 2=qvecs */ + + const char *column_name; + switch (corrupt_type) { + case 0: column_name = "neighbors_validity"; break; + case 1: column_name = "neighbor_ids"; break; + default: column_name = "neighbor_quantized_vectors"; break; + } + + /* Read the blob, corrupt it, write it back */ + { + sqlite3_stmt *readStmt; + char sqlbuf[256]; + snprintf(sqlbuf, sizeof(sqlbuf), + "SELECT %s FROM v_diskann_nodes00 WHERE rowid = ?", column_name); + rc = sqlite3_prepare_v2(db, sqlbuf, -1, &readStmt, NULL); + if (rc == SQLITE_OK) { + sqlite3_bind_int64(readStmt, 1, target_row); + if (sqlite3_step(readStmt) == SQLITE_ROW) { + const void *blob = sqlite3_column_blob(readStmt, 0); + int blobSize = sqlite3_column_bytes(readStmt, 0); + if (blob && blobSize > 0) { + unsigned char *corrupt = sqlite3_malloc(blobSize); + if (corrupt) { + memcpy(corrupt, blob, blobSize); + /* Apply fuzz bytes as XOR corruption */ + size_t remaining = size - offset; + for (size_t i = 0; i < remaining && i < (size_t)blobSize; i++) { + corrupt[i % blobSize] ^= data[offset + i]; + } + /* Write back */ + sqlite3_stmt *writeStmt; + snprintf(sqlbuf, sizeof(sqlbuf), + "UPDATE v_diskann_nodes00 SET %s = ? WHERE rowid = ?", column_name); + rc = sqlite3_prepare_v2(db, sqlbuf, -1, &writeStmt, NULL); + if (rc == SQLITE_OK) { + sqlite3_bind_blob(writeStmt, 1, corrupt, blobSize, SQLITE_TRANSIENT); + sqlite3_bind_int64(writeStmt, 2, target_row); + sqlite3_step(writeStmt); + sqlite3_finalize(writeStmt); + } + sqlite3_free(corrupt); + } + } + } + sqlite3_finalize(readStmt); + } + } + + /* Run queries on corrupted graph -- should not crash */ + { + float qvec[8] = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + sqlite3_stmt *knnStmt; + rc = sqlite3_prepare_v2(db, + "SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = 5", + -1, &knnStmt, NULL); + if (rc == SQLITE_OK) { + sqlite3_bind_blob(knnStmt, 1, qvec, sizeof(qvec), SQLITE_STATIC); + while (sqlite3_step(knnStmt) == SQLITE_ROW) {} + sqlite3_finalize(knnStmt); + } + } + + /* Full scan */ + sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL); + + sqlite3_close(db); + return 0; +} diff --git a/tests/fuzz/diskann-int8-quant.c b/tests/fuzz/diskann-int8-quant.c new file mode 100644 index 0000000..f1bd31d --- /dev/null +++ b/tests/fuzz/diskann-int8-quant.c @@ -0,0 +1,164 @@ +/** + * Fuzz target for DiskANN int8 quantizer edge cases. + * + * The binary quantizer is simple (sign bit), but the int8 quantizer has + * interesting arithmetic: + * i8_val = (i8)(((src - (-1.0f)) / step) - 128.0f) + * where step = 2.0f / 255.0f + * + * Edge cases in this formula: + * - src values outside [-1, 1] cause clamping issues (no explicit clamp!) + * - src = NaN, +Inf, -Inf (from corrupted vectors or div-by-zero) + * - src very close to boundaries (-1.0, 1.0) -- rounding + * - The cast to i8 can overflow for extreme src values + * + * Also exercises int8 distance functions: + * - distance_l2_sqr_int8: accumulates squared differences, possible overflow + * - distance_cosine_int8: dot product with normalization + * - distance_l1_int8: absolute differences + * + * This fuzzer also tests the cosine distance metric path which the + * other fuzzers (using L2 default) don't cover. + */ +#include +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) { + if (*size == 0) return def; + uint8_t b = **data; + (*data)++; + (*size)--; + return b; +} + +static float fuzz_extreme_float(const uint8_t **data, size_t *size) { + uint8_t mode = fuzz_byte(data, size, 0) % 8; + uint8_t raw = fuzz_byte(data, size, 0); + switch (mode) { + case 0: return (float)((int8_t)raw) / 10.0f; /* Normal range */ + case 1: return (float)((int8_t)raw) * 100.0f; /* Large values */ + case 2: return (float)((int8_t)raw) / 1000.0f; /* Tiny values near 0 */ + case 3: return -1.0f; /* Exact boundary */ + case 4: return 1.0f; /* Exact boundary */ + case 5: return 0.0f; /* Zero */ + case 6: return (float)raw / 255.0f; /* [0, 1] range */ + case 7: return -(float)raw / 255.0f; /* [-1, 0] range */ + } + return 0.0f; +} + +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + if (size < 40) return 0; + + int rc; + sqlite3 *db; + rc = sqlite3_open(":memory:", &db); + assert(rc == SQLITE_OK); + rc = sqlite3_vec_init(db, NULL, NULL); + assert(rc == SQLITE_OK); + + /* Test both distance metrics with int8 quantizer */ + uint8_t metric_choice = fuzz_byte(&data, &size, 0) % 2; + const char *metric = metric_choice ? "cosine" : "L2"; + + int dims = 8 + (fuzz_byte(&data, &size, 0) % 3) * 8; /* 8, 16, or 24 */ + + char sql[512]; + snprintf(sql, sizeof(sql), + "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[%d] distance_metric=%s " + "INDEXED BY diskann(neighbor_quantizer=int8, n_neighbors=8, search_list_size=16))", + dims, metric); + + rc = sqlite3_exec(db, sql, NULL, NULL, NULL); + if (rc != SQLITE_OK) { sqlite3_close(db); return 0; } + + sqlite3_stmt *stmtInsert = NULL, *stmtKnn = NULL, *stmtDelete = NULL; + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); + sqlite3_prepare_v2(db, + "SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = ?", + -1, &stmtKnn, NULL); + sqlite3_prepare_v2(db, + "DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL); + + if (!stmtInsert || !stmtKnn || !stmtDelete) goto cleanup; + + /* Insert vectors with extreme float values to stress quantization */ + float *vec = malloc(dims * sizeof(float)); + if (!vec) goto cleanup; + + for (int i = 1; i <= 16; i++) { + for (int j = 0; j < dims; j++) { + vec[j] = fuzz_extreme_float(&data, &size); + } + sqlite3_reset(stmtInsert); + sqlite3_bind_int64(stmtInsert, 1, i); + sqlite3_bind_blob(stmtInsert, 2, vec, dims * sizeof(float), SQLITE_TRANSIENT); + sqlite3_step(stmtInsert); + } + + /* Fuzz-driven operations */ + while (size >= 2) { + uint8_t op = fuzz_byte(&data, &size, 0) % 4; + uint8_t param = fuzz_byte(&data, &size, 0); + + switch (op) { + case 0: { /* KNN with extreme query values */ + for (int j = 0; j < dims; j++) { + vec[j] = fuzz_extreme_float(&data, &size); + } + int k = (param % 10) + 1; + sqlite3_reset(stmtKnn); + sqlite3_bind_blob(stmtKnn, 1, vec, dims * sizeof(float), SQLITE_TRANSIENT); + sqlite3_bind_int(stmtKnn, 2, k); + while (sqlite3_step(stmtKnn) == SQLITE_ROW) {} + break; + } + case 1: { /* Insert with extreme values */ + int64_t rowid = (int64_t)(param % 32) + 1; + for (int j = 0; j < dims; j++) { + vec[j] = fuzz_extreme_float(&data, &size); + } + sqlite3_reset(stmtInsert); + sqlite3_bind_int64(stmtInsert, 1, rowid); + sqlite3_bind_blob(stmtInsert, 2, vec, dims * sizeof(float), SQLITE_TRANSIENT); + sqlite3_step(stmtInsert); + break; + } + case 2: { /* Delete */ + int64_t rowid = (int64_t)(param % 32) + 1; + sqlite3_reset(stmtDelete); + sqlite3_bind_int64(stmtDelete, 1, rowid); + sqlite3_step(stmtDelete); + break; + } + case 3: { /* KNN with all-zero or all-same-value query */ + float val = (param % 3 == 0) ? 0.0f : + (param % 3 == 1) ? 1.0f : -1.0f; + for (int j = 0; j < dims; j++) vec[j] = val; + sqlite3_reset(stmtKnn); + sqlite3_bind_blob(stmtKnn, 1, vec, dims * sizeof(float), SQLITE_TRANSIENT); + sqlite3_bind_int(stmtKnn, 2, 5); + while (sqlite3_step(stmtKnn) == SQLITE_ROW) {} + break; + } + } + } + + free(vec); + +cleanup: + sqlite3_finalize(stmtInsert); + sqlite3_finalize(stmtKnn); + sqlite3_finalize(stmtDelete); + sqlite3_close(db); + return 0; +} diff --git a/tests/fuzz/diskann-operations.c b/tests/fuzz/diskann-operations.c new file mode 100644 index 0000000..b36620b --- /dev/null +++ b/tests/fuzz/diskann-operations.c @@ -0,0 +1,100 @@ +/** + * Fuzz target for DiskANN insert/delete/query operation sequences. + * Uses fuzz bytes to drive random operations on a DiskANN-indexed table. + */ +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + if (size < 6) return 0; + + int rc; + sqlite3 *db; + sqlite3_stmt *stmtInsert = NULL; + sqlite3_stmt *stmtDelete = NULL; + sqlite3_stmt *stmtKnn = NULL; + sqlite3_stmt *stmtScan = NULL; + + rc = sqlite3_open(":memory:", &db); + assert(rc == SQLITE_OK); + rc = sqlite3_vec_init(db, NULL, NULL); + assert(rc == SQLITE_OK); + + rc = sqlite3_exec(db, + "CREATE VIRTUAL TABLE v USING vec0(" + "emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8))", + NULL, NULL, NULL); + if (rc != SQLITE_OK) { sqlite3_close(db); return 0; } + + sqlite3_prepare_v2(db, + "INSERT INTO v(rowid, emb) VALUES (?, ?)", -1, &stmtInsert, NULL); + sqlite3_prepare_v2(db, + "DELETE FROM v WHERE rowid = ?", -1, &stmtDelete, NULL); + sqlite3_prepare_v2(db, + "SELECT rowid, distance FROM v WHERE emb MATCH ? AND k = 3", + -1, &stmtKnn, NULL); + sqlite3_prepare_v2(db, + "SELECT rowid FROM v", -1, &stmtScan, NULL); + + if (!stmtInsert || !stmtDelete || !stmtKnn || !stmtScan) goto cleanup; + + size_t i = 0; + while (i + 2 <= size) { + uint8_t op = data[i++] % 4; + uint8_t rowid_byte = data[i++]; + int64_t rowid = (int64_t)(rowid_byte % 32) + 1; + + switch (op) { + case 0: { + /* INSERT: consume 32 bytes for 8 floats, or use what's left */ + float vec[8] = {0}; + for (int j = 0; j < 8 && i < size; j++, i++) { + vec[j] = (float)((int8_t)data[i]) / 10.0f; + } + sqlite3_reset(stmtInsert); + sqlite3_bind_int64(stmtInsert, 1, rowid); + sqlite3_bind_blob(stmtInsert, 2, vec, sizeof(vec), SQLITE_TRANSIENT); + sqlite3_step(stmtInsert); + break; + } + case 1: { + /* DELETE */ + sqlite3_reset(stmtDelete); + sqlite3_bind_int64(stmtDelete, 1, rowid); + sqlite3_step(stmtDelete); + break; + } + case 2: { + /* KNN query */ + float qvec[8] = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + sqlite3_reset(stmtKnn); + sqlite3_bind_blob(stmtKnn, 1, qvec, sizeof(qvec), SQLITE_STATIC); + while (sqlite3_step(stmtKnn) == SQLITE_ROW) {} + break; + } + case 3: { + /* Full scan */ + sqlite3_reset(stmtScan); + while (sqlite3_step(stmtScan) == SQLITE_ROW) {} + break; + } + } + } + + /* Final operations -- must not crash regardless of prior state */ + sqlite3_exec(db, "SELECT * FROM v", NULL, NULL, NULL); + +cleanup: + sqlite3_finalize(stmtInsert); + sqlite3_finalize(stmtDelete); + sqlite3_finalize(stmtKnn); + sqlite3_finalize(stmtScan); + sqlite3_close(db); + return 0; +} diff --git a/tests/fuzz/diskann-prune-direct.c b/tests/fuzz/diskann-prune-direct.c new file mode 100644 index 0000000..7a440ad --- /dev/null +++ b/tests/fuzz/diskann-prune-direct.c @@ -0,0 +1,131 @@ +/** + * Fuzz target for DiskANN RobustPrune algorithm (diskann_prune_select). + * + * diskann_prune_select is exposed for testing and takes: + * - inter_distances: flattened NxN matrix of inter-candidate distances + * - p_distances: N distances from node p to each candidate + * - num_candidates, alpha, max_neighbors + * + * This is a pure function that doesn't need a database, so we can + * call it directly with fuzz-controlled inputs. This gives the fuzzer + * maximum speed (no SQLite overhead) to explore: + * + * - alpha boundary: alpha=0 (prunes nothing), alpha=very large (prunes all) + * - max_neighbors = 0, 1, num_candidates, > num_candidates + * - num_candidates = 0, 1, large + * - Distance matrices with: all zeros, all same, negative values, NaN, Inf + * - Non-symmetric distance matrices (should still work) + * - Memory: large num_candidates to stress malloc + * + * Key code paths: + * - diskann_prune_select alpha-pruning loop + * - Boundary: selectedCount reaches max_neighbors exactly + * - All candidates pruned before max_neighbors reached + */ +#include +#include +#include +#include +#include +#include +#include "sqlite-vec.h" +#include "sqlite3.h" +#include + +/* Declare the test-exposed function. + * diskann_prune_select is not static -- it's a public symbol. */ +extern int diskann_prune_select( + const float *inter_distances, const float *p_distances, + int num_candidates, float alpha, int max_neighbors, + int *outSelected, int *outCount); + +static uint8_t fuzz_byte(const uint8_t **data, size_t *size, uint8_t def) { + if (*size == 0) return def; + uint8_t b = **data; + (*data)++; + (*size)--; + return b; +} + +int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { + if (size < 8) return 0; + + /* Consume parameters from fuzz data */ + int num_candidates = fuzz_byte(&data, &size, 0) % 33; /* 0..32 */ + int max_neighbors = fuzz_byte(&data, &size, 0) % 17; /* 0..16 */ + + /* Alpha: pick from interesting values */ + uint8_t alpha_idx = fuzz_byte(&data, &size, 0) % 8; + float alpha_values[] = {0.0f, 0.5f, 1.0f, 1.2f, 1.5f, 2.0f, 10.0f, 100.0f}; + float alpha = alpha_values[alpha_idx]; + + if (num_candidates == 0) { + /* Test empty case */ + int outCount = -1; + int rc = diskann_prune_select(NULL, NULL, 0, alpha, max_neighbors, + NULL, &outCount); + assert(rc == 0 /* SQLITE_OK */); + assert(outCount == 0); + return 0; + } + + /* Allocate arrays */ + int n = num_candidates; + float *inter_distances = malloc(n * n * sizeof(float)); + float *p_distances = malloc(n * sizeof(float)); + int *outSelected = malloc(n * sizeof(int)); + if (!inter_distances || !p_distances || !outSelected) { + free(inter_distances); + free(p_distances); + free(outSelected); + return 0; + } + + /* Fill p_distances from fuzz data (sorted ascending for correct input) */ + for (int i = 0; i < n; i++) { + uint8_t raw = fuzz_byte(&data, &size, (uint8_t)(i * 10)); + p_distances[i] = (float)raw / 10.0f; + } + /* Sort p_distances ascending (prune_select expects sorted input) */ + for (int i = 1; i < n; i++) { + float tmp = p_distances[i]; + int j = i - 1; + while (j >= 0 && p_distances[j] > tmp) { + p_distances[j + 1] = p_distances[j]; + j--; + } + p_distances[j + 1] = tmp; + } + + /* Fill inter-distance matrix from fuzz data */ + for (int i = 0; i < n * n; i++) { + uint8_t raw = fuzz_byte(&data, &size, (uint8_t)(i % 256)); + inter_distances[i] = (float)raw / 10.0f; + } + /* Make diagonal zero */ + for (int i = 0; i < n; i++) { + inter_distances[i * n + i] = 0.0f; + } + + int outCount = -1; + int rc = diskann_prune_select(inter_distances, p_distances, + n, alpha, max_neighbors, + outSelected, &outCount); + /* Basic sanity: should not crash, count should be valid */ + assert(rc == 0); + assert(outCount >= 0); + assert(outCount <= max_neighbors || max_neighbors == 0); + assert(outCount <= n); + + /* Verify outSelected flags are consistent with outCount */ + int flagCount = 0; + for (int i = 0; i < n; i++) { + if (outSelected[i]) flagCount++; + } + assert(flagCount == outCount); + + free(inter_distances); + free(p_distances); + free(outSelected); + return 0; +} diff --git a/tests/fuzz/diskann.dict b/tests/fuzz/diskann.dict new file mode 100644 index 0000000..31d289d --- /dev/null +++ b/tests/fuzz/diskann.dict @@ -0,0 +1,10 @@ +"neighbor_quantizer" +"binary" +"int8" +"n_neighbors" +"search_list_size" +"search_list_size_search" +"search_list_size_insert" +"alpha" +"=" +"," diff --git a/tests/sqlite-vec-internal.h b/tests/sqlite-vec-internal.h index 67f1370..313add4 100644 --- a/tests/sqlite-vec-internal.h +++ b/tests/sqlite-vec-internal.h @@ -73,6 +73,7 @@ enum Vec0IndexType { VEC0_INDEX_TYPE_RESCORE = 2, #endif VEC0_INDEX_TYPE_IVF = 3, + VEC0_INDEX_TYPE_DISKANN = 4, }; enum Vec0RescoreQuantizerType { @@ -114,6 +115,20 @@ struct Vec0RescoreConfig { }; #endif +enum Vec0DiskannQuantizerType { + VEC0_DISKANN_QUANTIZER_BINARY = 1, + VEC0_DISKANN_QUANTIZER_INT8 = 2, +}; + +struct Vec0DiskannConfig { + enum Vec0DiskannQuantizerType quantizer_type; + int n_neighbors; + int search_list_size; + int search_list_size_search; + int search_list_size_insert; + float alpha; + int buffer_threshold; +}; struct VectorColumnDefinition { char *name; @@ -126,6 +141,7 @@ struct VectorColumnDefinition { struct Vec0RescoreConfig rescore; #endif struct Vec0IvfConfig ivf; + struct Vec0DiskannConfig diskann; }; int vec0_parse_vector_column(const char *source, int source_length, @@ -136,6 +152,48 @@ int vec0_parse_partition_key_definition(const char *source, int source_length, int *out_column_name_length, int *out_column_type); +size_t diskann_quantized_vector_byte_size( + enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions); + +int diskann_validity_byte_size(int n_neighbors); +size_t diskann_neighbor_ids_byte_size(int n_neighbors); +size_t diskann_neighbor_qvecs_byte_size( + int n_neighbors, enum Vec0DiskannQuantizerType quantizer_type, + size_t dimensions); +int diskann_node_init( + int n_neighbors, enum Vec0DiskannQuantizerType quantizer_type, + size_t dimensions, + unsigned char **outValidity, int *outValiditySize, + unsigned char **outNeighborIds, int *outNeighborIdsSize, + unsigned char **outNeighborQvecs, int *outNeighborQvecsSize); +int diskann_validity_get(const unsigned char *validity, int i); +void diskann_validity_set(unsigned char *validity, int i, int value); +int diskann_validity_count(const unsigned char *validity, int n_neighbors); +long long diskann_neighbor_id_get(const unsigned char *neighbor_ids, int i); +void diskann_neighbor_id_set(unsigned char *neighbor_ids, int i, long long rowid); +const unsigned char *diskann_neighbor_qvec_get( + const unsigned char *qvecs, int i, + enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions); +void diskann_neighbor_qvec_set( + unsigned char *qvecs, int i, const unsigned char *src_qvec, + enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions); +void diskann_node_set_neighbor( + unsigned char *validity, unsigned char *neighbor_ids, unsigned char *qvecs, int i, + long long neighbor_rowid, const unsigned char *neighbor_qvec, + enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions); +void diskann_node_clear_neighbor( + unsigned char *validity, unsigned char *neighbor_ids, unsigned char *qvecs, int i, + enum Vec0DiskannQuantizerType quantizer_type, size_t dimensions); +int diskann_quantize_vector( + const float *src, size_t dimensions, + enum Vec0DiskannQuantizerType quantizer_type, + unsigned char *out); + +int diskann_prune_select( + const float *inter_distances, const float *p_distances, + int num_candidates, float alpha, int max_neighbors, + int *outSelected, int *outCount); + #ifdef SQLITE_VEC_TEST float _test_distance_l2_sqr_float(const float *a, const float *b, size_t dims); float _test_distance_cosine_float(const float *a, const float *b, size_t dims); @@ -151,6 +209,33 @@ size_t _test_rescore_quantized_byte_size_int8(size_t dimensions); void ivf_quantize_int8(const float *src, int8_t *dst, int D); void ivf_quantize_binary(const float *src, uint8_t *dst, int D); #endif +// DiskANN candidate list (opaque struct, use accessors) +struct DiskannCandidateList { + void *items; // opaque + int count; + int capacity; +}; + +int _test_diskann_candidate_list_init(struct DiskannCandidateList *list, int capacity); +void _test_diskann_candidate_list_free(struct DiskannCandidateList *list); +int _test_diskann_candidate_list_insert(struct DiskannCandidateList *list, long long rowid, float distance); +int _test_diskann_candidate_list_next_unvisited(const struct DiskannCandidateList *list); +int _test_diskann_candidate_list_count(const struct DiskannCandidateList *list); +long long _test_diskann_candidate_list_rowid(const struct DiskannCandidateList *list, int i); +float _test_diskann_candidate_list_distance(const struct DiskannCandidateList *list, int i); +void _test_diskann_candidate_list_set_visited(struct DiskannCandidateList *list, int i); + +// DiskANN visited set (opaque struct, use accessors) +struct DiskannVisitedSet { + void *slots; // opaque + int capacity; + int count; +}; + +int _test_diskann_visited_set_init(struct DiskannVisitedSet *set, int capacity); +void _test_diskann_visited_set_free(struct DiskannVisitedSet *set); +int _test_diskann_visited_set_contains(const struct DiskannVisitedSet *set, long long rowid); +int _test_diskann_visited_set_insert(struct DiskannVisitedSet *set, long long rowid); #endif #endif /* SQLITE_VEC_INTERNAL_H */ diff --git a/tests/test-diskann.py b/tests/test-diskann.py new file mode 100644 index 0000000..4c049ce --- /dev/null +++ b/tests/test-diskann.py @@ -0,0 +1,1160 @@ +import sqlite3 +import struct +import pytest +from helpers import _f32, exec + + +def test_diskann_create_basic(db): + """Basic DiskANN table creation with binary quantizer should succeed.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[128] INDEXED BY diskann(neighbor_quantizer=binary) + ) + """) + # Table should exist + tables = [ + row[0] + for row in db.execute( + "select name from sqlite_master where name like 't%' order by 1" + ).fetchall() + ] + assert "t" in tables + + +def test_diskann_create_int8_quantizer(db): + """DiskANN with int8 quantizer should succeed.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[64] INDEXED BY diskann(neighbor_quantizer=int8) + ) + """) + tables = [ + row[0] + for row in db.execute( + "select name from sqlite_master where name like 't%' order by 1" + ).fetchall() + ] + assert "t" in tables + + +def test_diskann_create_with_options(db): + """DiskANN with custom n_neighbors and search_list_size.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[128] INDEXED BY diskann( + neighbor_quantizer=binary, + n_neighbors=48, + search_list_size=256 + ) + ) + """) + tables = [ + row[0] + for row in db.execute( + "select name from sqlite_master where name like 't%' order by 1" + ).fetchall() + ] + assert "t" in tables + + +def test_diskann_create_with_distance_metric(db): + """DiskANN combined with distance_metric should work.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[128] distance_metric=cosine INDEXED BY diskann(neighbor_quantizer=binary) + ) + """) + tables = [ + row[0] + for row in db.execute( + "select name from sqlite_master where name like 't%' order by 1" + ).fetchall() + ] + assert "t" in tables + + +def test_diskann_create_error_missing_quantizer(db): + """Error when neighbor_quantizer is not specified.""" + result = exec(db, """ + CREATE VIRTUAL TABLE t USING vec0( + emb float[128] INDEXED BY diskann(n_neighbors=72) + ) + """) + assert "error" in result + + +def test_diskann_create_error_empty_parens(db): + """Error on empty parens.""" + result = exec(db, """ + CREATE VIRTUAL TABLE t USING vec0( + emb float[128] INDEXED BY diskann() + ) + """) + assert "error" in result + + +def test_diskann_create_error_unknown_quantizer(db): + """Error on unknown quantizer type.""" + result = exec(db, """ + CREATE VIRTUAL TABLE t USING vec0( + emb float[128] INDEXED BY diskann(neighbor_quantizer=unknown) + ) + """) + assert "error" in result + + +def test_diskann_create_error_bit_column(db): + """Error: DiskANN not supported on bit vector columns.""" + result = exec(db, """ + CREATE VIRTUAL TABLE t USING vec0( + emb bit[128] INDEXED BY diskann(neighbor_quantizer=binary) + ) + """) + assert "error" in result + assert "bit" in result["message"].lower() or "DiskANN" in result["message"] + + +def test_diskann_create_error_binary_quantizer_odd_dims(db): + """Error: binary quantizer requires dimensions divisible by 8.""" + result = exec(db, """ + CREATE VIRTUAL TABLE t USING vec0( + emb float[13] INDEXED BY diskann(neighbor_quantizer=binary) + ) + """) + assert "error" in result + assert "divisible" in result["message"].lower() + + +def test_diskann_create_error_bad_n_neighbors(db): + """Error: n_neighbors must be divisible by 8.""" + result = exec(db, """ + CREATE VIRTUAL TABLE t USING vec0( + emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=13) + ) + """) + assert "error" in result + + +def test_diskann_shadow_tables_created(db): + """DiskANN table should create _vectors00 and _diskann_nodes00 shadow tables.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[64] INDEXED BY diskann(neighbor_quantizer=binary) + ) + """) + tables = sorted([ + row[0] + for row in db.execute( + "select name from sqlite_master where type='table' and name like 't_%' order by 1" + ).fetchall() + ]) + assert "t_vectors00" in tables + assert "t_diskann_nodes00" in tables + # DiskANN columns should NOT have _vector_chunks00 + assert "t_vector_chunks00" not in tables + + +def test_diskann_medoid_in_info(db): + """_info table should contain diskann_medoid_00 key with NULL value.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[64] INDEXED BY diskann(neighbor_quantizer=binary) + ) + """) + row = db.execute( + "SELECT key, value FROM t_info WHERE key = 'diskann_medoid_00'" + ).fetchone() + assert row is not None + assert row[0] == "diskann_medoid_00" + assert row[1] is None + + +def test_non_diskann_no_extra_tables(db): + """Non-DiskANN table must NOT create _vectors or _diskann_nodes tables.""" + db.execute("CREATE VIRTUAL TABLE t USING vec0(emb float[64])") + tables = [ + row[0] + for row in db.execute( + "select name from sqlite_master where type='table' and name like 't_%' order by 1" + ).fetchall() + ] + assert "t_vectors00" not in tables + assert "t_diskann_nodes00" not in tables + assert "t_vector_chunks00" in tables + + +def test_diskann_medoid_initial_null(db): + """Medoid should be NULL initially (empty graph).""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[64] INDEXED BY diskann(neighbor_quantizer=binary) + ) + """) + row = db.execute( + "SELECT value FROM t_info WHERE key = 'diskann_medoid_00'" + ).fetchone() + assert row[0] is None + + +def test_diskann_medoid_set_via_info(db): + """Setting medoid via _info table should be retrievable.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[64] INDEXED BY diskann(neighbor_quantizer=binary) + ) + """) + # Manually set medoid to simulate first insert + db.execute("UPDATE t_info SET value = 42 WHERE key = 'diskann_medoid_00'") + row = db.execute( + "SELECT value FROM t_info WHERE key = 'diskann_medoid_00'" + ).fetchone() + assert row[0] == 42 + + # Reset to NULL (empty graph) + db.execute("UPDATE t_info SET value = NULL WHERE key = 'diskann_medoid_00'") + row = db.execute( + "SELECT value FROM t_info WHERE key = 'diskann_medoid_00'" + ).fetchone() + assert row[0] is None + + +def test_diskann_single_insert(db): + """Insert 1 vector. Verify _vectors00, _diskann_nodes00, and medoid.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary) + ) + """) + db.execute( + "INSERT INTO t(rowid, emb) VALUES (1, ?)", + [_f32([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ) + # Verify _vectors00 has 1 row + count = db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] + assert count == 1 + + # Verify _diskann_nodes00 has 1 row + count = db.execute("SELECT count(*) FROM t_diskann_nodes00").fetchone()[0] + assert count == 1 + + # Verify medoid is set + medoid = db.execute( + "SELECT value FROM t_info WHERE key = 'diskann_medoid_00'" + ).fetchone()[0] + assert medoid == 1 + + +def test_diskann_multiple_inserts(db): + """Insert multiple vectors. Verify counts and that nodes have neighbors.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8) + ) + """) + import random + random.seed(42) + for i in range(1, 21): + vec = [random.gauss(0, 1) for _ in range(8)] + db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32(vec)]) + + # Verify counts + assert db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] == 20 + assert db.execute("SELECT count(*) FROM t_diskann_nodes00").fetchone()[0] == 20 + + # Every node after the first should have at least 1 neighbor + rows = db.execute( + "SELECT rowid, neighbors_validity FROM t_diskann_nodes00" + ).fetchall() + nodes_with_neighbors = 0 + for row in rows: + validity = row[1] + has_neighbor = any(b != 0 for b in validity) + if has_neighbor: + nodes_with_neighbors += 1 + # At minimum, nodes 2-20 should have neighbors (node 1 gets neighbors via reverse edges) + assert nodes_with_neighbors >= 19 + + +def test_diskann_bidirectional_edges(db): + """Insert A then B. B should be in A's neighbors and A in B's.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8) + ) + """) + db.execute( + "INSERT INTO t(rowid, emb) VALUES (1, ?)", + [_f32([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ) + db.execute( + "INSERT INTO t(rowid, emb) VALUES (2, ?)", + [_f32([0.9, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])], + ) + + # Check B(2) is in A(1)'s neighbor list + row_a = db.execute( + "SELECT neighbor_ids FROM t_diskann_nodes00 WHERE rowid = 1" + ).fetchone() + neighbor_ids_a = struct.unpack(f"{len(row_a[0])//8}q", row_a[0]) + assert 2 in neighbor_ids_a + + # Check A(1) is in B(2)'s neighbor list + row_b = db.execute( + "SELECT neighbor_ids FROM t_diskann_nodes00 WHERE rowid = 2" + ).fetchone() + neighbor_ids_b = struct.unpack(f"{len(row_b[0])//8}q", row_b[0]) + assert 1 in neighbor_ids_b + + +def test_diskann_delete_single(db): + """Insert 3, delete 1. Verify counts.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8) + ) + """) + for i in range(1, 4): + db.execute( + "INSERT INTO t(rowid, emb) VALUES (?, ?)", + [i, _f32([float(i)] * 8)], + ) + db.execute("DELETE FROM t WHERE rowid = 2") + + assert db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] == 2 + assert db.execute("SELECT count(*) FROM t_diskann_nodes00").fetchone()[0] == 2 + + +def test_diskann_delete_no_stale_references(db): + """After delete, no node should reference the deleted rowid.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8) + ) + """) + import random + random.seed(123) + for i in range(1, 11): + vec = [random.gauss(0, 1) for _ in range(8)] + db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32(vec)]) + + db.execute("DELETE FROM t WHERE rowid = 5") + + # Scan all remaining nodes and verify rowid 5 is not in any neighbor list + rows = db.execute( + "SELECT rowid, neighbors_validity, neighbor_ids FROM t_diskann_nodes00" + ).fetchall() + for row in rows: + validity = row[1] + neighbor_ids_blob = row[2] + n_neighbors = len(validity) * 8 + ids = struct.unpack(f"{n_neighbors}q", neighbor_ids_blob) + for i in range(n_neighbors): + byte_idx = i // 8 + bit_idx = i % 8 + if validity[byte_idx] & (1 << bit_idx): + assert ids[i] != 5, f"Node {row[0]} still references deleted rowid 5" + + +def test_diskann_delete_medoid(db): + """Delete the medoid. Verify a new non-NULL medoid is selected.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8) + ) + """) + for i in range(1, 4): + db.execute( + "INSERT INTO t(rowid, emb) VALUES (?, ?)", + [i, _f32([float(i)] * 8)], + ) + + medoid_before = db.execute( + "SELECT value FROM t_info WHERE key = 'diskann_medoid_00'" + ).fetchone()[0] + assert medoid_before == 1 + + db.execute("DELETE FROM t WHERE rowid = 1") + + medoid_after = db.execute( + "SELECT value FROM t_info WHERE key = 'diskann_medoid_00'" + ).fetchone()[0] + assert medoid_after is not None + assert medoid_after != 1 + + +def test_diskann_delete_all(db): + """Delete all vectors. Medoid should be NULL.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8) + ) + """) + for i in range(1, 4): + db.execute( + "INSERT INTO t(rowid, emb) VALUES (?, ?)", + [i, _f32([float(i)] * 8)], + ) + for i in range(1, 4): + db.execute("DELETE FROM t WHERE rowid = ?", [i]) + + assert db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] == 0 + assert db.execute("SELECT count(*) FROM t_diskann_nodes00").fetchone()[0] == 0 + + medoid = db.execute( + "SELECT value FROM t_info WHERE key = 'diskann_medoid_00'" + ).fetchone()[0] + assert medoid is None + + +def test_diskann_insert_delete_insert_cycle(db): + """Insert, delete, insert again. No crashes.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8) + ) + """) + db.execute("INSERT INTO t(rowid, emb) VALUES (1, ?)", [_f32([1.0] * 8)]) + db.execute("INSERT INTO t(rowid, emb) VALUES (2, ?)", [_f32([2.0] * 8)]) + db.execute("DELETE FROM t WHERE rowid = 1") + db.execute("INSERT INTO t(rowid, emb) VALUES (3, ?)", [_f32([3.0] * 8)]) + + assert db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] == 2 + assert db.execute("SELECT count(*) FROM t_diskann_nodes00").fetchone()[0] == 2 + + +def test_diskann_knn_basic(db): + """Basic KNN query should return results.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8) + ) + """) + db.execute("INSERT INTO t(rowid, emb) VALUES (1, ?)", [_f32([1, 0, 0, 0, 0, 0, 0, 0])]) + db.execute("INSERT INTO t(rowid, emb) VALUES (2, ?)", [_f32([0, 1, 0, 0, 0, 0, 0, 0])]) + db.execute("INSERT INTO t(rowid, emb) VALUES (3, ?)", [_f32([0.9, 0.1, 0, 0, 0, 0, 0, 0])]) + + rows = db.execute( + "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=2", + [_f32([1, 0, 0, 0, 0, 0, 0, 0])], + ).fetchall() + assert len(rows) == 2 + # Closest should be rowid 1 (exact match) + assert rows[0][0] == 1 + assert rows[0][1] < 0.01 # ~0 distance + + +def test_diskann_knn_distances_sorted(db): + """Returned distances should be in ascending order.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=16) + ) + """) + import random + random.seed(42) + for i in range(1, 51): + vec = [random.gauss(0, 1) for _ in range(8)] + db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32(vec)]) + + rows = db.execute( + "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=10", + [_f32([0.0] * 8)], + ).fetchall() + assert len(rows) == 10 + distances = [r[1] for r in rows] + for i in range(len(distances) - 1): + assert distances[i] <= distances[i + 1], f"Distances not sorted at index {i}" + + +def test_diskann_knn_empty_table(db): + """KNN on empty table should return 0 results.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8) + ) + """) + rows = db.execute( + "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=5", + [_f32([1, 0, 0, 0, 0, 0, 0, 0])], + ).fetchall() + assert len(rows) == 0 + + +def test_diskann_knn_after_delete(db): + """KNN after delete should not return deleted rows.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8) + ) + """) + db.execute("INSERT INTO t(rowid, emb) VALUES (1, ?)", [_f32([1, 0, 0, 0, 0, 0, 0, 0])]) + db.execute("INSERT INTO t(rowid, emb) VALUES (2, ?)", [_f32([0, 1, 0, 0, 0, 0, 0, 0])]) + db.execute("INSERT INTO t(rowid, emb) VALUES (3, ?)", [_f32([0.5, 0.5, 0, 0, 0, 0, 0, 0])]) + db.execute("DELETE FROM t WHERE rowid = 1") + + rows = db.execute( + "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=3", + [_f32([1, 0, 0, 0, 0, 0, 0, 0])], + ).fetchall() + rowids = [r[0] for r in rows] + assert 1 not in rowids + assert len(rows) == 2 + + +def test_diskann_no_index_still_works(db): + """Tables without INDEXED BY should still work identically.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[4] + ) + """) + db.execute("INSERT INTO t(rowid, emb) VALUES (1, ?)", [_f32([1, 2, 3, 4])]) + rows = db.execute( + "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=1", + [_f32([1, 2, 3, 4])], + ).fetchall() + assert len(rows) == 1 + assert rows[0][0] == 1 + + +def test_diskann_drop_table(db): + """DROP TABLE should clean up all shadow tables.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[128] INDEXED BY diskann(neighbor_quantizer=binary) + ) + """) + db.execute("DROP TABLE t") + tables = [ + row[0] + for row in db.execute( + "select name from sqlite_master where name like 't%'" + ).fetchall() + ] + assert len(tables) == 0 + + +def test_diskann_create_split_search_list_size(db): + """DiskANN with separate search_list_size_search and search_list_size_insert.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[128] INDEXED BY diskann( + neighbor_quantizer=binary, + search_list_size_search=256, + search_list_size_insert=64 + ) + ) + """) + tables = [ + row[0] + for row in db.execute( + "select name from sqlite_master where name like 't%' order by 1" + ).fetchall() + ] + assert "t" in tables + + +def test_diskann_create_error_mixed_search_list_size(db): + """Error when mixing search_list_size with search_list_size_search.""" + result = exec(db, """ + CREATE VIRTUAL TABLE t USING vec0( + emb float[128] INDEXED BY diskann( + neighbor_quantizer=binary, + search_list_size=128, + search_list_size_search=256 + ) + ) + """) + assert "error" in result + + +def test_diskann_command_search_list_size(db): + """Runtime search_list_size override via command insert.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[64] INDEXED BY diskann(neighbor_quantizer=binary) + ) + """) + import struct, random + random.seed(42) + for i in range(20): + vec = struct.pack("64f", *[random.random() for _ in range(64)]) + db.execute("INSERT INTO t(emb) VALUES (?)", [vec]) + + # Query with default search_list_size + query = struct.pack("64f", *[random.random() for _ in range(64)]) + results_before = db.execute( + "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k = 5", [query] + ).fetchall() + assert len(results_before) == 5 + + # Override search_list_size_search at runtime + db.execute("INSERT INTO t(rowid) VALUES ('search_list_size_search=256')") + + # Query should still work + results_after = db.execute( + "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k = 5", [query] + ).fetchall() + assert len(results_after) == 5 + + # Override search_list_size_insert at runtime + db.execute("INSERT INTO t(rowid) VALUES ('search_list_size_insert=32')") + + # Inserts should still work + vec = struct.pack("64f", *[random.random() for _ in range(64)]) + db.execute("INSERT INTO t(emb) VALUES (?)", [vec]) + + # Override unified search_list_size + db.execute("INSERT INTO t(rowid) VALUES ('search_list_size=64')") + + results_final = db.execute( + "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k = 5", [query] + ).fetchall() + assert len(results_final) == 5 + + +def test_diskann_command_search_list_size_error(db): + """Error on invalid search_list_size command value.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[64] INDEXED BY diskann(neighbor_quantizer=binary) + ) + """) + result = exec(db, "INSERT INTO t(rowid) VALUES ('search_list_size=0')") + assert "error" in result + result = exec(db, "INSERT INTO t(rowid) VALUES ('search_list_size=-1')") + assert "error" in result + + +# ====================================================================== +# Error cases: DiskANN + auxiliary/metadata/partition columns +# ====================================================================== + +def test_diskann_create_error_with_auxiliary_column(db): + """DiskANN tables should not support auxiliary columns.""" + result = exec(db, """ + CREATE VIRTUAL TABLE t USING vec0( + emb float[64] INDEXED BY diskann(neighbor_quantizer=binary), + +extra text + ) + """) + assert "error" in result + assert "auxiliary" in result["message"].lower() or "Auxiliary" in result["message"] + + +def test_diskann_create_error_with_metadata_column(db): + """DiskANN tables should not support metadata columns.""" + result = exec(db, """ + CREATE VIRTUAL TABLE t USING vec0( + emb float[64] INDEXED BY diskann(neighbor_quantizer=binary), + metadata_col integer metadata + ) + """) + assert "error" in result + assert "metadata" in result["message"].lower() or "Metadata" in result["message"] + + +def test_diskann_create_error_with_partition_key(db): + """DiskANN tables should not support partition key columns.""" + result = exec(db, """ + CREATE VIRTUAL TABLE t USING vec0( + emb float[64] INDEXED BY diskann(neighbor_quantizer=binary), + user_id text partition key + ) + """) + assert "error" in result + assert "partition" in result["message"].lower() or "Partition" in result["message"] + + +# ====================================================================== +# Insert edge cases +# ====================================================================== + +def test_diskann_insert_no_rowid(db): + """INSERT without explicit rowid (auto-generated) should work.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary) + ) + """) + db.execute("INSERT INTO t(emb) VALUES (?)", [_f32([1.0] * 8)]) + db.execute("INSERT INTO t(emb) VALUES (?)", [_f32([2.0] * 8)]) + assert db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] == 2 + assert db.execute("SELECT count(*) FROM t_diskann_nodes00").fetchone()[0] == 2 + + +def test_diskann_insert_large_batch(db): + """INSERT 500+ vectors, verify all are queryable via KNN.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[16] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=16) + ) + """) + import random + random.seed(99) + N = 500 + for i in range(1, N + 1): + vec = [random.gauss(0, 1) for _ in range(16)] + db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32(vec)]) + + assert db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] == N + assert db.execute("SELECT count(*) FROM t_diskann_nodes00").fetchone()[0] == N + + # KNN should return results + query = [random.gauss(0, 1) for _ in range(16)] + rows = db.execute( + "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=10", + [_f32(query)], + ).fetchall() + assert len(rows) == 10 + # Distances should be sorted + distances = [r[1] for r in rows] + for i in range(len(distances) - 1): + assert distances[i] <= distances[i + 1] + + +def test_diskann_insert_zero_vector(db): + """Insert an all-zero vector (edge case for binary quantizer).""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary) + ) + """) + db.execute("INSERT INTO t(rowid, emb) VALUES (1, ?)", [_f32([0.0] * 8)]) + db.execute("INSERT INTO t(rowid, emb) VALUES (2, ?)", [_f32([1.0] * 8)]) + count = db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] + assert count == 2 + + # Query with zero vector should find rowid 1 as closest + rows = db.execute( + "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=2", + [_f32([0.0] * 8)], + ).fetchall() + assert len(rows) == 2 + assert rows[0][0] == 1 + + +def test_diskann_insert_large_values(db): + """Insert vectors with very large float values.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary) + ) + """) + import sys + large = sys.float_info.max / 1e300 # Large but not overflowing + db.execute("INSERT INTO t(rowid, emb) VALUES (1, ?)", [_f32([large] * 8)]) + db.execute("INSERT INTO t(rowid, emb) VALUES (2, ?)", [_f32([-large] * 8)]) + db.execute("INSERT INTO t(rowid, emb) VALUES (3, ?)", [_f32([0.0] * 8)]) + assert db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] == 3 + + +def test_diskann_insert_int8_quantizer_knn(db): + """Full insert + query cycle with int8 quantizer.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[16] INDEXED BY diskann(neighbor_quantizer=int8, n_neighbors=8) + ) + """) + import random + random.seed(77) + for i in range(1, 31): + vec = [random.gauss(0, 1) for _ in range(16)] + db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32(vec)]) + + assert db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] == 30 + + # KNN should work + query = [random.gauss(0, 1) for _ in range(16)] + rows = db.execute( + "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=5", + [_f32(query)], + ).fetchall() + assert len(rows) == 5 + distances = [r[1] for r in rows] + for i in range(len(distances) - 1): + assert distances[i] <= distances[i + 1] + + +# ====================================================================== +# Delete edge cases +# ====================================================================== + +def test_diskann_delete_nonexistent(db): + """DELETE of a nonexistent rowid should either be a no-op or return an error, not crash.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8) + ) + """) + db.execute("INSERT INTO t(rowid, emb) VALUES (1, ?)", [_f32([1.0] * 8)]) + # Deleting a nonexistent rowid may error but should not crash + result = exec(db, "DELETE FROM t WHERE rowid = 999") + # Whether it succeeds or errors, the existing row should still be there + assert db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] == 1 + + +def test_diskann_delete_then_reinsert_same_rowid(db): + """Delete rowid 5, then reinsert rowid 5 with a new vector.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8) + ) + """) + for i in range(1, 6): + db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32([float(i)] * 8)]) + + db.execute("DELETE FROM t WHERE rowid = 5") + assert db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] == 4 + + # Reinsert with new vector + db.execute("INSERT INTO t(rowid, emb) VALUES (5, ?)", [_f32([99.0] * 8)]) + assert db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] == 5 + assert db.execute("SELECT count(*) FROM t_diskann_nodes00").fetchone()[0] == 5 + + +def test_diskann_delete_all_then_insert(db): + """Delete everything, then insert new vectors. Graph should rebuild.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8) + ) + """) + for i in range(1, 6): + db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32([float(i)] * 8)]) + + # Delete all + for i in range(1, 6): + db.execute("DELETE FROM t WHERE rowid = ?", [i]) + assert db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] == 0 + + medoid = db.execute("SELECT value FROM t_info WHERE key = 'diskann_medoid_00'").fetchone()[0] + assert medoid is None + + # Insert new vectors + for i in range(10, 15): + db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32([float(i)] * 8)]) + + assert db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] == 5 + assert db.execute("SELECT count(*) FROM t_diskann_nodes00").fetchone()[0] == 5 + + medoid = db.execute("SELECT value FROM t_info WHERE key = 'diskann_medoid_00'").fetchone()[0] + assert medoid is not None + + # KNN should work + rows = db.execute( + "SELECT rowid FROM t WHERE emb MATCH ? AND k=3", + [_f32([12.0] * 8)], + ).fetchall() + assert len(rows) == 3 + + +def test_diskann_delete_preserves_graph_connectivity(db): + """After deleting a node, remaining nodes should still be reachable via KNN.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8) + ) + """) + import random + random.seed(456) + for i in range(1, 21): + vec = [random.gauss(0, 1) for _ in range(8)] + db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32(vec)]) + + # Delete 5 nodes + for i in [3, 7, 11, 15, 19]: + db.execute("DELETE FROM t WHERE rowid = ?", [i]) + + remaining = db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] + assert remaining == 15 + + # Every remaining node should be reachable via KNN (appears somewhere in top-k) + all_rowids = [r[0] for r in db.execute("SELECT rowid FROM t_vectors00").fetchall()] + reachable = set() + for rid in all_rowids: + vec_blob = db.execute("SELECT vector FROM t_vectors00 WHERE rowid = ?", [rid]).fetchone()[0] + rows = db.execute( + "SELECT rowid FROM t WHERE emb MATCH ? AND k=5", + [vec_blob], + ).fetchall() + assert len(rows) >= 1 # At least some results + for r in rows: + reachable.add(r[0]) + # Most nodes should be reachable through the graph + assert len(reachable) >= len(all_rowids) * 0.8, \ + f"Only {len(reachable)}/{len(all_rowids)} nodes reachable" + + +# ====================================================================== +# Update scenarios +# ====================================================================== + +def test_diskann_update_vector(db): + """UPDATE a vector on DiskANN table may not be supported; verify it either works or errors cleanly.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8) + ) + """) + db.execute("INSERT INTO t(rowid, emb) VALUES (1, ?)", [_f32([1, 0, 0, 0, 0, 0, 0, 0])]) + db.execute("INSERT INTO t(rowid, emb) VALUES (2, ?)", [_f32([0, 1, 0, 0, 0, 0, 0, 0])]) + db.execute("INSERT INTO t(rowid, emb) VALUES (3, ?)", [_f32([0, 0, 1, 0, 0, 0, 0, 0])]) + + # UPDATE may not be fully supported for DiskANN yet; verify no crash + result = exec(db, "UPDATE t SET emb = ? WHERE rowid = 1", [_f32([0, 0.9, 0.1, 0, 0, 0, 0, 0])]) + if "error" not in result: + # If UPDATE succeeded, verify KNN reflects the new value + rows = db.execute( + "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=3", + [_f32([0, 1, 0, 0, 0, 0, 0, 0])], + ).fetchall() + assert len(rows) == 3 + # rowid 2 should still be closest (exact match) + assert rows[0][0] == 2 + + +# ====================================================================== +# KNN correctness after mutations +# ====================================================================== + +def test_diskann_knn_recall_after_inserts(db): + """Insert N vectors, verify top-1 recall is 100% for exact matches.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=16) + ) + """) + import random + random.seed(200) + vectors = {} + for i in range(1, 51): + vec = [random.gauss(0, 1) for _ in range(8)] + vectors[i] = vec + db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32(vec)]) + + # Top-1 for each vector should return itself + correct = 0 + for rid, vec in vectors.items(): + rows = db.execute( + "SELECT rowid FROM t WHERE emb MATCH ? AND k=1", + [_f32(vec)], + ).fetchall() + if rows and rows[0][0] == rid: + correct += 1 + + # With binary quantizer, approximate search may not always return exact match + # but should have high recall (at least 80%) + assert correct >= len(vectors) * 0.8, f"Top-1 recall too low: {correct}/{len(vectors)}" + + +def test_diskann_knn_k_larger_than_table(db): + """k=100 on table with 5 rows should return 5.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8) + ) + """) + for i in range(1, 6): + db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32([float(i)] * 8)]) + + rows = db.execute( + "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=100", + [_f32([3.0] * 8)], + ).fetchall() + assert len(rows) == 5 + + +def test_diskann_knn_cosine_metric(db): + """KNN with cosine distance metric.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[8] distance_metric=cosine INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8) + ) + """) + # Insert orthogonal-ish vectors + db.execute("INSERT INTO t(rowid, emb) VALUES (1, ?)", [_f32([1, 0, 0, 0, 0, 0, 0, 0])]) + db.execute("INSERT INTO t(rowid, emb) VALUES (2, ?)", [_f32([0, 1, 0, 0, 0, 0, 0, 0])]) + db.execute("INSERT INTO t(rowid, emb) VALUES (3, ?)", [_f32([0.7, 0.7, 0, 0, 0, 0, 0, 0])]) + + rows = db.execute( + "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=3", + [_f32([1, 0, 0, 0, 0, 0, 0, 0])], + ).fetchall() + assert len(rows) == 3 + # rowid 1 should be closest (exact match in direction) + assert rows[0][0] == 1 + # Distances should be sorted + distances = [r[1] for r in rows] + for i in range(len(distances) - 1): + assert distances[i] <= distances[i + 1] + + +def test_diskann_knn_after_heavy_churn(db): + """Interleave many inserts and deletes, then query.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=16) + ) + """) + import random + random.seed(321) + + # Insert 50 vectors + for i in range(1, 51): + vec = [random.gauss(0, 1) for _ in range(8)] + db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32(vec)]) + + # Delete even-numbered rows + for i in range(2, 51, 2): + db.execute("DELETE FROM t WHERE rowid = ?", [i]) + + # Insert more vectors with higher rowids + for i in range(51, 76): + vec = [random.gauss(0, 1) for _ in range(8)] + db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32(vec)]) + + remaining = db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] + assert remaining == 50 # 25 odd + 25 new + + # KNN should still work and return results + query = [random.gauss(0, 1) for _ in range(8)] + rows = db.execute( + "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=10", + [_f32(query)], + ).fetchall() + assert len(rows) == 10 + # Distances should be sorted + distances = [r[1] for r in rows] + for i in range(len(distances) - 1): + assert distances[i] <= distances[i + 1] + + +def test_diskann_knn_batch_recall(db): + """Insert 100+ vectors and verify reasonable recall.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[16] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=16) + ) + """) + import random + random.seed(42) + N = 150 + vectors = {} + for i in range(1, N + 1): + vec = [random.gauss(0, 1) for _ in range(16)] + vectors[i] = vec + db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32(vec)]) + + # Brute-force top-5 for a query and compare with DiskANN + query = [random.gauss(0, 1) for _ in range(16)] + + # Compute true distances + true_dists = [] + for rid, vec in vectors.items(): + d = sum((a - b) ** 2 for a, b in zip(query, vec)) + true_dists.append((d, rid)) + true_dists.sort() + true_top5 = set(r for _, r in true_dists[:5]) + + rows = db.execute( + "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=5", + [_f32(query)], + ).fetchall() + result_top5 = set(r[0] for r in rows) + assert len(rows) == 5 + + # At least 3 of top-5 should match (reasonable recall for approximate search) + overlap = len(true_top5 & result_top5) + assert overlap >= 3, f"Recall too low: only {overlap}/5 overlap" + + +# ====================================================================== +# Additional edge cases +# ====================================================================== + +def test_diskann_insert_wrong_dimensions(db): + """INSERT with wrong dimension vector should error.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8) + ) + """) + result = exec(db, "INSERT INTO t(rowid, emb) VALUES (1, ?)", [_f32([1.0] * 4)]) + assert "error" in result + + +def test_diskann_knn_wrong_query_dimensions(db): + """KNN MATCH with wrong dimension query should error.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=8) + ) + """) + db.execute("INSERT INTO t(rowid, emb) VALUES (1, ?)", [_f32([1.0] * 8)]) + + result = exec(db, "SELECT rowid FROM t WHERE emb MATCH ? AND k=1", [_f32([1.0] * 4)]) + assert "error" in result + + +def test_diskann_graph_connectivity_after_many_deletes(db): + """After many deletes, the graph should still be connected enough for search.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=16) + ) + """) + import random + random.seed(789) + N = 40 + for i in range(1, N + 1): + vec = [random.gauss(0, 1) for _ in range(8)] + db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32(vec)]) + + # Delete 30 of 40 nodes + to_delete = list(range(1, 31)) + for i in to_delete: + db.execute("DELETE FROM t WHERE rowid = ?", [i]) + + remaining = db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] + assert remaining == 10 + + # Search should still work and return results + query = [random.gauss(0, 1) for _ in range(8)] + rows = db.execute( + "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=10", + [_f32(query)], + ).fetchall() + # Should return some results (graph may be fragmented after heavy deletion) + assert len(rows) >= 1 + # Distances should be sorted + distances = [r[1] for r in rows] + for i in range(len(distances) - 1): + assert distances[i] <= distances[i + 1] + + +def test_diskann_large_batch_insert_500(db): + """Insert 500+ vectors and verify counts and KNN.""" + db.execute(""" + CREATE VIRTUAL TABLE t USING vec0( + emb float[8] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=16) + ) + """) + import random + random.seed(555) + N = 500 + for i in range(1, N + 1): + vec = [random.gauss(0, 1) for _ in range(8)] + db.execute("INSERT INTO t(rowid, emb) VALUES (?, ?)", [i, _f32(vec)]) + + assert db.execute("SELECT count(*) FROM t_vectors00").fetchone()[0] == N + + query = [random.gauss(0, 1) for _ in range(8)] + rows = db.execute( + "SELECT rowid, distance FROM t WHERE emb MATCH ? AND k=20", + [_f32(query)], + ).fetchall() + assert len(rows) == 20 + distances = [r[1] for r in rows] + for i in range(len(distances) - 1): + assert distances[i] <= distances[i + 1] diff --git a/tests/test-unit.c b/tests/test-unit.c index 27a469d..83cedd5 100644 --- a/tests/test-unit.c +++ b/tests/test-unit.c @@ -1187,6 +1187,7 @@ void test_ivf_quantize_binary() { } void test_ivf_config_parsing() { +void test_vec0_parse_vector_column_diskann() { printf("Starting %s...\n", __func__); struct VectorColumnDefinition col; int rc; @@ -1199,6 +1200,34 @@ void test_ivf_config_parsing() { assert(col.index_type == VEC0_INDEX_TYPE_RESCORE); assert(col.rescore.quantizer_type == VEC0_RESCORE_QUANTIZER_BIT); assert(col.rescore.oversample == 8); // default + // Existing syntax (no INDEXED BY) should have diskann.enabled == 0 + { + const char *input = "emb float[128]"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.index_type != VEC0_INDEX_TYPE_DISKANN); + sqlite3_free(col.name); + } + + // With distance_metric but no INDEXED BY + { + const char *input = "emb float[128] distance_metric=cosine"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.index_type != VEC0_INDEX_TYPE_DISKANN); + assert(col.distance_metric == VEC0_DISTANCE_METRIC_COSINE); + sqlite3_free(col.name); + } + + // Basic binary quantizer + { + const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.index_type == VEC0_INDEX_TYPE_DISKANN); + assert(col.diskann.quantizer_type == VEC0_DISKANN_QUANTIZER_BINARY); + assert(col.diskann.n_neighbors == 72); // default + assert(col.diskann.search_list_size == 128); // default assert(col.dimensions == 128); sqlite3_free(col.name); } @@ -1370,6 +1399,681 @@ void test_ivf_config_parsing() { printf(" All ivf_config_parsing tests passed.\n"); } #endif /* SQLITE_VEC_ENABLE_IVF */ + // INT8 quantizer + { + const char *input = "v float[64] INDEXED BY diskann(neighbor_quantizer=int8)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.index_type == VEC0_INDEX_TYPE_DISKANN); + assert(col.diskann.quantizer_type == VEC0_DISKANN_QUANTIZER_INT8); + sqlite3_free(col.name); + } + + // Custom n_neighbors + { + const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=48)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.index_type == VEC0_INDEX_TYPE_DISKANN); + assert(col.diskann.n_neighbors == 48); + sqlite3_free(col.name); + } + + // Custom search_list_size + { + const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, search_list_size=256)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.diskann.search_list_size == 256); + sqlite3_free(col.name); + } + + // Combined with distance_metric (distance_metric first) + { + const char *input = "emb float[128] distance_metric=cosine INDEXED BY diskann(neighbor_quantizer=int8)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.distance_metric == VEC0_DISTANCE_METRIC_COSINE); + assert(col.index_type == VEC0_INDEX_TYPE_DISKANN); + assert(col.diskann.quantizer_type == VEC0_DISKANN_QUANTIZER_INT8); + sqlite3_free(col.name); + } + + // Error: missing neighbor_quantizer (required) + { + const char *input = "emb float[128] INDEXED BY diskann(n_neighbors=72)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_ERROR); + } + + // Error: empty parens + { + const char *input = "emb float[128] INDEXED BY diskann()"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_ERROR); + } + + // Error: unknown quantizer + { + const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=unknown)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_ERROR); + } + + // Error: bad n_neighbors (not divisible by 8) + { + const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=13)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_ERROR); + } + + // Error: n_neighbors too large + { + const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, n_neighbors=512)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_ERROR); + } + + // Error: missing BY + { + const char *input = "emb float[128] INDEXED diskann(neighbor_quantizer=binary)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_ERROR); + } + + // Error: unknown algorithm + { + const char *input = "emb float[128] INDEXED BY hnsw(neighbor_quantizer=binary)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_ERROR); + } + + // Error: unknown option key + { + const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, foobar=baz)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_ERROR); + } + + // Case insensitivity for keywords + { + const char *input = "emb float[128] indexed by DISKANN(NEIGHBOR_QUANTIZER=BINARY)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.index_type == VEC0_INDEX_TYPE_DISKANN); + assert(col.diskann.quantizer_type == VEC0_DISKANN_QUANTIZER_BINARY); + sqlite3_free(col.name); + } + + // Split search_list_size: search and insert + { + const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, search_list_size_search=256, search_list_size_insert=64)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.diskann.search_list_size == 128); // default (unified) + assert(col.diskann.search_list_size_search == 256); + assert(col.diskann.search_list_size_insert == 64); + sqlite3_free(col.name); + } + + // Split search_list_size: only search + { + const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, search_list_size_search=200)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_OK); + assert(col.diskann.search_list_size_search == 200); + assert(col.diskann.search_list_size_insert == 0); + sqlite3_free(col.name); + } + + // Error: cannot mix search_list_size with search_list_size_search + { + const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, search_list_size=128, search_list_size_search=256)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_ERROR); + } + + // Error: cannot mix search_list_size with search_list_size_insert + { + const char *input = "emb float[128] INDEXED BY diskann(neighbor_quantizer=binary, search_list_size=128, search_list_size_insert=64)"; + rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == SQLITE_ERROR); + } + + printf(" All vec0_parse_vector_column_diskann tests passed.\n"); +} + +void test_diskann_validity_bitmap() { + printf("Starting %s...\n", __func__); + + unsigned char validity[3]; // 24 bits + memset(validity, 0, sizeof(validity)); + + // All initially invalid + for (int i = 0; i < 24; i++) { + assert(diskann_validity_get(validity, i) == 0); + } + assert(diskann_validity_count(validity, 24) == 0); + + // Set bit 0 + diskann_validity_set(validity, 0, 1); + assert(diskann_validity_get(validity, 0) == 1); + assert(diskann_validity_count(validity, 24) == 1); + + // Set bit 7 (last bit of first byte) + diskann_validity_set(validity, 7, 1); + assert(diskann_validity_get(validity, 7) == 1); + assert(diskann_validity_count(validity, 24) == 2); + + // Set bit 8 (first bit of second byte) + diskann_validity_set(validity, 8, 1); + assert(diskann_validity_get(validity, 8) == 1); + assert(diskann_validity_count(validity, 24) == 3); + + // Set bit 23 (last bit) + diskann_validity_set(validity, 23, 1); + assert(diskann_validity_get(validity, 23) == 1); + assert(diskann_validity_count(validity, 24) == 4); + + // Clear bit 0 + diskann_validity_set(validity, 0, 0); + assert(diskann_validity_get(validity, 0) == 0); + assert(diskann_validity_count(validity, 24) == 3); + + // Other bits unaffected + assert(diskann_validity_get(validity, 7) == 1); + assert(diskann_validity_get(validity, 8) == 1); + + printf(" All diskann_validity_bitmap tests passed.\n"); +} + +void test_diskann_neighbor_ids() { + printf("Starting %s...\n", __func__); + + unsigned char ids[8 * 8]; // 8 slots * 8 bytes each + memset(ids, 0, sizeof(ids)); + + // Set and get slot 0 + diskann_neighbor_id_set(ids, 0, 42); + assert(diskann_neighbor_id_get(ids, 0) == 42); + + // Set and get middle slot + diskann_neighbor_id_set(ids, 3, 12345); + assert(diskann_neighbor_id_get(ids, 3) == 12345); + + // Set and get last slot + diskann_neighbor_id_set(ids, 7, 99999); + assert(diskann_neighbor_id_get(ids, 7) == 99999); + + // Slot 0 still correct + assert(diskann_neighbor_id_get(ids, 0) == 42); + + // Large value + diskann_neighbor_id_set(ids, 1, INT64_MAX); + assert(diskann_neighbor_id_get(ids, 1) == INT64_MAX); + + printf(" All diskann_neighbor_ids tests passed.\n"); +} + +void test_diskann_quantize_binary() { + printf("Starting %s...\n", __func__); + + // 8-dimensional vector: positive values -> 1, negative/zero -> 0 + float src[8] = {1.0f, -1.0f, 0.5f, 0.0f, -0.5f, 0.1f, -0.1f, 100.0f}; + unsigned char out[1]; // 8 bits = 1 byte + + int rc = diskann_quantize_vector(src, 8, VEC0_DISKANN_QUANTIZER_BINARY, out); + assert(rc == 0); + + // Expected bits (LSB first within each byte): + // bit 0: 1.0 > 0 -> 1 + // bit 1: -1.0 > 0 -> 0 + // bit 2: 0.5 > 0 -> 1 + // bit 3: 0.0 > 0 -> 0 (not strictly greater) + // bit 4: -0.5 > 0 -> 0 + // bit 5: 0.1 > 0 -> 1 + // bit 6: -0.1 > 0 -> 0 + // bit 7: 100.0 > 0 -> 1 + // Expected byte: 1 + 0 + 4 + 0 + 0 + 32 + 0 + 128 = 0b10100101 = 0xA5 + assert(out[0] == 0xA5); + + printf(" All diskann_quantize_binary tests passed.\n"); +} + +void test_diskann_node_init_sizes() { + printf("Starting %s...\n", __func__); + + unsigned char *validity, *ids, *qvecs; + int validitySize, idsSize, qvecsSize; + + // 72 neighbors, binary quantizer, 1024 dims + int rc = diskann_node_init(72, VEC0_DISKANN_QUANTIZER_BINARY, 1024, + &validity, &validitySize, &ids, &idsSize, &qvecs, &qvecsSize); + assert(rc == 0); + assert(validitySize == 9); // 72/8 + assert(idsSize == 576); // 72 * 8 + assert(qvecsSize == 9216); // 72 * (1024/8) + + // All validity bits should be 0 + assert(diskann_validity_count(validity, 72) == 0); + + sqlite3_free(validity); + sqlite3_free(ids); + sqlite3_free(qvecs); + + // 8 neighbors, int8 quantizer, 32 dims + rc = diskann_node_init(8, VEC0_DISKANN_QUANTIZER_INT8, 32, + &validity, &validitySize, &ids, &idsSize, &qvecs, &qvecsSize); + assert(rc == 0); + assert(validitySize == 1); // 8/8 + assert(idsSize == 64); // 8 * 8 + assert(qvecsSize == 256); // 8 * 32 + + sqlite3_free(validity); + sqlite3_free(ids); + sqlite3_free(qvecs); + + printf(" All diskann_node_init_sizes tests passed.\n"); +} + +void test_diskann_node_set_clear_neighbor() { + printf("Starting %s...\n", __func__); + + unsigned char *validity, *ids, *qvecs; + int validitySize, idsSize, qvecsSize; + + // 8 neighbors, binary quantizer, 16 dims (2 bytes per qvec) + int rc = diskann_node_init(8, VEC0_DISKANN_QUANTIZER_BINARY, 16, + &validity, &validitySize, &ids, &idsSize, &qvecs, &qvecsSize); + assert(rc == 0); + + // Create a test quantized vector (2 bytes) + unsigned char test_qvec[2] = {0xAB, 0xCD}; + + // Set neighbor at slot 3 + diskann_node_set_neighbor(validity, ids, qvecs, 3, + 42, test_qvec, VEC0_DISKANN_QUANTIZER_BINARY, 16); + + // Verify slot 3 is valid + assert(diskann_validity_get(validity, 3) == 1); + assert(diskann_validity_count(validity, 8) == 1); + + // Verify rowid + assert(diskann_neighbor_id_get(ids, 3) == 42); + + // Verify quantized vector + const unsigned char *read_qvec = diskann_neighbor_qvec_get( + qvecs, 3, VEC0_DISKANN_QUANTIZER_BINARY, 16); + assert(read_qvec[0] == 0xAB); + assert(read_qvec[1] == 0xCD); + + // Clear slot 3 + diskann_node_clear_neighbor(validity, ids, qvecs, 3, + VEC0_DISKANN_QUANTIZER_BINARY, 16); + assert(diskann_validity_get(validity, 3) == 0); + assert(diskann_neighbor_id_get(ids, 3) == 0); + assert(diskann_validity_count(validity, 8) == 0); + + sqlite3_free(validity); + sqlite3_free(ids); + sqlite3_free(qvecs); + + printf(" All diskann_node_set_clear_neighbor tests passed.\n"); +} + +void test_diskann_prune_select() { + printf("Starting %s...\n", __func__); + + // Scenario: 5 candidates, sorted by distance to p + // Candidates: A(0), B(1), C(2), D(3), E(4) + // p_distances (already sorted): A=1.0, B=2.0, C=3.0, D=4.0, E=5.0 + // + // Inter-candidate distances (5x5 matrix): + // A B C D E + // A 0.0 1.5 3.0 4.0 5.0 + // B 1.5 0.0 1.5 3.0 4.0 + // C 3.0 1.5 0.0 1.5 3.0 + // D 4.0 3.0 1.5 0.0 1.5 + // E 5.0 4.0 3.0 1.5 0.0 + + float p_distances[5] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; + float inter[25] = { + 0.0f, 1.5f, 3.0f, 4.0f, 5.0f, + 1.5f, 0.0f, 1.5f, 3.0f, 4.0f, + 3.0f, 1.5f, 0.0f, 1.5f, 3.0f, + 4.0f, 3.0f, 1.5f, 0.0f, 1.5f, + 5.0f, 4.0f, 3.0f, 1.5f, 0.0f, + }; + int selected[5]; + int count; + + // alpha=1.0, R=3: greedy selection + // Round 1: Pick A (closest). Prune check: + // B: 1.0*1.5 <= 2.0? yes -> pruned + // C: 1.0*3.0 <= 3.0? yes -> pruned + // D: 1.0*4.0 <= 4.0? yes -> pruned + // E: 1.0*5.0 <= 5.0? yes -> pruned + // Result: only A selected + { + int rc = diskann_prune_select(inter, p_distances, 5, 1.0f, 3, selected, &count); + assert(rc == 0); + assert(count == 1); + assert(selected[0] == 1); // A + } + + // alpha=1.5, R=3: diversity-aware + // Round 1: Pick A. Prune check: + // B: 1.5*1.5=2.25 <= 2.0? no -> keep + // C: 1.5*3.0=4.5 <= 3.0? no -> keep + // D: 1.5*4.0=6.0 <= 4.0? no -> keep + // E: 1.5*5.0=7.5 <= 5.0? no -> keep + // Round 2: Pick B. Prune check: + // C: 1.5*1.5=2.25 <= 3.0? yes -> pruned + // D: 1.5*3.0=4.5 <= 4.0? no -> keep + // E: 1.5*4.0=6.0 <= 5.0? no -> keep + // Round 3: Pick D. Done, 3 selected. + { + int rc = diskann_prune_select(inter, p_distances, 5, 1.5f, 3, selected, &count); + assert(rc == 0); + assert(count == 3); + assert(selected[0] == 1); // A + assert(selected[1] == 1); // B + assert(selected[3] == 1); // D + assert(selected[2] == 0); // C pruned + assert(selected[4] == 0); // E not reached + } + + // R > num_candidates with very high alpha (no pruning): select all + { + int rc = diskann_prune_select(inter, p_distances, 5, 100.0f, 10, selected, &count); + assert(rc == 0); + assert(count == 5); + } + + // Empty candidate set + { + int rc = diskann_prune_select(NULL, NULL, 0, 1.2f, 3, selected, &count); + assert(rc == 0); + assert(count == 0); + } + + printf(" All diskann_prune_select tests passed.\n"); +} + +void test_diskann_quantized_vector_byte_size() { + printf("Starting %s...\n", __func__); + + // Binary quantizer: 1 bit per dimension, so 128 dims = 16 bytes + assert(diskann_quantized_vector_byte_size(VEC0_DISKANN_QUANTIZER_BINARY, 128) == 16); + assert(diskann_quantized_vector_byte_size(VEC0_DISKANN_QUANTIZER_BINARY, 8) == 1); + assert(diskann_quantized_vector_byte_size(VEC0_DISKANN_QUANTIZER_BINARY, 1024) == 128); + + // INT8 quantizer: 1 byte per dimension + assert(diskann_quantized_vector_byte_size(VEC0_DISKANN_QUANTIZER_INT8, 128) == 128); + assert(diskann_quantized_vector_byte_size(VEC0_DISKANN_QUANTIZER_INT8, 1) == 1); + assert(diskann_quantized_vector_byte_size(VEC0_DISKANN_QUANTIZER_INT8, 768) == 768); + + printf(" All diskann_quantized_vector_byte_size tests passed.\n"); +} + +void test_diskann_config_defaults() { + printf("Starting %s...\n", __func__); + + // A freshly zero-initialized VectorColumnDefinition should have diskann.enabled == 0 + struct VectorColumnDefinition col; + memset(&col, 0, sizeof(col)); + assert(col.index_type != VEC0_INDEX_TYPE_DISKANN); + assert(col.diskann.n_neighbors == 0); + assert(col.diskann.search_list_size == 0); + + // Verify parsing a normal vector column still works and diskann is not enabled + { + const char *input = "embedding float[768]"; + int rc = vec0_parse_vector_column(input, (int)strlen(input), &col); + assert(rc == 0 /* SQLITE_OK */); + assert(col.index_type != VEC0_INDEX_TYPE_DISKANN); + sqlite3_free(col.name); + } + + printf(" All diskann_config_defaults tests passed.\n"); +} + +// ====================================================================== +// Additional DiskANN unit tests +// ====================================================================== + +void test_diskann_quantize_int8() { + printf("Starting %s...\n", __func__); + + // INT8 quantization uses fixed range [-1, 1]: + // step = 2.0 / 255.0 + // out[i] = (i8)((src[i] + 1.0) / step - 128.0) + float src[4] = {-1.0f, 0.0f, 0.5f, 1.0f}; + unsigned char out[4]; + + int rc = diskann_quantize_vector(src, 4, VEC0_DISKANN_QUANTIZER_INT8, out); + assert(rc == 0); + + int8_t *signed_out = (int8_t *)out; + // -1.0 -> (0/step) - 128 = -128 + assert(signed_out[0] == -128); + // 0.0 -> (1.0/step) - 128 ~= 127.5 - 128 ~= -0.5 -> (i8)(-0.5) = 0 + assert(signed_out[1] >= -2 && signed_out[1] <= 2); + // 0.5 -> (1.5/step) - 128 ~= 191.25 - 128 = 63.25 -> (i8) 63 + assert(signed_out[2] >= 60 && signed_out[2] <= 66); + // 1.0 -> should be close to 127 (may have float precision issues) + assert(signed_out[3] >= 126 && signed_out[3] <= 127); + + printf(" All diskann_quantize_int8 tests passed.\n"); +} + +void test_diskann_quantize_binary_16d() { + printf("Starting %s...\n", __func__); + + // 16-dimensional vector (2 bytes output) + float src[16] = { + 1.0f, -1.0f, 0.5f, -0.5f, // byte 0: bit0=1, bit1=0, bit2=1, bit3=0 + 0.1f, -0.1f, 0.0f, 100.0f, // byte 0: bit4=1, bit5=0, bit6=0, bit7=1 + -1.0f, 1.0f, 1.0f, 1.0f, // byte 1: bit0=0, bit1=1, bit2=1, bit3=1 + -1.0f, -1.0f, 1.0f, -1.0f // byte 1: bit4=0, bit5=0, bit6=1, bit7=0 + }; + unsigned char out[2]; + + int rc = diskann_quantize_vector(src, 16, VEC0_DISKANN_QUANTIZER_BINARY, out); + assert(rc == 0); + + // byte 0: bits 0,2,4,7 set -> 0b10010101 = 0x95 + assert(out[0] == 0x95); + // byte 1: bits 1,2,3,6 set -> 0b01001110 = 0x4E + assert(out[1] == 0x4E); + + printf(" All diskann_quantize_binary_16d tests passed.\n"); +} + +void test_diskann_quantize_binary_all_positive() { + printf("Starting %s...\n", __func__); + + float src[8] = {1.0f, 2.0f, 0.1f, 0.001f, 100.0f, 42.0f, 0.5f, 3.14f}; + unsigned char out[1]; + + int rc = diskann_quantize_vector(src, 8, VEC0_DISKANN_QUANTIZER_BINARY, out); + assert(rc == 0); + assert(out[0] == 0xFF); // All bits set + + printf(" All diskann_quantize_binary_all_positive tests passed.\n"); +} + +void test_diskann_quantize_binary_all_negative() { + printf("Starting %s...\n", __func__); + + float src[8] = {-1.0f, -2.0f, -0.1f, -0.001f, -100.0f, -42.0f, -0.5f, 0.0f}; + unsigned char out[1]; + + int rc = diskann_quantize_vector(src, 8, VEC0_DISKANN_QUANTIZER_BINARY, out); + assert(rc == 0); + assert(out[0] == 0x00); // No bits set (all <= 0) + + printf(" All diskann_quantize_binary_all_negative tests passed.\n"); +} + +void test_diskann_candidate_list_operations() { + printf("Starting %s...\n", __func__); + + struct DiskannCandidateList list; + int rc = _test_diskann_candidate_list_init(&list, 5); + assert(rc == 0); + + // Insert candidates in non-sorted order + _test_diskann_candidate_list_insert(&list, 10, 3.0f); + _test_diskann_candidate_list_insert(&list, 20, 1.0f); + _test_diskann_candidate_list_insert(&list, 30, 2.0f); + + assert(_test_diskann_candidate_list_count(&list) == 3); + // Should be sorted by distance + assert(_test_diskann_candidate_list_rowid(&list, 0) == 20); // dist 1.0 + assert(_test_diskann_candidate_list_rowid(&list, 1) == 30); // dist 2.0 + assert(_test_diskann_candidate_list_rowid(&list, 2) == 10); // dist 3.0 + + assert(_test_diskann_candidate_list_distance(&list, 0) == 1.0f); + assert(_test_diskann_candidate_list_distance(&list, 1) == 2.0f); + assert(_test_diskann_candidate_list_distance(&list, 2) == 3.0f); + + // Deduplication: inserting same rowid with better distance should update + _test_diskann_candidate_list_insert(&list, 10, 0.5f); + assert(_test_diskann_candidate_list_count(&list) == 3); // Same count + assert(_test_diskann_candidate_list_rowid(&list, 0) == 10); // Now first + assert(_test_diskann_candidate_list_distance(&list, 0) == 0.5f); + + // Next unvisited: should be index 0 + int idx = _test_diskann_candidate_list_next_unvisited(&list); + assert(idx == 0); + + // Mark visited + _test_diskann_candidate_list_set_visited(&list, 0); + idx = _test_diskann_candidate_list_next_unvisited(&list); + assert(idx == 1); // Skip visited + + // Fill to capacity (5) and try inserting a worse candidate + _test_diskann_candidate_list_insert(&list, 40, 4.0f); + _test_diskann_candidate_list_insert(&list, 50, 5.0f); + assert(_test_diskann_candidate_list_count(&list) == 5); + + // Insert worse than worst -> should be discarded + int inserted = _test_diskann_candidate_list_insert(&list, 60, 10.0f); + assert(inserted == 0); + assert(_test_diskann_candidate_list_count(&list) == 5); + + // Insert better than worst -> should replace worst + inserted = _test_diskann_candidate_list_insert(&list, 60, 3.5f); + assert(inserted == 1); + assert(_test_diskann_candidate_list_count(&list) == 5); + + _test_diskann_candidate_list_free(&list); + + printf(" All diskann_candidate_list_operations tests passed.\n"); +} + +void test_diskann_visited_set_operations() { + printf("Starting %s...\n", __func__); + + struct DiskannVisitedSet set; + int rc = _test_diskann_visited_set_init(&set, 32); + assert(rc == 0); + + // Empty set + assert(_test_diskann_visited_set_contains(&set, 1) == 0); + assert(_test_diskann_visited_set_contains(&set, 100) == 0); + + // Insert and check + int inserted = _test_diskann_visited_set_insert(&set, 42); + assert(inserted == 1); + assert(_test_diskann_visited_set_contains(&set, 42) == 1); + assert(_test_diskann_visited_set_contains(&set, 43) == 0); + + // Double insert returns 0 + inserted = _test_diskann_visited_set_insert(&set, 42); + assert(inserted == 0); + + // Insert several + _test_diskann_visited_set_insert(&set, 1); + _test_diskann_visited_set_insert(&set, 2); + _test_diskann_visited_set_insert(&set, 100); + _test_diskann_visited_set_insert(&set, 999); + assert(_test_diskann_visited_set_contains(&set, 1) == 1); + assert(_test_diskann_visited_set_contains(&set, 2) == 1); + assert(_test_diskann_visited_set_contains(&set, 100) == 1); + assert(_test_diskann_visited_set_contains(&set, 999) == 1); + assert(_test_diskann_visited_set_contains(&set, 3) == 0); + + // Sentinel value (rowid 0) should not be insertable + assert(_test_diskann_visited_set_contains(&set, 0) == 0); + inserted = _test_diskann_visited_set_insert(&set, 0); + assert(inserted == 0); + + _test_diskann_visited_set_free(&set); + + printf(" All diskann_visited_set_operations tests passed.\n"); +} + +void test_diskann_prune_select_single_candidate() { + printf("Starting %s...\n", __func__); + + float p_distances[1] = {5.0f}; + float inter[1] = {0.0f}; + int selected[1]; + int count; + + int rc = diskann_prune_select(inter, p_distances, 1, 1.0f, 3, selected, &count); + assert(rc == 0); + assert(count == 1); + assert(selected[0] == 1); + + printf(" All diskann_prune_select_single_candidate tests passed.\n"); +} + +void test_diskann_prune_select_all_identical_distances() { + printf("Starting %s...\n", __func__); + + float p_distances[4] = {2.0f, 2.0f, 2.0f, 2.0f}; + // All inter-distances are equal too + float inter[16] = { + 0.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 0.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 0.0f, + }; + int selected[4]; + int count; + + // alpha=1.0: pick first, then check if alpha * inter[0][j] <= p_dist[j] + // 1.0 * 1.0 <= 2.0? yes, so all are pruned after picking the first + int rc = diskann_prune_select(inter, p_distances, 4, 1.0f, 4, selected, &count); + assert(rc == 0); + assert(count >= 1); // At least one selected + + printf(" All diskann_prune_select_all_identical_distances tests passed.\n"); +} + +void test_diskann_prune_select_max_neighbors_1() { + printf("Starting %s...\n", __func__); + + float p_distances[3] = {1.0f, 2.0f, 3.0f}; + float inter[9] = { + 0.0f, 5.0f, 5.0f, + 5.0f, 0.0f, 5.0f, + 5.0f, 5.0f, 0.0f, + }; + int selected[3]; + int count; + + // R=1: should select exactly 1 + int rc = diskann_prune_select(inter, p_distances, 3, 1.0f, 1, selected, &count); + assert(rc == 0); + assert(count == 1); + assert(selected[0] == 1); // First (closest) is selected + + printf(" All diskann_prune_select_max_neighbors_1 tests passed.\n"); +} int main() { printf("Starting unit tests...\n"); @@ -1402,5 +2106,23 @@ int main() { test_ivf_quantize_binary(); test_ivf_config_parsing(); #endif + test_vec0_parse_vector_column_diskann(); + test_diskann_validity_bitmap(); + test_diskann_neighbor_ids(); + test_diskann_quantize_binary(); + test_diskann_node_init_sizes(); + test_diskann_node_set_clear_neighbor(); + test_diskann_prune_select(); + test_diskann_quantized_vector_byte_size(); + test_diskann_config_defaults(); + test_diskann_quantize_int8(); + test_diskann_quantize_binary_16d(); + test_diskann_quantize_binary_all_positive(); + test_diskann_quantize_binary_all_negative(); + test_diskann_candidate_list_operations(); + test_diskann_visited_set_operations(); + test_diskann_prune_select_single_candidate(); + test_diskann_prune_select_all_identical_distances(); + test_diskann_prune_select_max_neighbors_1(); printf("All unit tests passed.\n"); }