From dbbb4b98f76a86d08bfa1c5ee14629d6ee2a6fad Mon Sep 17 00:00:00 2001 From: Alex Garcia Date: Sun, 29 Mar 2026 19:47:12 -0700 Subject: [PATCH] Add comprehensive ANN benchmarking suite Extend benchmarks-ann/ with results database (SQLite with per-query detail and continuous writes), dataset subfolder organization, --subset-size and --warmup options. Supports systematic comparison across flat, rescore, IVF, and DiskANN index types. --- .gitignore | 3 + benchmarks-ann/.gitignore | 6 + benchmarks-ann/Makefile | 28 +- benchmarks-ann/README.md | 114 ++- benchmarks-ann/bench.py | 946 +++++++++++++----- benchmarks-ann/datasets/cohere10m/Makefile | 27 + .../datasets/cohere10m/build_base_db.py | 134 +++ .../{seed => datasets/cohere1m}/.gitignore | 0 .../{seed => datasets/cohere1m}/Makefile | 0 .../cohere1m}/build_base_db.py | 0 benchmarks-ann/datasets/nyt-1024/Makefile | 30 + .../datasets/nyt-1024/build-base.py | 163 +++ benchmarks-ann/datasets/nyt-1024/queries.txt | 100 ++ benchmarks-ann/datasets/nyt-384/Makefile | 29 + benchmarks-ann/datasets/nyt-384/queries.txt | 100 ++ benchmarks-ann/datasets/nyt-768/Makefile | 37 + .../datasets/nyt-768/build-contents.py | 64 ++ .../datasets/nyt-768/distill-model.py | 13 + benchmarks-ann/datasets/nyt-768/queries.txt | 100 ++ benchmarks-ann/datasets/nyt/.gitignore | 1 + benchmarks-ann/datasets/nyt/Makefile | 30 + benchmarks-ann/datasets/nyt/build-base.py | 165 +++ benchmarks-ann/datasets/nyt/build-contents.py | 52 + benchmarks-ann/datasets/nyt/queries.txt | 100 ++ benchmarks-ann/faiss_kmeans.py | 101 ++ benchmarks-ann/results_schema.sql | 76 ++ 26 files changed, 2127 insertions(+), 292 deletions(-) create mode 100644 benchmarks-ann/datasets/cohere10m/Makefile create mode 100644 benchmarks-ann/datasets/cohere10m/build_base_db.py rename benchmarks-ann/{seed => datasets/cohere1m}/.gitignore (100%) rename benchmarks-ann/{seed => datasets/cohere1m}/Makefile (100%) rename benchmarks-ann/{seed => datasets/cohere1m}/build_base_db.py (100%) create mode 100644 benchmarks-ann/datasets/nyt-1024/Makefile create mode 100644 benchmarks-ann/datasets/nyt-1024/build-base.py create mode 100644 benchmarks-ann/datasets/nyt-1024/queries.txt create mode 100644 benchmarks-ann/datasets/nyt-384/Makefile create mode 100644 benchmarks-ann/datasets/nyt-384/queries.txt create mode 100644 benchmarks-ann/datasets/nyt-768/Makefile create mode 100644 benchmarks-ann/datasets/nyt-768/build-contents.py create mode 100644 benchmarks-ann/datasets/nyt-768/distill-model.py create mode 100644 benchmarks-ann/datasets/nyt-768/queries.txt create mode 100644 benchmarks-ann/datasets/nyt/.gitignore create mode 100644 benchmarks-ann/datasets/nyt/Makefile create mode 100644 benchmarks-ann/datasets/nyt/build-base.py create mode 100644 benchmarks-ann/datasets/nyt/build-contents.py create mode 100644 benchmarks-ann/datasets/nyt/queries.txt create mode 100644 benchmarks-ann/faiss_kmeans.py create mode 100644 benchmarks-ann/results_schema.sql diff --git a/.gitignore b/.gitignore index 0268d5d..ef549f4 100644 --- a/.gitignore +++ b/.gitignore @@ -31,3 +31,6 @@ poetry.lock memstat.c memstat.* + + +.DS_Store \ No newline at end of file diff --git a/benchmarks-ann/.gitignore b/benchmarks-ann/.gitignore index c418b76..95707b9 100644 --- a/benchmarks-ann/.gitignore +++ b/benchmarks-ann/.gitignore @@ -1,2 +1,8 @@ *.db +*.db-shm +*.db-wal +*.parquet runs/ + +viewer/ +searcher/ \ No newline at end of file diff --git a/benchmarks-ann/Makefile b/benchmarks-ann/Makefile index ddceb65..a631478 100644 --- a/benchmarks-ann/Makefile +++ b/benchmarks-ann/Makefile @@ -1,5 +1,5 @@ BENCH = python bench.py -BASE_DB = seed/base.db +BASE_DB = cohere1m/base.db EXT = ../dist/vec0 # --- Baseline (brute-force) configs --- @@ -33,7 +33,7 @@ ALL_CONFIGS = $(BASELINES) $(RESCORE_CONFIGS) $(IVF_CONFIGS) $(DISKANN_CONFIGS) # --- Data preparation --- seed: - $(MAKE) -C seed + $(MAKE) -C cohere1m ground-truth: seed python ground_truth.py --subset-size 10000 @@ -42,43 +42,43 @@ ground-truth: seed # --- Quick smoke test --- bench-smoke: seed - $(BENCH) --subset-size 5000 -k 10 -n 20 -o runs/smoke \ + $(BENCH) --subset-size 5000 -k 10 -n 20 --dataset cohere1m -o runs \ "brute-float:type=baseline,variant=float" \ "ivf-quick:type=ivf,nlist=16,nprobe=4" \ "diskann-quick:type=diskann,R=48,L=64,quantizer=binary" bench-rescore: seed - $(BENCH) --subset-size 10000 -k 10 -o runs/rescore \ + $(BENCH) --subset-size 10000 -k 10 --dataset cohere1m -o runs \ $(RESCORE_CONFIGS) # --- Standard sizes --- bench-10k: seed - $(BENCH) --subset-size 10000 -k 10 -o runs/10k $(ALL_CONFIGS) + $(BENCH) --subset-size 10000 -k 10 --dataset cohere1m -o runs $(ALL_CONFIGS) bench-50k: seed - $(BENCH) --subset-size 50000 -k 10 -o runs/50k $(ALL_CONFIGS) + $(BENCH) --subset-size 50000 -k 10 --dataset cohere1m -o runs $(ALL_CONFIGS) bench-100k: seed - $(BENCH) --subset-size 100000 -k 10 -o runs/100k $(ALL_CONFIGS) + $(BENCH) --subset-size 100000 -k 10 --dataset cohere1m -o runs $(ALL_CONFIGS) bench-all: bench-10k bench-50k bench-100k # --- IVF across sizes --- bench-ivf: seed - $(BENCH) --subset-size 10000 -k 10 -o runs/ivf $(BASELINES) $(IVF_CONFIGS) - $(BENCH) --subset-size 50000 -k 10 -o runs/ivf $(BASELINES) $(IVF_CONFIGS) - $(BENCH) --subset-size 100000 -k 10 -o runs/ivf $(BASELINES) $(IVF_CONFIGS) + $(BENCH) --subset-size 10000 -k 10 --dataset cohere1m -o runs $(BASELINES) $(IVF_CONFIGS) + $(BENCH) --subset-size 50000 -k 10 --dataset cohere1m -o runs $(BASELINES) $(IVF_CONFIGS) + $(BENCH) --subset-size 100000 -k 10 --dataset cohere1m -o runs $(BASELINES) $(IVF_CONFIGS) # --- DiskANN across sizes --- bench-diskann: seed - $(BENCH) --subset-size 10000 -k 10 -o runs/diskann $(BASELINES) $(DISKANN_CONFIGS) - $(BENCH) --subset-size 50000 -k 10 -o runs/diskann $(BASELINES) $(DISKANN_CONFIGS) - $(BENCH) --subset-size 100000 -k 10 -o runs/diskann $(BASELINES) $(DISKANN_CONFIGS) + $(BENCH) --subset-size 10000 -k 10 --dataset cohere1m -o runs $(BASELINES) $(DISKANN_CONFIGS) + $(BENCH) --subset-size 50000 -k 10 --dataset cohere1m -o runs $(BASELINES) $(DISKANN_CONFIGS) + $(BENCH) --subset-size 100000 -k 10 --dataset cohere1m -o runs $(BASELINES) $(DISKANN_CONFIGS) # --- Report --- report: - @echo "Use: sqlite3 runs//results.db 'SELECT * FROM bench_results ORDER BY recall DESC'" + @echo "Use: sqlite3 runs/cohere1m//results.db 'SELECT run_id, config_name, status, recall FROM runs JOIN run_results USING(run_id)'" # --- Cleanup --- clean: diff --git a/benchmarks-ann/README.md b/benchmarks-ann/README.md index 1f7fd5c..88f1c74 100644 --- a/benchmarks-ann/README.md +++ b/benchmarks-ann/README.md @@ -1,81 +1,111 @@ # KNN Benchmarks for sqlite-vec Benchmarking infrastructure for vec0 KNN configurations. Includes brute-force -baselines (float, int8, bit); index-specific branches add their own types -via the `INDEX_REGISTRY` in `bench.py`. +baselines (float, int8, bit), rescore, IVF, and DiskANN index types. + +## Datasets + +Each dataset is a subdirectory containing a `Makefile` and `build_base_db.py` +that produce a `base.db`. The benchmark runner auto-discovers any subdirectory +with a `base.db` file. + +``` +cohere1m/ # Cohere 768d cosine, 1M vectors + Makefile # downloads parquets from Zilliz, builds base.db + build_base_db.py + base.db # (generated) + +cohere10m/ # Cohere 768d cosine, 10M vectors (10 train shards) + Makefile # make -j12 download to fetch all shards in parallel + build_base_db.py + base.db # (generated) +``` + +Every `base.db` has the same schema: + +| Table | Columns | Description | +|-------|---------|-------------| +| `train` | `id INTEGER PRIMARY KEY, vector BLOB` | Indexed vectors (f32 blobs) | +| `query_vectors` | `id INTEGER PRIMARY KEY, vector BLOB` | Query vectors for KNN evaluation | +| `neighbors` | `query_vector_id INTEGER, rank INTEGER, neighbors_id TEXT` | Ground-truth nearest neighbors | + +To add a new dataset, create a directory with a `Makefile` that builds `base.db` +with the tables above. It will be available via `--dataset ` automatically. + +### Building datasets + +```bash +# Cohere 1M +cd cohere1m && make download && make && cd .. + +# Cohere 10M (parallel download recommended — 10 train shards + test + neighbors) +cd cohere10m && make -j12 download && make && cd .. +``` ## Prerequisites -- Built `dist/vec0` extension (run `make` from repo root) +- Built `dist/vec0` extension (run `make loadable` from repo root) - Python 3.10+ -- `uv` (for seed data prep): `pip install uv` +- `uv` ## Quick start ```bash -# 1. Download dataset and build seed DB (~3 GB download, ~5 min) -make seed +# 1. Build a dataset +cd cohere1m && make && cd .. -# 2. Run a quick smoke test (5k vectors, ~1 min) +# 2. Quick smoke test (5k vectors) make bench-smoke -# 3. Run full benchmark at 10k +# 3. Full benchmark at 10k make bench-10k ``` ## Usage -### Direct invocation - ```bash -python bench.py --subset-size 10000 \ +uv run python bench.py --subset-size 10000 -k 10 -n 50 --dataset cohere1m \ "brute-float:type=baseline,variant=float" \ - "brute-int8:type=baseline,variant=int8" \ - "brute-bit:type=baseline,variant=bit" + "rescore-bit-os8:type=rescore,quantizer=bit,oversample=8" ``` ### Config format `name:type=,key=val,key=val` -| Index type | Keys | Branch | -|-----------|------|--------| -| `baseline` | `variant` (float/int8/bit), `oversample` | this branch | - -Index branches register additional types in `INDEX_REGISTRY`. See the -docstring in `bench.py` for the extension API. +| Index type | Keys | +|-----------|------| +| `baseline` | `variant` (float/int8/bit), `oversample` | +| `rescore` | `quantizer` (bit/int8), `oversample` | +| `ivf` | `nlist`, `nprobe` | +| `diskann` | `R`, `L`, `quantizer` (binary/int8), `buffer_threshold` | ### Make targets | Target | Description | |--------|-------------| -| `make seed` | Download COHERE 1M dataset | -| `make ground-truth` | Pre-compute ground truth for 10k/50k/100k | -| `make bench-smoke` | Quick 5k baseline test | +| `make seed` | Download and build default dataset | +| `make bench-smoke` | Quick 5k test (3 configs) | | `make bench-10k` | All configs at 10k vectors | | `make bench-50k` | All configs at 50k vectors | | `make bench-100k` | All configs at 100k vectors | | `make bench-all` | 10k + 50k + 100k | +| `make bench-ivf` | Baselines + IVF across 10k/50k/100k | +| `make bench-diskann` | Baselines + DiskANN across 10k/50k/100k | + +## Results DB + +Each run writes to `runs///results.db` (SQLite, WAL mode). +Progress is written continuously — query from another terminal to monitor: + +```bash +sqlite3 runs/cohere1m/10000/results.db "SELECT run_id, config_name, status FROM runs" +``` + +See `results_schema.sql` for the full schema (tables: `runs`, `run_results`, +`insert_batches`, `queries`). ## Adding an index type -In your index branch, add an entry to `INDEX_REGISTRY` in `bench.py` and -append your configs to `ALL_CONFIGS` in the `Makefile`. See the existing -`baseline` entry and the comments in both files for the pattern. - -## Results - -Results are stored in `runs//results.db` using the schema in `schema.sql`. - -```bash -sqlite3 runs/10k/results.db " - SELECT config_name, recall, mean_ms, qps - FROM bench_results - ORDER BY recall DESC -" -``` - -## Dataset - -[Zilliz COHERE Medium 1M](https://zilliz.com/learn/datasets-for-vector-database-benchmarks): -768 dimensions, cosine distance, 1M train vectors + 10k query vectors with precomputed neighbors. +Add an entry to `INDEX_REGISTRY` in `bench.py` and append configs to +`ALL_CONFIGS` in the `Makefile`. See existing entries for the pattern. diff --git a/benchmarks-ann/bench.py b/benchmarks-ann/bench.py index 520db77..a4cbbe4 100644 --- a/benchmarks-ann/bench.py +++ b/benchmarks-ann/bench.py @@ -6,7 +6,7 @@ across different vec0 configurations. Config format: name:type=,key=val,key=val - Available types: none, vec0-flat, rescore, ivf, diskann + Available types: none, vec0-flat, quantized, rescore, ivf, diskann Usage: python bench.py --subset-size 10000 \ @@ -15,7 +15,7 @@ Usage: "flat-int8:type=vec0-flat,variant=int8" """ import argparse -from datetime import datetime, timezone +import json import os import sqlite3 import statistics @@ -23,9 +23,36 @@ import time _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) EXT_PATH = os.path.join(_SCRIPT_DIR, "..", "dist", "vec0") -BASE_DB = os.path.join(_SCRIPT_DIR, "seed", "base.db") INSERT_BATCH_SIZE = 1000 +_DATASETS_DIR = os.path.join(_SCRIPT_DIR, "datasets") + +DATASETS = { + "cohere1m": {"base_db": os.path.join(_DATASETS_DIR, "cohere1m", "base.db"), "dimensions": 768}, + "cohere10m": {"base_db": os.path.join(_DATASETS_DIR, "cohere10m", "base.db"), "dimensions": 768}, + "nyt": {"base_db": os.path.join(_DATASETS_DIR, "nyt", "base.db"), "dimensions": 256}, + "nyt-768": {"base_db": os.path.join(_DATASETS_DIR, "nyt-768", "base.db"), "dimensions": 768}, + "nyt-1024": {"base_db": os.path.join(_DATASETS_DIR, "nyt-1024", "base.db"), "dimensions": 1024}, + "nyt-384": {"base_db": os.path.join(_DATASETS_DIR, "nyt-384", "base.db"), "dimensions": 384}, +} + + +# ============================================================================ +# Timing helpers +# ============================================================================ + + +def now_ns(): + return time.time_ns() + + +def ns_to_s(ns): + return ns / 1_000_000_000 + + +def ns_to_ms(ns): + return ns / 1_000_000 + # ============================================================================ # Index registry — extension point for ANN index branches @@ -36,7 +63,9 @@ INSERT_BATCH_SIZE = 1000 # "create_table_sql": fn(params) -> SQL string # "insert_sql": fn(params) -> SQL string (or None for default) # "post_insert_hook": fn(conn, params) -> train_time_s (or None) +# "train_sql": fn(params) -> SQL string (or None if no training) # "run_query": fn(conn, params, query, k) -> [(id, distance), ...] (or None for default MATCH) +# "query_sql": fn(params) -> SQL string (or None for default MATCH) # "describe": fn(params) -> str (one-line description) # # To add a new index type, add an entry here. Example (in your branch): @@ -59,6 +88,7 @@ INDEX_REGISTRY = {} def _none_create_table_sql(params): + # none uses raw tables — no dimension in DDL variant = params["variant"] if variant == "int8": return ( @@ -138,7 +168,7 @@ def _none_run_query(conn, params, query, k): return conn.execute( "SELECT id, vec_distance_cosine(:query, embedding) as distance " - "FROM vec_items ORDER BY 2 LIMIT :k", + "FROM vec_items WHERE distance IS NOT NULL ORDER BY 2 LIMIT :k", {"query": query, "k": k}, ).fetchall() @@ -155,7 +185,9 @@ INDEX_REGISTRY["none"] = { "create_table_sql": _none_create_table_sql, "insert_sql": _none_insert_sql, "post_insert_hook": None, + "train_sql": None, "run_query": _none_run_query, + "query_sql": None, "describe": _none_describe, } @@ -166,17 +198,18 @@ INDEX_REGISTRY["none"] = { def _vec0flat_create_table_sql(params): + D = params.get("_dimensions", 768) variant = params["variant"] extra = "" if variant == "int8": - extra = ", embedding_int8 int8[768]" + extra = f", embedding_int8 int8[{D}]" elif variant == "bit": - extra = ", embedding_bq bit[768]" + extra = f", embedding_bq bit[{D}]" return ( f"CREATE VIRTUAL TABLE vec_items USING vec0(" f" chunk_size=256," f" id integer primary key," - f" embedding float[768] distance_metric=cosine" + f" embedding float[{D}] distance_metric=cosine" f" {extra})" ) @@ -228,6 +261,32 @@ def _vec0flat_run_query(conn, params, query, k): return None # use default MATCH +def _vec0flat_query_sql(params): + variant = params["variant"] + oversample = params.get("oversample", 8) + if variant == "int8": + return ( + "WITH coarse AS (" + " SELECT id, embedding FROM vec_items" + " WHERE embedding_int8 MATCH vec_quantize_int8(:query, 'unit')" + f" LIMIT :k * {oversample}" + ") " + "SELECT id, vec_distance_cosine(embedding, :query) as distance " + "FROM coarse ORDER BY 2 LIMIT :k" + ) + elif variant == "bit": + return ( + "WITH coarse AS (" + " SELECT id, embedding FROM vec_items" + " WHERE embedding_bq MATCH vec_quantize_binary(:query)" + f" LIMIT :k * {oversample}" + ") " + "SELECT id, vec_distance_cosine(embedding, :query) as distance " + "FROM coarse ORDER BY 2 LIMIT :k" + ) + return None + + def _vec0flat_describe(params): v = params["variant"] if v in ("int8", "bit"): @@ -240,24 +299,115 @@ INDEX_REGISTRY["vec0-flat"] = { "create_table_sql": _vec0flat_create_table_sql, "insert_sql": _vec0flat_insert_sql, "post_insert_hook": None, + "train_sql": None, "run_query": _vec0flat_run_query, + "query_sql": _vec0flat_query_sql, "describe": _vec0flat_describe, } +# ============================================================================ +# Quantized-only implementation (no rescoring) +# ============================================================================ + + +def _quantized_create_table_sql(params): + D = params.get("_dimensions", 768) + quantizer = params["quantizer"] + if quantizer == "int8": + col = f"embedding int8[{D}]" + elif quantizer == "bit": + col = f"embedding bit[{D}]" + else: + raise ValueError(f"Unknown quantizer: {quantizer}") + return ( + f"CREATE VIRTUAL TABLE vec_items USING vec0(" + f" chunk_size=256," + f" id integer primary key," + f" {col})" + ) + + +def _quantized_insert_sql(params): + quantizer = params["quantizer"] + if quantizer == "int8": + return ( + "INSERT INTO vec_items(id, embedding) " + "SELECT id, vec_quantize_int8(vector, 'unit') " + "FROM base.train WHERE id >= :lo AND id < :hi" + ) + elif quantizer == "bit": + return ( + "INSERT INTO vec_items(id, embedding) " + "SELECT id, vec_quantize_binary(vector) " + "FROM base.train WHERE id >= :lo AND id < :hi" + ) + return None + + +def _quantized_run_query(conn, params, query, k): + """Search quantized column only — no rescoring.""" + quantizer = params["quantizer"] + if quantizer == "int8": + return conn.execute( + "SELECT id, distance FROM vec_items " + "WHERE embedding MATCH vec_quantize_int8(:query, 'unit') AND k = :k", + {"query": query, "k": k}, + ).fetchall() + elif quantizer == "bit": + return conn.execute( + "SELECT id, distance FROM vec_items " + "WHERE embedding MATCH vec_quantize_binary(:query) AND k = :k", + {"query": query, "k": k}, + ).fetchall() + return None + + +def _quantized_query_sql(params): + quantizer = params["quantizer"] + if quantizer == "int8": + return ( + "SELECT id, distance FROM vec_items " + "WHERE embedding MATCH vec_quantize_int8(:query, 'unit') AND k = :k" + ) + elif quantizer == "bit": + return ( + "SELECT id, distance FROM vec_items " + "WHERE embedding MATCH vec_quantize_binary(:query) AND k = :k" + ) + return None + + +def _quantized_describe(params): + return f"quantized {params['quantizer']}" + + +INDEX_REGISTRY["quantized"] = { + "defaults": {"quantizer": "bit"}, + "create_table_sql": _quantized_create_table_sql, + "insert_sql": _quantized_insert_sql, + "post_insert_hook": None, + "train_sql": None, + "run_query": _quantized_run_query, + "query_sql": _quantized_query_sql, + "describe": _quantized_describe, +} + + # ============================================================================ # Rescore implementation # ============================================================================ def _rescore_create_table_sql(params): + D = params.get("_dimensions", 768) quantizer = params.get("quantizer", "bit") oversample = params.get("oversample", 8) return ( f"CREATE VIRTUAL TABLE vec_items USING vec0(" f" chunk_size=256," f" id integer primary key," - f" embedding float[768] distance_metric=cosine" + f" embedding float[{D}] distance_metric=cosine" f" indexed by rescore(quantizer={quantizer}, oversample={oversample}))" ) @@ -273,7 +423,9 @@ INDEX_REGISTRY["rescore"] = { "create_table_sql": _rescore_create_table_sql, "insert_sql": None, "post_insert_hook": None, + "train_sql": None, "run_query": None, # default MATCH query works — rescore is automatic + "query_sql": None, "describe": _rescore_describe, } @@ -284,20 +436,25 @@ INDEX_REGISTRY["rescore"] = { def _ivf_create_table_sql(params): + D = params.get("_dimensions", 768) + quantizer = params.get("quantizer", "none") + oversample = params.get("oversample", 1) + parts = [f"nlist={params['nlist']}", f"nprobe={params['nprobe']}"] + if quantizer != "none": + parts.append(f"quantizer={quantizer}") + if oversample > 1: + parts.append(f"oversample={oversample}") + ivf_args = ", ".join(parts) return ( f"CREATE VIRTUAL TABLE vec_items USING vec0(" - f" id integer primary key," - f" embedding float[768] distance_metric=cosine" - f" indexed by ivf(" - f" nlist={params['nlist']}," - f" nprobe={params['nprobe']}" - f" )" - f")" + f"id integer primary key, " + f"embedding float[{D}] distance_metric=cosine " + f"indexed by ivf({ivf_args}))" ) def _ivf_post_insert_hook(conn, params): - print(" Training k-means centroids...", flush=True) + print(" Training k-means centroids (built-in)...", flush=True) t0 = time.perf_counter() conn.execute("INSERT INTO vec_items(id) VALUES ('compute-centroids')") conn.commit() @@ -306,16 +463,118 @@ def _ivf_post_insert_hook(conn, params): return elapsed +def _ivf_faiss_kmeans_hook(conn, params): + """Run FAISS k-means externally, then load centroids via set-centroid commands. + + Called BEFORE any inserts — centroids are loaded first so vectors get + assigned to partitions on insert (no assign-vectors step needed). + """ + import subprocess + import tempfile + + nlist = params["nlist"] + ntrain = params.get("train_sample", 0) or params.get("faiss_kmeans", 10000) + niter = params.get("faiss_niter", 20) + base_db = params.get("_base_db") # injected by build_index + + print(f" Training k-means via FAISS ({nlist} clusters, {ntrain} vectors, {niter} iters)...", + flush=True) + + centroids_db_path = tempfile.mktemp(suffix=".db") + t0 = time.perf_counter() + + result = subprocess.run( + [ + "uv", "run", "--with", "faiss-cpu", "--with", "numpy", + "python", os.path.join(_SCRIPT_DIR, "faiss_kmeans.py"), + "--base-db", base_db, + "--ntrain", str(ntrain), + "--nclusters", str(nlist), + "--niter", str(niter), + "-o", centroids_db_path, + ], + capture_output=True, text=True, + ) + if result.returncode != 0: + print(f" FAISS stderr: {result.stderr}", flush=True) + raise RuntimeError(f"faiss_kmeans.py failed: {result.stderr}") + + faiss_elapsed = time.perf_counter() - t0 + print(f" FAISS k-means done in {faiss_elapsed:.1f}s", flush=True) + + # Load centroids into vec0 via set-centroid commands + print(f" Loading {nlist} centroids into vec0...", flush=True) + cdb = sqlite3.connect(centroids_db_path) + centroids = cdb.execute( + "SELECT centroid_id, centroid FROM centroids ORDER BY centroid_id" + ).fetchall() + meta = dict(cdb.execute("SELECT key, value FROM meta").fetchall()) + cdb.close() + os.remove(centroids_db_path) + + for cid, blob in centroids: + conn.execute( + "INSERT INTO vec_items(id, embedding) VALUES (?, ?)", + (f"set-centroid:{cid}", blob), + ) + conn.commit() + + elapsed = time.perf_counter() - t0 + print(f" Centroids loaded in {elapsed:.1f}s total", flush=True) + + # Stash meta for results tracking + params["_faiss_meta"] = { + "ntrain": meta.get("ntrain"), + "nclusters": meta.get("nclusters"), + "niter": meta.get("niter"), + "faiss_elapsed_s": meta.get("elapsed_s"), + "total_elapsed_s": round(elapsed, 3), + "trainer": "faiss", + } + + return elapsed + + +def _ivf_pre_query_hook(conn, params): + """Override nprobe at runtime via command dispatch.""" + nprobe = params.get("nprobe") + if nprobe: + conn.execute( + "INSERT INTO vec_items(id) VALUES (?)", + (f"nprobe={nprobe}",), + ) + conn.commit() + print(f" Set nprobe={nprobe}") + + def _ivf_describe(params): - return f"ivf nlist={params['nlist']:<4} nprobe={params['nprobe']}" + ts = params.get("train_sample", 0) + q = params.get("quantizer", "none") + os_val = params.get("oversample", 1) + fk = params.get("faiss_kmeans", 0) + desc = f"ivf nlist={params['nlist']:<4} nprobe={params['nprobe']}" + if q != "none": + desc += f" q={q}" + if os_val > 1: + desc += f" os={os_val}" + if fk: + desc += f" faiss" + if ts: + desc += f" ts={ts}" + return desc INDEX_REGISTRY["ivf"] = { - "defaults": {"nlist": 128, "nprobe": 16}, + "defaults": {"nlist": 128, "nprobe": 16, "train_sample": 0, + "quantizer": "none", "oversample": 1, + "faiss_kmeans": 0, "faiss_niter": 20}, "create_table_sql": _ivf_create_table_sql, "insert_sql": None, "post_insert_hook": _ivf_post_insert_hook, + "pre_query_hook": _ivf_pre_query_hook, + "train_sql": lambda _: "INSERT INTO vec_items(id) VALUES ('compute-centroids')", "run_query": None, + "query_sql": None, "describe": _ivf_describe, } @@ -326,24 +585,35 @@ INDEX_REGISTRY["ivf"] = { def _diskann_create_table_sql(params): + D = params.get("_dimensions", 768) + parts = [ + f"neighbor_quantizer={params['quantizer']}", + f"n_neighbors={params['R']}", + ] + L_insert = params.get("L_insert", 0) + L_search = params.get("L_search", 0) + if L_insert or L_search: + li = L_insert or params["L"] + ls = L_search or params["L"] + parts.append(f"search_list_size_insert={li}") + parts.append(f"search_list_size_search={ls}") + else: + parts.append(f"search_list_size={params['L']}") bt = params["buffer_threshold"] - extra = f", buffer_threshold={bt}" if bt > 0 else "" + if bt > 0: + parts.append(f"buffer_threshold={bt}") + diskann_args = ", ".join(parts) return ( f"CREATE VIRTUAL TABLE vec_items USING vec0(" - f" id integer primary key," - f" embedding float[768] distance_metric=cosine" - f" INDEXED BY diskann(" - f" neighbor_quantizer={params['quantizer']}," - f" n_neighbors={params['R']}," - f" search_list_size={params['L']}" - f" {extra}" - f" )" - f")" + f"id integer primary key, " + f"embedding float[{D}] distance_metric=cosine " + f"indexed by diskann({diskann_args}))" ) def _diskann_pre_query_hook(conn, params): - L_search = params.get("L_search") + """Override search_list_size_search at runtime via command dispatch.""" + L_search = params.get("L_search", 0) if L_search: conn.execute( "INSERT INTO vec_items(id) VALUES (?)", @@ -354,20 +624,27 @@ def _diskann_pre_query_hook(conn, params): def _diskann_describe(params): - desc = f"diskann q={params['quantizer']:<6} R={params['R']:<3} L={params['L']}" - L_search = params.get("L_search") - if L_search: - desc += f" L_search={L_search}" - return desc + L_insert = params.get("L_insert", 0) + L_search = params.get("L_search", 0) + if L_insert or L_search: + li = L_insert or params["L"] + ls = L_search or params["L"] + l_str = f"Li={li} Ls={ls}" + else: + l_str = f"L={params['L']}" + return f"diskann q={params['quantizer']:<6} R={params['R']:<3} {l_str}" INDEX_REGISTRY["diskann"] = { - "defaults": {"R": 72, "L": 128, "quantizer": "binary", "buffer_threshold": 0}, + "defaults": {"R": 72, "L": 128, "L_insert": 0, "L_search": 0, + "quantizer": "binary", "buffer_threshold": 0}, "create_table_sql": _diskann_create_table_sql, "insert_sql": None, "post_insert_hook": None, "pre_query_hook": _diskann_pre_query_hook, + "train_sql": None, "run_query": None, + "query_sql": None, "describe": _diskann_describe, } @@ -377,8 +654,9 @@ INDEX_REGISTRY["diskann"] = { # ============================================================================ INT_KEYS = { - "R", "L", "L_search", "buffer_threshold", "nlist", "nprobe", "oversample", - "n_trees", "search_k", + "R", "L", "L_insert", "L_search", "buffer_threshold", + "nlist", "nprobe", "oversample", "n_trees", "search_k", + "train_sample", "faiss_kmeans", "faiss_niter", } @@ -414,6 +692,12 @@ def parse_config(spec): return name, params +def params_to_json(params): + """Serialize params to JSON, excluding internal keys.""" + return json.dumps({k: v for k, v in sorted(params.items()) + if not k.startswith("_") and k != "index_type"}) + + # ============================================================================ # Shared helpers # ============================================================================ @@ -428,31 +712,59 @@ def load_query_vectors(base_db_path, n): return [(r[0], r[1]) for r in rows] -def insert_loop(conn, sql, subset_size, label=""): - t0 = time.perf_counter() - for lo in range(0, subset_size, INSERT_BATCH_SIZE): +def insert_loop(conn, sql, subset_size, label="", results_db=None, run_id=None, + start_from=0): + loop_start_ns = now_ns() + for lo in range(start_from, subset_size, INSERT_BATCH_SIZE): hi = min(lo + INSERT_BATCH_SIZE, subset_size) + batch_start_ns = now_ns() conn.execute(sql, {"lo": lo, "hi": hi}) conn.commit() + batch_end_ns = now_ns() done = hi + + if results_db is not None and run_id is not None: + elapsed_total_ns = batch_end_ns - loop_start_ns + elapsed_total_s = ns_to_s(elapsed_total_ns) + rate = done / elapsed_total_s if elapsed_total_s > 0 else 0 + results_db.execute( + "INSERT INTO insert_batches " + "(run_id, batch_lo, batch_hi, rows_in_batch, " + " started_ns, ended_ns, duration_ns, " + " cumulative_rows, rate_rows_per_s) " + "VALUES (?,?,?,?,?,?,?,?,?)", + ( + run_id, lo, hi, hi - lo, + batch_start_ns, batch_end_ns, + batch_end_ns - batch_start_ns, + done, round(rate, 1), + ), + ) + + if results_db is not None and run_id is not None: + results_db.commit() + if done % 5000 == 0 or done == subset_size: - elapsed = time.perf_counter() - t0 - rate = done / elapsed if elapsed > 0 else 0 + elapsed_total_ns = batch_end_ns - loop_start_ns + elapsed_total_s = ns_to_s(elapsed_total_ns) + rate = done / elapsed_total_s if elapsed_total_s > 0 else 0 print( f" [{label}] {done:>8}/{subset_size} " - f"{elapsed:.1f}s {rate:.0f} rows/s", + f"{elapsed_total_s:.1f}s {rate:.0f} rows/s", flush=True, ) - return time.perf_counter() - t0 + + return time.perf_counter() # not used for timing anymore, kept for compat -def create_bench_db(db_path, ext_path, base_db): +def create_bench_db(db_path, ext_path, base_db, page_size=4096): if os.path.exists(db_path): os.remove(db_path) conn = sqlite3.connect(db_path) conn.enable_load_extension(True) conn.load_extension(ext_path) - conn.execute("PRAGMA page_size=8192") + if page_size != 4096: + conn.execute(f"PRAGMA page_size={page_size}") conn.execute(f"ATTACH DATABASE '{base_db}' AS base") return conn @@ -475,49 +787,212 @@ DEFAULT_INSERT_SQL = ( "SELECT id, vector FROM base.train WHERE id >= :lo AND id < :hi" ) +DEFAULT_QUERY_SQL = ( + "SELECT id, distance FROM vec_items " + "WHERE embedding MATCH :query AND k = :k" +) + + +# ============================================================================ +# Results DB helpers +# ============================================================================ + +_RESULTS_SCHEMA_PATH = os.path.join(_SCRIPT_DIR, "results_schema.sql") + + +def open_results_db(out_dir, dataset, subset_size, results_db_name="results.db"): + """Open/create the results DB in WAL mode.""" + sub_dir = os.path.join(out_dir, dataset, str(subset_size)) + os.makedirs(sub_dir, exist_ok=True) + db_path = os.path.join(sub_dir, results_db_name) + db = sqlite3.connect(db_path, timeout=60) + db.execute("PRAGMA journal_mode=WAL") + db.execute("PRAGMA busy_timeout=60000") + # Migrate existing DBs: add phase column before running schema + cols = {r[1] for r in db.execute("PRAGMA table_info(runs)").fetchall()} + if cols and "phase" not in cols: + db.execute("ALTER TABLE runs ADD COLUMN phase TEXT NOT NULL DEFAULT 'both'") + db.commit() + with open(_RESULTS_SCHEMA_PATH) as f: + db.executescript(f.read()) + return db, sub_dir + + +def create_run(results_db, config_name, index_type, params, dataset, + subset_size, k, n_queries, phase="both"): + """Insert a new run row and return the run_id.""" + cur = results_db.execute( + "INSERT INTO runs " + "(config_name, index_type, params, dataset, subset_size, " + " k, n_queries, phase, status, created_at_ns) " + "VALUES (?,?,?,?,?,?,?,?,?,?)", + ( + config_name, index_type, params_to_json(params), dataset, + subset_size, k, n_queries, phase, "pending", now_ns(), + ), + ) + results_db.commit() + return cur.lastrowid + + +def update_run_status(results_db, run_id, status): + results_db.execute( + "UPDATE runs SET status=? WHERE run_id=?", (status, run_id) + ) + results_db.commit() + # ============================================================================ # Build # ============================================================================ -def build_index(base_db, ext_path, name, params, subset_size, out_dir): - db_path = os.path.join(out_dir, f"{name}.{subset_size}.db") - conn = create_bench_db(db_path, ext_path, base_db) +def build_index(base_db, ext_path, name, params, subset_size, sub_dir, + results_db=None, run_id=None, k=None): + db_path = os.path.join(sub_dir, f"{name}.{subset_size}.db") + params["_base_db"] = base_db # expose to hooks (e.g. FAISS k-means) + page_size = int(params.get("page_size", 4096)) + conn = create_bench_db(db_path, ext_path, base_db, page_size=page_size) reg = INDEX_REGISTRY[params["index_type"]] - conn.execute(reg["create_table_sql"](params)) + create_sql = reg["create_table_sql"](params) + conn.execute(create_sql) label = params["index_type"] print(f" Inserting {subset_size} vectors...") sql_fn = reg.get("insert_sql") - sql = sql_fn(params) if sql_fn else None - if sql is None: - sql = DEFAULT_INSERT_SQL + insert_sql = sql_fn(params) if sql_fn else None + if insert_sql is None: + insert_sql = DEFAULT_INSERT_SQL - insert_time = insert_loop(conn, sql, subset_size, label) + train_sql_fn = reg.get("train_sql") + train_sql = train_sql_fn(params) if train_sql_fn else None - train_time = 0.0 + query_sql_fn = reg.get("query_sql") + query_sql = query_sql_fn(params) if query_sql_fn else None + if query_sql is None: + query_sql = DEFAULT_QUERY_SQL + + # -- Insert + Training phases -- + train_sample = params.get("train_sample", 0) hook = reg.get("post_insert_hook") - if hook: - train_time = hook(conn, params) + faiss_kmeans = params.get("faiss_kmeans", 0) + + train_started_ns = None + train_ended_ns = None + train_duration_ns = None + train_time_s = 0.0 + + if faiss_kmeans: + # FAISS mode: train on base.db first, load centroids, then insert all + if results_db and run_id: + update_run_status(results_db, run_id, "training") + train_started_ns = now_ns() + train_time_s = _ivf_faiss_kmeans_hook(conn, params) + train_ended_ns = now_ns() + train_duration_ns = train_ended_ns - train_started_ns + + # Now insert all vectors (they get assigned on insert) + if results_db and run_id: + update_run_status(results_db, run_id, "inserting") + insert_started_ns = now_ns() + insert_loop(conn, insert_sql, subset_size, label, + results_db=results_db, run_id=run_id) + insert_ended_ns = now_ns() + insert_duration_ns = insert_ended_ns - insert_started_ns + + elif train_sample and hook and train_sample < subset_size: + # Built-in k-means: insert sample, train, insert rest + if results_db and run_id: + update_run_status(results_db, run_id, "inserting") + insert_started_ns = now_ns() + + print(f" Inserting {train_sample} vectors (training sample)...") + insert_loop(conn, insert_sql, train_sample, label, + results_db=results_db, run_id=run_id) + insert_paused_ns = now_ns() + + # -- Training on sample -- + if results_db and run_id: + update_run_status(results_db, run_id, "training") + train_started_ns = now_ns() + train_time_s = hook(conn, params) + train_ended_ns = now_ns() + train_duration_ns = train_ended_ns - train_started_ns + + # -- Insert remaining vectors -- + if results_db and run_id: + update_run_status(results_db, run_id, "inserting") + print(f" Inserting remaining {subset_size - train_sample} vectors...") + insert_loop(conn, insert_sql, subset_size, label, + results_db=results_db, run_id=run_id, + start_from=train_sample) + insert_ended_ns = now_ns() + + # Insert time = total wall time minus training time + insert_duration_ns = (insert_paused_ns - insert_started_ns) + \ + (insert_ended_ns - train_ended_ns) + else: + # Standard flow: insert all, then train + if results_db and run_id: + update_run_status(results_db, run_id, "inserting") + insert_started_ns = now_ns() + + insert_loop(conn, insert_sql, subset_size, label, + results_db=results_db, run_id=run_id) + insert_ended_ns = now_ns() + insert_duration_ns = insert_ended_ns - insert_started_ns + + if hook: + if results_db and run_id: + update_run_status(results_db, run_id, "training") + train_started_ns = now_ns() + train_time_s = hook(conn, params) + train_ended_ns = now_ns() + train_duration_ns = train_ended_ns - train_started_ns row_count = conn.execute("SELECT count(*) FROM vec_items").fetchone()[0] conn.close() - file_size_mb = os.path.getsize(db_path) / (1024 * 1024) + file_size_bytes = os.path.getsize(db_path) + + build_duration_ns = insert_duration_ns + (train_duration_ns or 0) + insert_time_s = ns_to_s(insert_duration_ns) + + # If FAISS was used for training, record its meta as train_sql + faiss_meta = params.get("_faiss_meta") + if faiss_meta: + train_sql = json.dumps(faiss_meta) + + # Write run_results (build portion) + if results_db and run_id: + results_db.execute( + "INSERT INTO run_results " + "(run_id, insert_started_ns, insert_ended_ns, insert_duration_ns, " + " train_started_ns, train_ended_ns, train_duration_ns, " + " build_duration_ns, db_file_size_bytes, db_file_path, " + " create_sql, insert_sql, train_sql, query_sql, k) " + "VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", + ( + run_id, insert_started_ns, insert_ended_ns, insert_duration_ns, + train_started_ns, train_ended_ns, train_duration_ns, + build_duration_ns, file_size_bytes, db_path, + create_sql, insert_sql, train_sql, query_sql, k, + ), + ) + results_db.commit() return { "db_path": db_path, - "insert_time_s": round(insert_time, 3), - "train_time_s": round(train_time, 3), - "total_time_s": round(insert_time + train_time, 3), - "insert_per_vec_ms": round((insert_time / row_count) * 1000, 2) + "insert_time_s": round(insert_time_s, 3), + "train_time_s": round(train_time_s, 3), + "total_time_s": round(insert_time_s + train_time_s, 3), + "insert_per_vec_ms": round((insert_time_s / row_count) * 1000, 2) if row_count else 0, "rows": row_count, - "file_size_mb": round(file_size_mb, 2), + "file_size_mb": round(file_size_bytes / (1024 * 1024), 2), } @@ -535,7 +1010,7 @@ def _default_match_query(conn, query, k): def measure_knn(db_path, ext_path, base_db, params, subset_size, k=10, n=50, - pre_query_hook=None): + results_db=None, run_id=None, pre_query_hook=None, warmup=0): conn = sqlite3.connect(db_path) conn.enable_load_extension(True) conn.load_extension(ext_path) @@ -549,10 +1024,25 @@ def measure_knn(db_path, ext_path, base_db, params, subset_size, k=10, n=50, reg = INDEX_REGISTRY[params["index_type"]] query_fn = reg.get("run_query") + # Warmup: run random queries to populate OS page cache + if warmup > 0: + import random + warmup_vecs = [qv for _, qv in query_vectors] + print(f" Warming up with {warmup} queries...", flush=True) + for _ in range(warmup): + wq = random.choice(warmup_vecs) + if query_fn: + query_fn(conn, params, wq, k) + else: + _default_match_query(conn, wq, k) + + if results_db and run_id: + update_run_status(results_db, run_id, "querying") + times_ms = [] recalls = [] - for qid, query in query_vectors: - t0 = time.perf_counter() + for i, (qid, query) in enumerate(query_vectors): + started_ns = now_ns() results = None if query_fn: @@ -560,9 +1050,13 @@ def measure_knn(db_path, ext_path, base_db, params, subset_size, k=10, n=50, if results is None: results = _default_match_query(conn, query, k) - elapsed_ms = (time.perf_counter() - t0) * 1000 - times_ms.append(elapsed_ms) - result_ids = set(r[0] for r in results) + ended_ns = now_ns() + duration_ms = ns_to_ms(ended_ns - started_ns) + times_ms.append(duration_ms) + + result_ids_list = [r[0] for r in results] + result_distances_list = [r[1] for r in results] + result_ids = set(result_ids_list) # Ground truth: use pre-computed neighbors table for full dataset, # otherwise brute-force over the subset @@ -580,91 +1074,62 @@ def measure_knn(db_path, ext_path, base_db, params, subset_size, k=10, n=50, ")", {"query": query, "k": k, "n": subset_size}, ).fetchall() - gt_ids = set(r[0] for r in gt_rows) + gt_ids_list = [r[0] for r in gt_rows] + gt_ids = set(gt_ids_list) if gt_ids: - recalls.append(len(result_ids & gt_ids) / len(gt_ids)) + q_recall = len(result_ids & gt_ids) / len(gt_ids) else: - recalls.append(0.0) + q_recall = 0.0 + recalls.append(q_recall) + + if results_db and run_id: + results_db.execute( + "INSERT INTO queries " + "(run_id, k, query_vector_id, started_ns, ended_ns, duration_ms, " + " result_ids, result_distances, ground_truth_ids, recall) " + "VALUES (?,?,?,?,?,?,?,?,?,?)", + ( + run_id, k, qid, started_ns, ended_ns, round(duration_ms, 4), + json.dumps(result_ids_list), + json.dumps(result_distances_list), + json.dumps(gt_ids_list), + round(q_recall, 6), + ), + ) + results_db.commit() conn.close() + mean_ms = round(statistics.mean(times_ms), 2) + median_ms = round(statistics.median(times_ms), 2) + p99_ms = (round(sorted(times_ms)[int(len(times_ms) * 0.99)], 2) + if len(times_ms) > 1 + else round(times_ms[0], 2)) + total_ms = round(sum(times_ms), 2) + recall = round(statistics.mean(recalls), 4) + qps = round(len(times_ms) / (total_ms / 1000), 1) if total_ms > 0 else 0 + + # Update run_results with query aggregates + if results_db and run_id: + results_db.execute( + "UPDATE run_results SET " + "query_mean_ms=?, query_median_ms=?, query_p99_ms=?, " + "query_total_ms=?, qps=?, recall=? " + "WHERE run_id=?", + (mean_ms, median_ms, p99_ms, total_ms, qps, recall, run_id), + ) + update_run_status(results_db, run_id, "done") + return { - "mean_ms": round(statistics.mean(times_ms), 2), - "median_ms": round(statistics.median(times_ms), 2), - "p99_ms": round(sorted(times_ms)[int(len(times_ms) * 0.99)], 2) - if len(times_ms) > 1 - else round(times_ms[0], 2), - "total_ms": round(sum(times_ms), 2), - "recall": round(statistics.mean(recalls), 4), + "mean_ms": mean_ms, + "median_ms": median_ms, + "p99_ms": p99_ms, + "total_ms": total_ms, + "recall": recall, } -# ============================================================================ -# Results persistence -# ============================================================================ - - -def open_results_db(results_path): - db = sqlite3.connect(results_path) - db.executescript(open(os.path.join(_SCRIPT_DIR, "schema.sql")).read()) - # Migrate existing DBs that predate the runs table - cols = {r[1] for r in db.execute("PRAGMA table_info(runs)").fetchall()} - if "phase" not in cols: - db.execute("ALTER TABLE runs ADD COLUMN phase TEXT NOT NULL DEFAULT 'both'") - db.commit() - return db - - -def create_run(db, config_name, index_type, subset_size, phase, k=None, n=None): - cur = db.execute( - "INSERT INTO runs (config_name, index_type, subset_size, phase, status, k, n) " - "VALUES (?, ?, ?, ?, 'pending', ?, ?)", - (config_name, index_type, subset_size, phase, k, n), - ) - db.commit() - return cur.lastrowid - - -def update_run(db, run_id, **kwargs): - sets = ", ".join(f"{k} = ?" for k in kwargs) - vals = list(kwargs.values()) + [run_id] - db.execute(f"UPDATE runs SET {sets} WHERE run_id = ?", vals) - db.commit() - - -def save_results(results_path, rows): - db = sqlite3.connect(results_path) - db.executescript(open(os.path.join(_SCRIPT_DIR, "schema.sql")).read()) - for r in rows: - db.execute( - "INSERT OR REPLACE INTO build_results " - "(config_name, index_type, subset_size, db_path, " - " insert_time_s, train_time_s, total_time_s, rows, file_size_mb) " - "VALUES (?,?,?,?,?,?,?,?,?)", - ( - r["name"], r["index_type"], r["n_vectors"], r["db_path"], - r["insert_time_s"], r["train_time_s"], r["total_time_s"], - r["rows"], r["file_size_mb"], - ), - ) - db.execute( - "INSERT OR REPLACE INTO bench_results " - "(config_name, index_type, subset_size, k, n, " - " mean_ms, median_ms, p99_ms, total_ms, qps, recall, db_path) " - "VALUES (?,?,?,?,?,?,?,?,?,?,?,?)", - ( - r["name"], r["index_type"], r["n_vectors"], r["k"], r["n_queries"], - r["mean_ms"], r["median_ms"], r["p99_ms"], r["total_ms"], - round(r["n_queries"] / (r["total_ms"] / 1000), 1) - if r["total_ms"] > 0 else 0, - r["recall"], r["db_path"], - ), - ) - db.commit() - db.close() - - # ============================================================================ # Reporting # ============================================================================ @@ -699,22 +1164,38 @@ def main(): epilog=__doc__, ) parser.add_argument("configs", nargs="+", help="config specs (name:type=X,key=val,...)") - parser.add_argument("--subset-size", type=int, required=True) + parser.add_argument("--subset-size", type=int, default=None, + help="number of vectors to use (default: all)") parser.add_argument("-k", type=int, default=10, help="KNN k (default 10)") parser.add_argument("-n", type=int, default=50, help="number of queries (default 50)") parser.add_argument("--phase", choices=["build", "query", "both"], default="both", help="build=build only, query=query existing index, both=default") - parser.add_argument("--base-db", default=BASE_DB) + parser.add_argument("--dataset", default="cohere1m", + choices=list(DATASETS.keys()), + help="dataset name (default: cohere1m)") parser.add_argument("--ext", default=EXT_PATH) - parser.add_argument("-o", "--out-dir", default="runs") - parser.add_argument("--results-db", default=None, - help="path to results DB (default: /results.db)") + parser.add_argument("-o", "--out-dir", default=os.path.join(_SCRIPT_DIR, "runs")) + parser.add_argument("--warmup", type=int, default=0, + help="run N random warmup queries before measuring (default: 0)") + parser.add_argument("--results-db-name", default="results.db", + help="results DB filename (default: results.db)") args = parser.parse_args() - os.makedirs(args.out_dir, exist_ok=True) - results_db_path = args.results_db or os.path.join(args.out_dir, "results.db") + dataset_cfg = DATASETS[args.dataset] + base_db = dataset_cfg["base_db"] + dimensions = dataset_cfg["dimensions"] + + if args.subset_size is None: + _tmp = sqlite3.connect(base_db) + args.subset_size = _tmp.execute("SELECT COUNT(*) FROM train").fetchone()[0] + _tmp.close() + print(f"Using full dataset: {args.subset_size} vectors") + + results_db, sub_dir = open_results_db(args.out_dir, args.dataset, args.subset_size, + results_db_name=args.results_db_name) configs = [parse_config(c) for c in args.configs] - results_db = open_results_db(results_db_path) + for _, params in configs: + params["_dimensions"] = dimensions all_results = [] for i, (name, params) in enumerate(configs, 1): @@ -722,31 +1203,30 @@ def main(): desc = reg["describe"](params) print(f"\n[{i}/{len(configs)}] {name} ({desc.strip()}) [phase={args.phase}]") - db_path = os.path.join(args.out_dir, f"{name}.{args.subset_size}.db") + db_path = os.path.join(sub_dir, f"{name}.{args.subset_size}.db") if args.phase == "build": - run_id = create_run(results_db, name, params["index_type"], - args.subset_size, "build") - update_run(results_db, run_id, status="inserting") + run_id = create_run( + results_db, name, params["index_type"], params, + args.dataset, args.subset_size, args.k, args.n, phase="build", + ) - build = build_index( - args.base_db, args.ext, name, params, args.subset_size, args.out_dir - ) - train_str = f" + {build['train_time_s']}s train" if build["train_time_s"] > 0 else "" - print( - f" Build: {build['insert_time_s']}s insert{train_str} " - f"{build['file_size_mb']} MB" - ) - update_run(results_db, run_id, - status="built", - db_path=build["db_path"], - insert_time_s=build["insert_time_s"], - train_time_s=build["train_time_s"], - total_build_time_s=build["total_time_s"], - rows=build["rows"], - file_size_mb=build["file_size_mb"], - finished_at=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")) - print(f" Index DB: {build['db_path']}") + try: + build = build_index( + base_db, args.ext, name, params, args.subset_size, sub_dir, + results_db=results_db, run_id=run_id, k=args.k, + ) + train_str = f" + {build['train_time_s']}s train" if build["train_time_s"] > 0 else "" + print( + f" Build: {build['insert_time_s']}s insert{train_str} " + f"{build['file_size_mb']} MB" + ) + update_run_status(results_db, run_id, "built") + print(f" Index DB: {build['db_path']}") + except Exception as e: + update_run_status(results_db, run_id, "error") + print(f" ERROR: {e}") + raise elif args.phase == "query": if not os.path.exists(db_path): @@ -755,30 +1235,35 @@ def main(): f"Build it first with: --phase build" ) - run_id = create_run(results_db, name, params["index_type"], - args.subset_size, "query", k=args.k, n=args.n) - update_run(results_db, run_id, status="querying") - - pre_hook = reg.get("pre_query_hook") - print(f" Measuring KNN (k={args.k}, n={args.n})...") - knn = measure_knn( - db_path, args.ext, args.base_db, - params, args.subset_size, k=args.k, n=args.n, - pre_query_hook=pre_hook, + run_id = create_run( + results_db, name, params["index_type"], params, + args.dataset, args.subset_size, args.k, args.n, phase="query", ) - print(f" KNN: mean={knn['mean_ms']}ms recall@{args.k}={knn['recall']}") - qps = round(args.n / (knn["total_ms"] / 1000), 1) if knn["total_ms"] > 0 else 0 - update_run(results_db, run_id, - status="done", - db_path=db_path, - mean_ms=knn["mean_ms"], - median_ms=knn["median_ms"], - p99_ms=knn["p99_ms"], - total_query_ms=knn["total_ms"], - qps=qps, - recall=knn["recall"], - finished_at=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")) + try: + # Create a run_results row so measure_knn can UPDATE it + file_size_bytes = os.path.getsize(db_path) + results_db.execute( + "INSERT INTO run_results " + "(run_id, db_file_size_bytes, db_file_path, k) " + "VALUES (?,?,?,?)", + (run_id, file_size_bytes, db_path, args.k), + ) + results_db.commit() + + pre_hook = reg.get("pre_query_hook") + print(f" Measuring KNN (k={args.k}, n={args.n})...") + knn = measure_knn( + db_path, args.ext, base_db, + params, args.subset_size, k=args.k, n=args.n, + results_db=results_db, run_id=run_id, + pre_query_hook=pre_hook, warmup=args.warmup, + ) + print(f" KNN: mean={knn['mean_ms']}ms recall@{args.k}={knn['recall']}") + except Exception as e: + update_run_status(results_db, run_id, "error") + print(f" ERROR: {e}") + raise file_size_mb = os.path.getsize(db_path) / (1024 * 1024) all_results.append({ @@ -803,43 +1288,35 @@ def main(): }) else: # both - run_id = create_run(results_db, name, params["index_type"], - args.subset_size, "both", k=args.k, n=args.n) - update_run(results_db, run_id, status="inserting") - - build = build_index( - args.base_db, args.ext, name, params, args.subset_size, args.out_dir + run_id = create_run( + results_db, name, params["index_type"], params, + args.dataset, args.subset_size, args.k, args.n, phase="both", ) - train_str = f" + {build['train_time_s']}s train" if build["train_time_s"] > 0 else "" - print( - f" Build: {build['insert_time_s']}s insert{train_str} " - f"{build['file_size_mb']} MB" - ) - update_run(results_db, run_id, status="querying", - db_path=build["db_path"], - insert_time_s=build["insert_time_s"], - train_time_s=build["train_time_s"], - total_build_time_s=build["total_time_s"], - rows=build["rows"], - file_size_mb=build["file_size_mb"]) - print(f" Measuring KNN (k={args.k}, n={args.n})...") - knn = measure_knn( - build["db_path"], args.ext, args.base_db, - params, args.subset_size, k=args.k, n=args.n, - ) - print(f" KNN: mean={knn['mean_ms']}ms recall@{args.k}={knn['recall']}") + try: + build = build_index( + base_db, args.ext, name, params, args.subset_size, sub_dir, + results_db=results_db, run_id=run_id, k=args.k, + ) + train_str = f" + {build['train_time_s']}s train" if build["train_time_s"] > 0 else "" + print( + f" Build: {build['insert_time_s']}s insert{train_str} " + f"{build['file_size_mb']} MB" + ) - qps = round(args.n / (knn["total_ms"] / 1000), 1) if knn["total_ms"] > 0 else 0 - update_run(results_db, run_id, - status="done", - mean_ms=knn["mean_ms"], - median_ms=knn["median_ms"], - p99_ms=knn["p99_ms"], - total_query_ms=knn["total_ms"], - qps=qps, - recall=knn["recall"], - finished_at=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")) + pre_hook = reg.get("pre_query_hook") + print(f" Measuring KNN (k={args.k}, n={args.n})...") + knn = measure_knn( + build["db_path"], args.ext, base_db, + params, args.subset_size, k=args.k, n=args.n, + results_db=results_db, run_id=run_id, + pre_query_hook=pre_hook, warmup=args.warmup, + ) + print(f" KNN: mean={knn['mean_ms']}ms recall@{args.k}={knn['recall']}") + except Exception as e: + update_run_status(results_db, run_id, "error") + print(f" ERROR: {e}") + raise all_results.append({ "name": name, @@ -862,14 +1339,11 @@ def main(): "recall": knn["recall"], }) - results_db.close() - if all_results: print_report(all_results) - save_results(results_db_path, all_results) - print(f"\nResults saved to {results_db_path}") - elif args.phase == "build": - print(f"\nBuild complete. Results tracked in {results_db_path}") + + print(f"\nResults DB: {os.path.join(sub_dir, 'results.db')}") + results_db.close() if __name__ == "__main__": diff --git a/benchmarks-ann/datasets/cohere10m/Makefile b/benchmarks-ann/datasets/cohere10m/Makefile new file mode 100644 index 0000000..322b21c --- /dev/null +++ b/benchmarks-ann/datasets/cohere10m/Makefile @@ -0,0 +1,27 @@ +BASE_URL = https://assets.zilliz.com/benchmark/cohere_large_10m + +TRAIN_PARQUETS = $(shell printf 'train-%02d-of-10.parquet ' 0 1 2 3 4 5 6 7 8 9) +OTHER_PARQUETS = test.parquet neighbors.parquet +PARQUETS = $(TRAIN_PARQUETS) $(OTHER_PARQUETS) + +.PHONY: all download clean + +all: base.db + +# Use: make -j12 download +download: $(PARQUETS) + +train-%-of-10.parquet: + curl -L -o $@ $(BASE_URL)/$@ + +test.parquet: + curl -L -o $@ $(BASE_URL)/test.parquet + +neighbors.parquet: + curl -L -o $@ $(BASE_URL)/neighbors.parquet + +base.db: $(PARQUETS) build_base_db.py + uv run --with pandas --with pyarrow python build_base_db.py + +clean: + rm -f base.db diff --git a/benchmarks-ann/datasets/cohere10m/build_base_db.py b/benchmarks-ann/datasets/cohere10m/build_base_db.py new file mode 100644 index 0000000..ceaeb22 --- /dev/null +++ b/benchmarks-ann/datasets/cohere10m/build_base_db.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 +"""Build base.db from downloaded parquet files (10M dataset, 10 train shards). + +Reads train-00-of-10.parquet .. train-09-of-10.parquet, test.parquet, +neighbors.parquet and creates a SQLite database with tables: + train, query_vectors, neighbors. + +Usage: + uv run --with pandas --with pyarrow python build_base_db.py +""" +import json +import os +import sqlite3 +import struct +import sys +import time + +import pandas as pd + +TRAIN_SHARDS = 10 + + +def float_list_to_blob(floats): + """Pack a list of floats into a little-endian f32 blob.""" + return struct.pack(f"<{len(floats)}f", *floats) + + +def main(): + script_dir = os.path.dirname(os.path.abspath(__file__)) + db_path = os.path.join(script_dir, "base.db") + + train_paths = [ + os.path.join(script_dir, f"train-{i:02d}-of-{TRAIN_SHARDS}.parquet") + for i in range(TRAIN_SHARDS) + ] + test_path = os.path.join(script_dir, "test.parquet") + neighbors_path = os.path.join(script_dir, "neighbors.parquet") + + for p in train_paths + [test_path, neighbors_path]: + if not os.path.exists(p): + print(f"ERROR: {p} not found. Run 'make download' first.") + sys.exit(1) + + if os.path.exists(db_path): + os.remove(db_path) + + conn = sqlite3.connect(db_path) + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA page_size=4096") + + # --- query_vectors (from test.parquet) --- + print("Loading test.parquet (query vectors)...") + t0 = time.perf_counter() + df_test = pd.read_parquet(test_path) + conn.execute( + "CREATE TABLE query_vectors (id INTEGER PRIMARY KEY, vector BLOB)" + ) + rows = [] + for _, row in df_test.iterrows(): + rows.append((int(row["id"]), float_list_to_blob(row["emb"]))) + conn.executemany("INSERT INTO query_vectors (id, vector) VALUES (?, ?)", rows) + conn.commit() + print(f" {len(rows)} query vectors in {time.perf_counter() - t0:.1f}s") + + # --- neighbors (from neighbors.parquet) --- + print("Loading neighbors.parquet...") + t0 = time.perf_counter() + df_neighbors = pd.read_parquet(neighbors_path) + conn.execute( + "CREATE TABLE neighbors (" + " query_vector_id INTEGER, rank INTEGER, neighbors_id TEXT," + " UNIQUE(query_vector_id, rank))" + ) + rows = [] + for _, row in df_neighbors.iterrows(): + qid = int(row["id"]) + nids = row["neighbors_id"] + if isinstance(nids, str): + nids = json.loads(nids) + for rank, nid in enumerate(nids): + rows.append((qid, rank, str(int(nid)))) + conn.executemany( + "INSERT INTO neighbors (query_vector_id, rank, neighbors_id) VALUES (?, ?, ?)", + rows, + ) + conn.commit() + print(f" {len(rows)} neighbor rows in {time.perf_counter() - t0:.1f}s") + + # --- train (from 10 shard parquets) --- + print(f"Loading {TRAIN_SHARDS} train shards (10M vectors, this will take a while)...") + conn.execute( + "CREATE TABLE train (id INTEGER PRIMARY KEY, vector BLOB)" + ) + + global_t0 = time.perf_counter() + total_inserted = 0 + batch_size = 10000 + + for shard_idx, train_path in enumerate(train_paths): + print(f" Shard {shard_idx + 1}/{TRAIN_SHARDS}: {os.path.basename(train_path)}") + t0 = time.perf_counter() + df = pd.read_parquet(train_path) + shard_len = len(df) + + for start in range(0, shard_len, batch_size): + chunk = df.iloc[start : start + batch_size] + rows = [] + for _, row in chunk.iterrows(): + rows.append((int(row["id"]), float_list_to_blob(row["emb"]))) + conn.executemany("INSERT INTO train (id, vector) VALUES (?, ?)", rows) + conn.commit() + + total_inserted += len(rows) + if total_inserted % 100000 < batch_size: + elapsed = time.perf_counter() - global_t0 + rate = total_inserted / elapsed if elapsed > 0 else 0 + print( + f" {total_inserted:>10} {elapsed:.0f}s {rate:.0f} rows/s", + flush=True, + ) + + shard_elapsed = time.perf_counter() - t0 + print(f" shard done: {shard_len} rows in {shard_elapsed:.1f}s") + + elapsed = time.perf_counter() - global_t0 + print(f" {total_inserted} train vectors in {elapsed:.1f}s") + + conn.close() + size_mb = os.path.getsize(db_path) / (1024 * 1024) + print(f"\nDone: {db_path} ({size_mb:.0f} MB)") + + +if __name__ == "__main__": + main() diff --git a/benchmarks-ann/seed/.gitignore b/benchmarks-ann/datasets/cohere1m/.gitignore similarity index 100% rename from benchmarks-ann/seed/.gitignore rename to benchmarks-ann/datasets/cohere1m/.gitignore diff --git a/benchmarks-ann/seed/Makefile b/benchmarks-ann/datasets/cohere1m/Makefile similarity index 100% rename from benchmarks-ann/seed/Makefile rename to benchmarks-ann/datasets/cohere1m/Makefile diff --git a/benchmarks-ann/seed/build_base_db.py b/benchmarks-ann/datasets/cohere1m/build_base_db.py similarity index 100% rename from benchmarks-ann/seed/build_base_db.py rename to benchmarks-ann/datasets/cohere1m/build_base_db.py diff --git a/benchmarks-ann/datasets/nyt-1024/Makefile b/benchmarks-ann/datasets/nyt-1024/Makefile new file mode 100644 index 0000000..0547409 --- /dev/null +++ b/benchmarks-ann/datasets/nyt-1024/Makefile @@ -0,0 +1,30 @@ +MODEL ?= mixedbread-ai/mxbai-embed-large-v1 +K ?= 100 +BATCH_SIZE ?= 256 +DATA_DIR ?= ../nyt/data + +all: base.db + +# Reuse data from ../nyt +$(DATA_DIR): + $(MAKE) -C ../nyt data + +contents.db: $(DATA_DIR) + uv run ../nyt-768/build-contents.py --data-dir $(DATA_DIR) -o $@ + +base.db: contents.db queries.txt + uv run build-base.py \ + --contents-db contents.db \ + --model $(MODEL) \ + --queries-file queries.txt \ + --batch-size $(BATCH_SIZE) \ + --k $(K) \ + -o $@ + +queries.txt: + cp ../nyt/queries.txt $@ + +clean: + rm -f base.db contents.db + +.PHONY: all clean diff --git a/benchmarks-ann/datasets/nyt-1024/build-base.py b/benchmarks-ann/datasets/nyt-1024/build-base.py new file mode 100644 index 0000000..a0a6b22 --- /dev/null +++ b/benchmarks-ann/datasets/nyt-1024/build-base.py @@ -0,0 +1,163 @@ +# /// script +# requires-python = ">=3.12" +# dependencies = [ +# "sentence-transformers", +# "torch<=2.7", +# "tqdm", +# ] +# /// + +import argparse +import sqlite3 +from array import array +from itertools import batched + +from sentence_transformers import SentenceTransformer +from tqdm import tqdm + + +def main(): + parser = argparse.ArgumentParser( + description="Build base.db with train vectors, query vectors, and brute-force KNN neighbors", + ) + parser.add_argument( + "--contents-db", "-c", default=None, + help="Path to contents.db (source of headlines and IDs)", + ) + parser.add_argument( + "--model", "-m", default="mixedbread-ai/mxbai-embed-large-v1", + help="HuggingFace model ID (default: mixedbread-ai/mxbai-embed-large-v1)", + ) + parser.add_argument( + "--queries-file", "-q", default="queries.txt", + help="Path to the queries file (default: queries.txt)", + ) + parser.add_argument( + "--output", "-o", required=True, + help="Path to the output base.db", + ) + parser.add_argument( + "--batch-size", "-b", type=int, default=256, + help="Batch size for embedding (default: 256)", + ) + parser.add_argument( + "--k", "-k", type=int, default=100, + help="Number of nearest neighbors (default: 100)", + ) + parser.add_argument( + "--limit", "-l", type=int, default=0, + help="Limit number of headlines to embed (0 = all)", + ) + parser.add_argument( + "--vec-path", "-v", default="~/projects/sqlite-vec/dist/vec0", + help="Path to sqlite-vec extension (default: ~/projects/sqlite-vec/dist/vec0)", + ) + parser.add_argument( + "--skip-neighbors", action="store_true", + help="Skip the brute-force KNN neighbor computation", + ) + args = parser.parse_args() + + import os + vec_path = os.path.expanduser(args.vec_path) + + print(f"Loading model {args.model}...") + model = SentenceTransformer(args.model) + + # Read headlines from contents.db + src = sqlite3.connect(args.contents_db) + limit_clause = f" LIMIT {args.limit}" if args.limit > 0 else "" + headlines = src.execute( + f"SELECT id, headline FROM contents ORDER BY id{limit_clause}" + ).fetchall() + src.close() + print(f"Loaded {len(headlines)} headlines from {args.contents_db}") + + # Read queries + with open(args.queries_file) as f: + queries = [line.strip() for line in f if line.strip()] + print(f"Loaded {len(queries)} queries from {args.queries_file}") + + # Create output database + db = sqlite3.connect(args.output) + db.enable_load_extension(True) + db.load_extension(vec_path) + db.enable_load_extension(False) + + db.execute("CREATE TABLE IF NOT EXISTS train(id INTEGER PRIMARY KEY, vector BLOB)") + db.execute("CREATE TABLE IF NOT EXISTS query_vectors(id INTEGER PRIMARY KEY, vector BLOB)") + db.execute( + "CREATE TABLE IF NOT EXISTS neighbors(" + " query_vector_id INTEGER, rank INTEGER, neighbors_id TEXT," + " UNIQUE(query_vector_id, rank))" + ) + + # Step 1: Embed headlines -> train table + print("Embedding headlines...") + for batch in tqdm( + batched(headlines, args.batch_size), + total=(len(headlines) + args.batch_size - 1) // args.batch_size, + ): + ids = [r[0] for r in batch] + texts = [r[1] for r in batch] + embeddings = model.encode(texts, normalize_embeddings=True) + + params = [ + (int(rid), array("f", emb.tolist()).tobytes()) + for rid, emb in zip(ids, embeddings) + ] + db.executemany("INSERT INTO train VALUES (?, ?)", params) + db.commit() + + del headlines + n = db.execute("SELECT count(*) FROM train").fetchone()[0] + print(f"Embedded {n} headlines") + + # Step 2: Embed queries -> query_vectors table + print("Embedding queries...") + query_embeddings = model.encode(queries, normalize_embeddings=True) + query_params = [] + for i, emb in enumerate(query_embeddings, 1): + blob = array("f", emb.tolist()).tobytes() + query_params.append((i, blob)) + db.executemany("INSERT INTO query_vectors VALUES (?, ?)", query_params) + db.commit() + print(f"Embedded {len(queries)} queries") + + if args.skip_neighbors: + db.close() + print(f"Done (skipped neighbors). Wrote {args.output}") + return + + # Step 3: Brute-force KNN via sqlite-vec -> neighbors table + n_queries = db.execute("SELECT count(*) FROM query_vectors").fetchone()[0] + print(f"Computing {args.k}-NN for {n_queries} queries via sqlite-vec...") + for query_id, query_blob in tqdm( + db.execute("SELECT id, vector FROM query_vectors").fetchall() + ): + results = db.execute( + """ + SELECT + train.id, + vec_distance_cosine(train.vector, ?) AS distance + FROM train + WHERE distance IS NOT NULL + ORDER BY distance ASC + LIMIT ? + """, + (query_blob, args.k), + ).fetchall() + + params = [ + (query_id, rank, str(rid)) + for rank, (rid, _dist) in enumerate(results) + ] + db.executemany("INSERT INTO neighbors VALUES (?, ?, ?)", params) + + db.commit() + db.close() + print(f"Done. Wrote {args.output}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks-ann/datasets/nyt-1024/queries.txt b/benchmarks-ann/datasets/nyt-1024/queries.txt new file mode 100644 index 0000000..9e98f84 --- /dev/null +++ b/benchmarks-ann/datasets/nyt-1024/queries.txt @@ -0,0 +1,100 @@ +latest news on climate change policy +presidential election results and analysis +stock market crash causes +coronavirus vaccine development updates +artificial intelligence breakthrough in healthcare +supreme court ruling on abortion rights +tech companies layoff announcements +earthquake damages in California +cybersecurity breach at major corporation +space exploration mission to Mars +immigration reform legislation debate +renewable energy investment trends +healthcare costs rising across America +protests against police brutality +wildfires destroy homes in the West +Olympic games highlights and records +celebrity scandal rocks Hollywood +breakthrough cancer treatment discovered +housing market bubble concerns +federal reserve interest rate decision +school shooting tragedy response +diplomatic tensions between superpowers +drone strike kills terrorist leader +social media platform faces regulation +archaeological discovery reveals ancient civilization +unemployment rate hits record low +autonomous vehicles testing expansion +streaming service launches original content +opioid crisis intervention programs +trade war tariffs impact economy +infrastructure bill passes Congress +data privacy concerns grow +minimum wage increase proposal +college admissions scandal exposed +NFL player protest during anthem +cryptocurrency regulation debate +pandemic lockdown restrictions eased +mass shooting gun control debate +tax reform legislation impact +ransomware attack cripples pipeline +climate activists stage demonstration +sports team wins championship +banking system collapse fears +pharmaceutical company fraud charges +genetic engineering ethical concerns +border wall funding controversy +impeachment proceedings begin +nuclear weapons treaty violation +artificial meat alternative launch +student loan debt forgiveness +venture capital funding decline +facial recognition ban proposed +election interference investigation +pandemic preparedness failures +police reform measures announced +wildfire prevention strategies +ocean pollution crisis worsens +manufacturing jobs returning +pension fund shortfall concerns +antitrust investigation launched +voting rights protection act +mental health awareness campaign +homeless population increasing +space debris collision risk +drug cartel violence escalates +renewable energy jobs growth +infrastructure deterioration report +vaccine mandate legal challenge +cryptocurrency market volatility +autonomous drone delivery service +deep fake technology dangers +Arctic ice melting accelerates +income inequality gap widens +election fraud claims disputed +corporate merger blocked +medical breakthrough extends life +transportation strike disrupts city +racial justice protests spread +carbon emissions reduction goals +financial crisis warning signs +cyberbullying prevention efforts +asteroid near miss with Earth +gene therapy approval granted +labor union organizing drive +surveillance technology expansion +education funding cuts proposed +disaster relief efforts underway +housing affordability crisis +clean water access shortage +artificial intelligence job displacement +trade agreement negotiations +prison reform initiative launched +species extinction accelerates +political corruption scandal +terrorism threat level raised +food safety contamination outbreak +ai model release +affordability interest rates +peanut allergies in newbons +breaking bad walter white \ No newline at end of file diff --git a/benchmarks-ann/datasets/nyt-384/Makefile b/benchmarks-ann/datasets/nyt-384/Makefile new file mode 100644 index 0000000..76296a1 --- /dev/null +++ b/benchmarks-ann/datasets/nyt-384/Makefile @@ -0,0 +1,29 @@ +MODEL ?= mixedbread-ai/mxbai-embed-xsmall-v1 +K ?= 100 +BATCH_SIZE ?= 512 +DATA_DIR ?= ../nyt/data + +all: base.db + +$(DATA_DIR): + $(MAKE) -C ../nyt data + +contents.db: $(DATA_DIR) + uv run ../nyt-768/build-contents.py --data-dir $(DATA_DIR) -o $@ + +base.db: contents.db queries.txt + uv run ../nyt-1024/build-base.py \ + --contents-db contents.db \ + --model $(MODEL) \ + --queries-file queries.txt \ + --batch-size $(BATCH_SIZE) \ + --k $(K) \ + -o $@ + +queries.txt: + cp ../nyt/queries.txt $@ + +clean: + rm -f base.db contents.db + +.PHONY: all clean diff --git a/benchmarks-ann/datasets/nyt-384/queries.txt b/benchmarks-ann/datasets/nyt-384/queries.txt new file mode 100644 index 0000000..9e98f84 --- /dev/null +++ b/benchmarks-ann/datasets/nyt-384/queries.txt @@ -0,0 +1,100 @@ +latest news on climate change policy +presidential election results and analysis +stock market crash causes +coronavirus vaccine development updates +artificial intelligence breakthrough in healthcare +supreme court ruling on abortion rights +tech companies layoff announcements +earthquake damages in California +cybersecurity breach at major corporation +space exploration mission to Mars +immigration reform legislation debate +renewable energy investment trends +healthcare costs rising across America +protests against police brutality +wildfires destroy homes in the West +Olympic games highlights and records +celebrity scandal rocks Hollywood +breakthrough cancer treatment discovered +housing market bubble concerns +federal reserve interest rate decision +school shooting tragedy response +diplomatic tensions between superpowers +drone strike kills terrorist leader +social media platform faces regulation +archaeological discovery reveals ancient civilization +unemployment rate hits record low +autonomous vehicles testing expansion +streaming service launches original content +opioid crisis intervention programs +trade war tariffs impact economy +infrastructure bill passes Congress +data privacy concerns grow +minimum wage increase proposal +college admissions scandal exposed +NFL player protest during anthem +cryptocurrency regulation debate +pandemic lockdown restrictions eased +mass shooting gun control debate +tax reform legislation impact +ransomware attack cripples pipeline +climate activists stage demonstration +sports team wins championship +banking system collapse fears +pharmaceutical company fraud charges +genetic engineering ethical concerns +border wall funding controversy +impeachment proceedings begin +nuclear weapons treaty violation +artificial meat alternative launch +student loan debt forgiveness +venture capital funding decline +facial recognition ban proposed +election interference investigation +pandemic preparedness failures +police reform measures announced +wildfire prevention strategies +ocean pollution crisis worsens +manufacturing jobs returning +pension fund shortfall concerns +antitrust investigation launched +voting rights protection act +mental health awareness campaign +homeless population increasing +space debris collision risk +drug cartel violence escalates +renewable energy jobs growth +infrastructure deterioration report +vaccine mandate legal challenge +cryptocurrency market volatility +autonomous drone delivery service +deep fake technology dangers +Arctic ice melting accelerates +income inequality gap widens +election fraud claims disputed +corporate merger blocked +medical breakthrough extends life +transportation strike disrupts city +racial justice protests spread +carbon emissions reduction goals +financial crisis warning signs +cyberbullying prevention efforts +asteroid near miss with Earth +gene therapy approval granted +labor union organizing drive +surveillance technology expansion +education funding cuts proposed +disaster relief efforts underway +housing affordability crisis +clean water access shortage +artificial intelligence job displacement +trade agreement negotiations +prison reform initiative launched +species extinction accelerates +political corruption scandal +terrorism threat level raised +food safety contamination outbreak +ai model release +affordability interest rates +peanut allergies in newbons +breaking bad walter white \ No newline at end of file diff --git a/benchmarks-ann/datasets/nyt-768/Makefile b/benchmarks-ann/datasets/nyt-768/Makefile new file mode 100644 index 0000000..93bb72a --- /dev/null +++ b/benchmarks-ann/datasets/nyt-768/Makefile @@ -0,0 +1,37 @@ +MODEL ?= bge-base-en-v1.5-768 +K ?= 100 +BATCH_SIZE ?= 512 +DATA_DIR ?= ../nyt/data + +all: base.db + +# Reuse data from ../nyt +$(DATA_DIR): + $(MAKE) -C ../nyt data + +# Distill model (separate step, may take a while) +$(MODEL): + uv run distill-model.py + +contents.db: $(DATA_DIR) + uv run build-contents.py --data-dir $(DATA_DIR) -o $@ + +base.db: contents.db queries.txt $(MODEL) + uv run ../nyt/build-base.py \ + --contents-db contents.db \ + --model $(MODEL) \ + --queries-file queries.txt \ + --batch-size $(BATCH_SIZE) \ + --k $(K) \ + -o $@ + +queries.txt: + cp ../nyt/queries.txt $@ + +clean: + rm -f base.db contents.db + +clean-all: clean + rm -rf $(MODEL) + +.PHONY: all clean clean-all diff --git a/benchmarks-ann/datasets/nyt-768/build-contents.py b/benchmarks-ann/datasets/nyt-768/build-contents.py new file mode 100644 index 0000000..fc829d8 --- /dev/null +++ b/benchmarks-ann/datasets/nyt-768/build-contents.py @@ -0,0 +1,64 @@ +# /// script +# requires-python = ">=3.12" +# dependencies = [ +# "duckdb", +# ] +# /// + +import argparse +import sqlite3 +import duckdb + + +def main(): + parser = argparse.ArgumentParser( + description="Load NYT headline CSVs into a SQLite contents database (most recent 1M, deduplicated)", + ) + parser.add_argument( + "--data-dir", "-d", default="../nyt/data", + help="Directory containing NYT CSV files (default: ../nyt/data)", + ) + parser.add_argument( + "--limit", "-l", type=int, default=1_000_000, + help="Maximum number of headlines to keep (default: 1000000)", + ) + parser.add_argument( + "--output", "-o", required=True, + help="Path to the output SQLite database", + ) + args = parser.parse_args() + + glob_pattern = f"{args.data_dir}/new_york_times_stories_*.csv" + + con = duckdb.connect() + rows = con.execute( + f""" + WITH deduped AS ( + SELECT + headline, + max(pub_date) AS pub_date + FROM read_csv('{glob_pattern}', auto_detect=true, union_by_name=true) + WHERE headline IS NOT NULL AND trim(headline) != '' + GROUP BY headline + ) + SELECT + row_number() OVER (ORDER BY pub_date DESC) AS id, + headline + FROM deduped + ORDER BY pub_date DESC + LIMIT {args.limit} + """ + ).fetchall() + con.close() + + db = sqlite3.connect(args.output) + db.execute("CREATE TABLE contents(id INTEGER PRIMARY KEY, headline TEXT)") + db.executemany("INSERT INTO contents VALUES (?, ?)", rows) + db.commit() + db.close() + + print(f"Wrote {len(rows)} headlines to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks-ann/datasets/nyt-768/distill-model.py b/benchmarks-ann/datasets/nyt-768/distill-model.py new file mode 100644 index 0000000..3adca4a --- /dev/null +++ b/benchmarks-ann/datasets/nyt-768/distill-model.py @@ -0,0 +1,13 @@ +# /// script +# requires-python = ">=3.12" +# dependencies = [ +# "model2vec[distill]", +# "torch<=2.7", +# ] +# /// + +from model2vec.distill import distill + +model = distill(model_name="BAAI/bge-base-en-v1.5", pca_dims=768) +model.save_pretrained("bge-base-en-v1.5-768") +print("Saved distilled model to bge-base-en-v1.5-768/") diff --git a/benchmarks-ann/datasets/nyt-768/queries.txt b/benchmarks-ann/datasets/nyt-768/queries.txt new file mode 100644 index 0000000..9e98f84 --- /dev/null +++ b/benchmarks-ann/datasets/nyt-768/queries.txt @@ -0,0 +1,100 @@ +latest news on climate change policy +presidential election results and analysis +stock market crash causes +coronavirus vaccine development updates +artificial intelligence breakthrough in healthcare +supreme court ruling on abortion rights +tech companies layoff announcements +earthquake damages in California +cybersecurity breach at major corporation +space exploration mission to Mars +immigration reform legislation debate +renewable energy investment trends +healthcare costs rising across America +protests against police brutality +wildfires destroy homes in the West +Olympic games highlights and records +celebrity scandal rocks Hollywood +breakthrough cancer treatment discovered +housing market bubble concerns +federal reserve interest rate decision +school shooting tragedy response +diplomatic tensions between superpowers +drone strike kills terrorist leader +social media platform faces regulation +archaeological discovery reveals ancient civilization +unemployment rate hits record low +autonomous vehicles testing expansion +streaming service launches original content +opioid crisis intervention programs +trade war tariffs impact economy +infrastructure bill passes Congress +data privacy concerns grow +minimum wage increase proposal +college admissions scandal exposed +NFL player protest during anthem +cryptocurrency regulation debate +pandemic lockdown restrictions eased +mass shooting gun control debate +tax reform legislation impact +ransomware attack cripples pipeline +climate activists stage demonstration +sports team wins championship +banking system collapse fears +pharmaceutical company fraud charges +genetic engineering ethical concerns +border wall funding controversy +impeachment proceedings begin +nuclear weapons treaty violation +artificial meat alternative launch +student loan debt forgiveness +venture capital funding decline +facial recognition ban proposed +election interference investigation +pandemic preparedness failures +police reform measures announced +wildfire prevention strategies +ocean pollution crisis worsens +manufacturing jobs returning +pension fund shortfall concerns +antitrust investigation launched +voting rights protection act +mental health awareness campaign +homeless population increasing +space debris collision risk +drug cartel violence escalates +renewable energy jobs growth +infrastructure deterioration report +vaccine mandate legal challenge +cryptocurrency market volatility +autonomous drone delivery service +deep fake technology dangers +Arctic ice melting accelerates +income inequality gap widens +election fraud claims disputed +corporate merger blocked +medical breakthrough extends life +transportation strike disrupts city +racial justice protests spread +carbon emissions reduction goals +financial crisis warning signs +cyberbullying prevention efforts +asteroid near miss with Earth +gene therapy approval granted +labor union organizing drive +surveillance technology expansion +education funding cuts proposed +disaster relief efforts underway +housing affordability crisis +clean water access shortage +artificial intelligence job displacement +trade agreement negotiations +prison reform initiative launched +species extinction accelerates +political corruption scandal +terrorism threat level raised +food safety contamination outbreak +ai model release +affordability interest rates +peanut allergies in newbons +breaking bad walter white \ No newline at end of file diff --git a/benchmarks-ann/datasets/nyt/.gitignore b/benchmarks-ann/datasets/nyt/.gitignore new file mode 100644 index 0000000..adbb97d --- /dev/null +++ b/benchmarks-ann/datasets/nyt/.gitignore @@ -0,0 +1 @@ +data/ \ No newline at end of file diff --git a/benchmarks-ann/datasets/nyt/Makefile b/benchmarks-ann/datasets/nyt/Makefile new file mode 100644 index 0000000..dfaa6e9 --- /dev/null +++ b/benchmarks-ann/datasets/nyt/Makefile @@ -0,0 +1,30 @@ +MODEL ?= minishlab/potion-base-8M +K ?= 100 +BATCH_SIZE ?= 512 +DATA_DIR ?= data + +all: base.db contents.db + +# Download NYT headlines CSVs from Kaggle (requires `kaggle` CLI + API token) +$(DATA_DIR): + kaggle datasets download -d johnbandy/new-york-times-headlines -p $(DATA_DIR) --unzip + +contents.db: $(DATA_DIR) + uv run build-contents.py --data-dir $(DATA_DIR) -o $@ + +base.db: contents.db queries.txt + uv run build-base.py \ + --contents-db contents.db \ + --model $(MODEL) \ + --queries-file queries.txt \ + --batch-size $(BATCH_SIZE) \ + --k $(K) \ + -o $@ + +clean: + rm -f base.db contents.db + +clean-all: clean + rm -rf $(DATA_DIR) + +.PHONY: all clean clean-all diff --git a/benchmarks-ann/datasets/nyt/build-base.py b/benchmarks-ann/datasets/nyt/build-base.py new file mode 100644 index 0000000..db00aa2 --- /dev/null +++ b/benchmarks-ann/datasets/nyt/build-base.py @@ -0,0 +1,165 @@ +# /// script +# requires-python = ">=3.12" +# dependencies = [ +# "model2vec", +# "torch<=2.7", +# "tqdm", +# ] +# /// + +import argparse +import sqlite3 +from array import array +from itertools import batched + +from model2vec import StaticModel +from tqdm import tqdm + + +def main(): + parser = argparse.ArgumentParser( + description="Build base.db with train vectors, query vectors, and brute-force KNN neighbors", + ) + parser.add_argument( + "--contents-db", "-c", default=None, + help="Path to contents.db (source of headlines and IDs)", + ) + parser.add_argument( + "--model", "-m", default="minishlab/potion-base-8M", + help="HuggingFace model ID or local path (default: minishlab/potion-base-8M)", + ) + parser.add_argument( + "--queries-file", "-q", default="queries.txt", + help="Path to the queries file (default: queries.txt)", + ) + parser.add_argument( + "--output", "-o", required=True, + help="Path to the output base.db", + ) + parser.add_argument( + "--batch-size", "-b", type=int, default=512, + help="Batch size for embedding (default: 512)", + ) + parser.add_argument( + "--k", "-k", type=int, default=100, + help="Number of nearest neighbors (default: 100)", + ) + parser.add_argument( + "--vec-path", "-v", default="~/projects/sqlite-vec/dist/vec0", + help="Path to sqlite-vec extension (default: ~/projects/sqlite-vec/dist/vec0)", + ) + parser.add_argument( + "--rebuild-neighbors", action="store_true", + help="Only rebuild the neighbors table (skip embedding steps)", + ) + args = parser.parse_args() + + import os + vec_path = os.path.expanduser(args.vec_path) + + if args.rebuild_neighbors: + # Skip embedding, just open existing DB and rebuild neighbors + db = sqlite3.connect(args.output) + db.enable_load_extension(True) + db.load_extension(vec_path) + db.enable_load_extension(False) + db.execute("DROP TABLE IF EXISTS neighbors") + db.execute( + "CREATE TABLE neighbors(" + " query_vector_id INTEGER, rank INTEGER, neighbors_id TEXT," + " UNIQUE(query_vector_id, rank))" + ) + print(f"Rebuilding neighbors in {args.output}...") + else: + print(f"Loading model {args.model}...") + model = StaticModel.from_pretrained(args.model) + + # Read headlines from contents.db + src = sqlite3.connect(args.contents_db) + headlines = src.execute("SELECT id, headline FROM contents ORDER BY id").fetchall() + src.close() + print(f"Loaded {len(headlines)} headlines from {args.contents_db}") + + # Read queries + with open(args.queries_file) as f: + queries = [line.strip() for line in f if line.strip()] + print(f"Loaded {len(queries)} queries from {args.queries_file}") + + # Create output database + db = sqlite3.connect(args.output) + db.enable_load_extension(True) + db.load_extension(vec_path) + db.enable_load_extension(False) + + db.execute("CREATE TABLE train(id INTEGER PRIMARY KEY, vector BLOB)") + db.execute("CREATE TABLE query_vectors(id INTEGER PRIMARY KEY, vector BLOB)") + db.execute( + "CREATE TABLE neighbors(" + " query_vector_id INTEGER, rank INTEGER, neighbors_id TEXT," + " UNIQUE(query_vector_id, rank))" + ) + + # Step 1: Embed headlines -> train table + print("Embedding headlines...") + for batch in tqdm( + batched(headlines, args.batch_size), + total=(len(headlines) + args.batch_size - 1) // args.batch_size, + ): + ids = [r[0] for r in batch] + texts = [r[1] for r in batch] + embeddings = model.encode(texts) + + params = [ + (int(rid), array("f", emb.tolist()).tobytes()) + for rid, emb in zip(ids, embeddings) + ] + db.executemany("INSERT INTO train VALUES (?, ?)", params) + db.commit() + + del headlines + n = db.execute("SELECT count(*) FROM train").fetchone()[0] + print(f"Embedded {n} headlines") + + # Step 2: Embed queries -> query_vectors table + print("Embedding queries...") + query_embeddings = model.encode(queries) + query_params = [] + for i, emb in enumerate(query_embeddings, 1): + blob = array("f", emb.tolist()).tobytes() + query_params.append((i, blob)) + db.executemany("INSERT INTO query_vectors VALUES (?, ?)", query_params) + db.commit() + print(f"Embedded {len(queries)} queries") + + # Step 3: Brute-force KNN via sqlite-vec -> neighbors table + n_queries = db.execute("SELECT count(*) FROM query_vectors").fetchone()[0] + print(f"Computing {args.k}-NN for {n_queries} queries via sqlite-vec...") + for query_id, query_blob in tqdm( + db.execute("SELECT id, vector FROM query_vectors").fetchall() + ): + results = db.execute( + """ + SELECT + train.id, + vec_distance_cosine(train.vector, ?) AS distance + FROM train + WHERE distance IS NOT NULL + ORDER BY distance ASC + LIMIT ? + """, + (query_blob, args.k), + ).fetchall() + + params = [ + (query_id, rank, str(rid)) + for rank, (rid, _dist) in enumerate(results) + ] + db.executemany("INSERT INTO neighbors VALUES (?, ?, ?)", params) + + db.commit() + db.close() + print(f"Done. Wrote {args.output}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks-ann/datasets/nyt/build-contents.py b/benchmarks-ann/datasets/nyt/build-contents.py new file mode 100644 index 0000000..7e99cb9 --- /dev/null +++ b/benchmarks-ann/datasets/nyt/build-contents.py @@ -0,0 +1,52 @@ +# /// script +# requires-python = ">=3.12" +# dependencies = [ +# "duckdb", +# ] +# /// + +import argparse +import os +import sqlite3 +import duckdb + + +def main(): + parser = argparse.ArgumentParser( + description="Load NYT headline CSVs into a SQLite contents database via DuckDB", + ) + parser.add_argument( + "--data-dir", "-d", default="data", + help="Directory containing NYT CSV files (default: data)", + ) + parser.add_argument( + "--output", "-o", required=True, + help="Path to the output SQLite database", + ) + args = parser.parse_args() + + glob_pattern = os.path.join(args.data_dir, "new_york_times_stories_*.csv") + + con = duckdb.connect() + rows = con.execute( + f""" + SELECT + row_number() OVER () AS id, + headline + FROM read_csv('{glob_pattern}', auto_detect=true, union_by_name=true) + WHERE headline IS NOT NULL AND headline != '' + """ + ).fetchall() + con.close() + + db = sqlite3.connect(args.output) + db.execute("CREATE TABLE contents(id INTEGER PRIMARY KEY, headline TEXT)") + db.executemany("INSERT INTO contents VALUES (?, ?)", rows) + db.commit() + db.close() + + print(f"Wrote {len(rows)} headlines to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks-ann/datasets/nyt/queries.txt b/benchmarks-ann/datasets/nyt/queries.txt new file mode 100644 index 0000000..9e98f84 --- /dev/null +++ b/benchmarks-ann/datasets/nyt/queries.txt @@ -0,0 +1,100 @@ +latest news on climate change policy +presidential election results and analysis +stock market crash causes +coronavirus vaccine development updates +artificial intelligence breakthrough in healthcare +supreme court ruling on abortion rights +tech companies layoff announcements +earthquake damages in California +cybersecurity breach at major corporation +space exploration mission to Mars +immigration reform legislation debate +renewable energy investment trends +healthcare costs rising across America +protests against police brutality +wildfires destroy homes in the West +Olympic games highlights and records +celebrity scandal rocks Hollywood +breakthrough cancer treatment discovered +housing market bubble concerns +federal reserve interest rate decision +school shooting tragedy response +diplomatic tensions between superpowers +drone strike kills terrorist leader +social media platform faces regulation +archaeological discovery reveals ancient civilization +unemployment rate hits record low +autonomous vehicles testing expansion +streaming service launches original content +opioid crisis intervention programs +trade war tariffs impact economy +infrastructure bill passes Congress +data privacy concerns grow +minimum wage increase proposal +college admissions scandal exposed +NFL player protest during anthem +cryptocurrency regulation debate +pandemic lockdown restrictions eased +mass shooting gun control debate +tax reform legislation impact +ransomware attack cripples pipeline +climate activists stage demonstration +sports team wins championship +banking system collapse fears +pharmaceutical company fraud charges +genetic engineering ethical concerns +border wall funding controversy +impeachment proceedings begin +nuclear weapons treaty violation +artificial meat alternative launch +student loan debt forgiveness +venture capital funding decline +facial recognition ban proposed +election interference investigation +pandemic preparedness failures +police reform measures announced +wildfire prevention strategies +ocean pollution crisis worsens +manufacturing jobs returning +pension fund shortfall concerns +antitrust investigation launched +voting rights protection act +mental health awareness campaign +homeless population increasing +space debris collision risk +drug cartel violence escalates +renewable energy jobs growth +infrastructure deterioration report +vaccine mandate legal challenge +cryptocurrency market volatility +autonomous drone delivery service +deep fake technology dangers +Arctic ice melting accelerates +income inequality gap widens +election fraud claims disputed +corporate merger blocked +medical breakthrough extends life +transportation strike disrupts city +racial justice protests spread +carbon emissions reduction goals +financial crisis warning signs +cyberbullying prevention efforts +asteroid near miss with Earth +gene therapy approval granted +labor union organizing drive +surveillance technology expansion +education funding cuts proposed +disaster relief efforts underway +housing affordability crisis +clean water access shortage +artificial intelligence job displacement +trade agreement negotiations +prison reform initiative launched +species extinction accelerates +political corruption scandal +terrorism threat level raised +food safety contamination outbreak +ai model release +affordability interest rates +peanut allergies in newbons +breaking bad walter white \ No newline at end of file diff --git a/benchmarks-ann/faiss_kmeans.py b/benchmarks-ann/faiss_kmeans.py new file mode 100644 index 0000000..9765a7b --- /dev/null +++ b/benchmarks-ann/faiss_kmeans.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +"""Compute k-means centroids using FAISS and save to a centroids DB. + +Reads the first N vectors from a base.db, runs FAISS k-means, and writes +the centroids to an output SQLite DB as float32 blobs. + +Usage: + python faiss_kmeans.py --base-db datasets/cohere10m/base.db --ntrain 100000 \ + --nclusters 8192 -o centroids.db + +Output schema: + CREATE TABLE centroids ( + centroid_id INTEGER PRIMARY KEY, + centroid BLOB NOT NULL -- float32[D] + ); + CREATE TABLE meta (key TEXT PRIMARY KEY, value TEXT); + -- ntrain, nclusters, dimensions, elapsed_s +""" +import argparse +import os +import sqlite3 +import struct +import time + +import faiss +import numpy as np + + +def main(): + parser = argparse.ArgumentParser(description="FAISS k-means centroid computation") + parser.add_argument("--base-db", required=True, help="path to base.db with train table") + parser.add_argument("--ntrain", type=int, required=True, help="number of vectors to train on") + parser.add_argument("--nclusters", type=int, required=True, help="number of clusters (nlist)") + parser.add_argument("--niter", type=int, default=20, help="k-means iterations (default 20)") + parser.add_argument("--seed", type=int, default=42, help="random seed") + parser.add_argument("-o", "--output", required=True, help="output centroids DB path") + args = parser.parse_args() + + # Load vectors + print(f"Loading {args.ntrain} vectors from {args.base_db}...") + conn = sqlite3.connect(args.base_db) + rows = conn.execute( + "SELECT vector FROM train ORDER BY id LIMIT ?", (args.ntrain,) + ).fetchall() + conn.close() + + # Parse float32 blobs to numpy + first_blob = rows[0][0] + D = len(first_blob) // 4 # float32 + print(f" Dimensions: {D}, loaded {len(rows)} vectors") + + vectors = np.zeros((len(rows), D), dtype=np.float32) + for i, (blob,) in enumerate(rows): + vectors[i] = np.frombuffer(blob, dtype=np.float32) + + # Normalize for cosine distance (FAISS k-means on L2 of unit vectors ≈ cosine) + norms = np.linalg.norm(vectors, axis=1, keepdims=True) + norms[norms == 0] = 1 + vectors /= norms + + # Run FAISS k-means + print(f"Running k-means: {args.nclusters} clusters, {args.niter} iterations...") + t0 = time.perf_counter() + kmeans = faiss.Kmeans( + D, args.nclusters, + niter=args.niter, + seed=args.seed, + verbose=True, + gpu=False, + ) + kmeans.train(vectors) + elapsed = time.perf_counter() - t0 + print(f" Done in {elapsed:.1f}s") + + centroids = kmeans.centroids # (nclusters, D) float32 + + # Write output DB + if os.path.exists(args.output): + os.remove(args.output) + out = sqlite3.connect(args.output) + out.execute("CREATE TABLE centroids (centroid_id INTEGER PRIMARY KEY, centroid BLOB NOT NULL)") + out.execute("CREATE TABLE meta (key TEXT PRIMARY KEY, value TEXT)") + + for i in range(args.nclusters): + blob = centroids[i].tobytes() + out.execute("INSERT INTO centroids (centroid_id, centroid) VALUES (?, ?)", (i, blob)) + + out.execute("INSERT INTO meta VALUES ('ntrain', ?)", (str(args.ntrain),)) + out.execute("INSERT INTO meta VALUES ('nclusters', ?)", (str(args.nclusters),)) + out.execute("INSERT INTO meta VALUES ('dimensions', ?)", (str(D),)) + out.execute("INSERT INTO meta VALUES ('niter', ?)", (str(args.niter),)) + out.execute("INSERT INTO meta VALUES ('elapsed_s', ?)", (str(round(elapsed, 3)),)) + out.execute("INSERT INTO meta VALUES ('seed', ?)", (str(args.seed),)) + out.commit() + out.close() + + print(f"Wrote {args.nclusters} centroids to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks-ann/results_schema.sql b/benchmarks-ann/results_schema.sql new file mode 100644 index 0000000..7918709 --- /dev/null +++ b/benchmarks-ann/results_schema.sql @@ -0,0 +1,76 @@ +-- Comprehensive results schema for vec0 KNN benchmark runs. +-- Created in WAL mode: PRAGMA journal_mode=WAL + +CREATE TABLE IF NOT EXISTS runs ( + run_id INTEGER PRIMARY KEY AUTOINCREMENT, + config_name TEXT NOT NULL, + index_type TEXT NOT NULL, + params TEXT NOT NULL, -- JSON: {"R":48,"L":128,"quantizer":"binary"} + dataset TEXT NOT NULL, -- "cohere1m" + subset_size INTEGER NOT NULL, + k INTEGER NOT NULL, + n_queries INTEGER NOT NULL, + phase TEXT NOT NULL DEFAULT 'both', + -- 'build', 'query', or 'both' + status TEXT NOT NULL DEFAULT 'pending', + -- pending → inserting → training → querying → done | built | error + created_at_ns INTEGER NOT NULL -- time.time_ns() +); + +CREATE TABLE IF NOT EXISTS run_results ( + run_id INTEGER PRIMARY KEY REFERENCES runs(run_id), + insert_started_ns INTEGER, + insert_ended_ns INTEGER, + insert_duration_ns INTEGER, + train_started_ns INTEGER, -- NULL if no training + train_ended_ns INTEGER, + train_duration_ns INTEGER, + build_duration_ns INTEGER, -- insert + train + db_file_size_bytes INTEGER, + db_file_path TEXT, + create_sql TEXT, -- CREATE VIRTUAL TABLE ... + insert_sql TEXT, -- INSERT INTO vec_items ... + train_sql TEXT, -- NULL if no training step + query_sql TEXT, -- SELECT ... WHERE embedding MATCH ... + k INTEGER, -- denormalized from runs for easy filtering + query_mean_ms REAL, -- denormalized aggregates + query_median_ms REAL, + query_p99_ms REAL, + query_total_ms REAL, + qps REAL, + recall REAL +); + +CREATE TABLE IF NOT EXISTS insert_batches ( + batch_id INTEGER PRIMARY KEY AUTOINCREMENT, + run_id INTEGER NOT NULL REFERENCES runs(run_id), + batch_lo INTEGER NOT NULL, -- start index (inclusive) + batch_hi INTEGER NOT NULL, -- end index (exclusive) + rows_in_batch INTEGER NOT NULL, + started_ns INTEGER NOT NULL, + ended_ns INTEGER NOT NULL, + duration_ns INTEGER NOT NULL, + cumulative_rows INTEGER NOT NULL, -- total rows inserted so far + rate_rows_per_s REAL NOT NULL -- cumulative rate +); + +CREATE TABLE IF NOT EXISTS queries ( + query_id INTEGER PRIMARY KEY AUTOINCREMENT, + run_id INTEGER NOT NULL REFERENCES runs(run_id), + k INTEGER NOT NULL, + query_vector_id INTEGER NOT NULL, + started_ns INTEGER NOT NULL, + ended_ns INTEGER NOT NULL, + duration_ms REAL NOT NULL, + result_ids TEXT NOT NULL, -- JSON array + result_distances TEXT NOT NULL, -- JSON array + ground_truth_ids TEXT NOT NULL, -- JSON array + recall REAL NOT NULL, + UNIQUE(run_id, k, query_vector_id) +); + +CREATE INDEX IF NOT EXISTS idx_runs_config ON runs(config_name); +CREATE INDEX IF NOT EXISTS idx_runs_type ON runs(index_type); +CREATE INDEX IF NOT EXISTS idx_runs_status ON runs(status); +CREATE INDEX IF NOT EXISTS idx_batches_run ON insert_batches(run_id); +CREATE INDEX IF NOT EXISTS idx_queries_run ON queries(run_id);