sqlite-vec/benchmarks-ann/ground_truth.py
Alex Garcia 0de765f457
Add ANN search support for vec0 virtual table (#273)
Add approximate nearest neighbor infrastructure to vec0: shared distance
dispatch (vec0_distance_full), flat index type with parser, NEON-optimized
cosine/Hamming for float32/int8, amalgamation script, and benchmark suite
(benchmarks-ann/) with ground-truth generation and profiling tools. Remove
unused vec_npy_each/vec_static_blobs code, fix missing stdint.h include.
2026-03-31 01:03:32 -07:00

168 lines
5.5 KiB
Python

#!/usr/bin/env python3
"""Compute per-subset ground truth for ANN benchmarks.
For subset sizes < 1M, builds a temporary vec0 float table with the first N
vectors and runs brute-force KNN to get correct ground truth per subset.
For 1M (the full dataset), converts the existing `neighbors` table.
Output: ground_truth.{subset_size}.db with table:
ground_truth(query_vector_id, rank, neighbor_id, distance)
Usage:
python ground_truth.py --subset-size 50000
python ground_truth.py --subset-size 1000000
"""
import argparse
import os
import sqlite3
import time
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
EXT_PATH = os.path.join(_SCRIPT_DIR, "..", "dist", "vec0")
BASE_DB = os.path.join(_SCRIPT_DIR, "seed", "base.db")
FULL_DATASET_SIZE = 1_000_000
def gen_ground_truth_subset(base_db, ext_path, subset_size, n_queries, k, out_path):
"""Build ground truth by brute-force KNN over the first `subset_size` vectors."""
if os.path.exists(out_path):
os.remove(out_path)
conn = sqlite3.connect(out_path)
conn.enable_load_extension(True)
conn.load_extension(ext_path)
conn.execute(
"CREATE TABLE ground_truth ("
" query_vector_id INTEGER NOT NULL,"
" rank INTEGER NOT NULL,"
" neighbor_id INTEGER NOT NULL,"
" distance REAL NOT NULL,"
" PRIMARY KEY (query_vector_id, rank)"
")"
)
conn.execute(f"ATTACH DATABASE '{base_db}' AS base")
print(f" Building temp vec0 table with {subset_size} vectors...")
conn.execute(
"CREATE VIRTUAL TABLE tmp_vec USING vec0("
" id integer primary key,"
" embedding float[768] distance_metric=cosine"
")"
)
t0 = time.perf_counter()
conn.execute(
"INSERT INTO tmp_vec(id, embedding) "
"SELECT id, vector FROM base.train WHERE id < :n",
{"n": subset_size},
)
conn.commit()
build_time = time.perf_counter() - t0
print(f" Temp table built in {build_time:.1f}s")
query_vectors = conn.execute(
"SELECT id, vector FROM base.query_vectors ORDER BY id LIMIT :n",
{"n": n_queries},
).fetchall()
print(f" Running brute-force KNN for {len(query_vectors)} queries, k={k}...")
t0 = time.perf_counter()
for i, (qid, qvec) in enumerate(query_vectors):
results = conn.execute(
"SELECT id, distance FROM tmp_vec "
"WHERE embedding MATCH :query AND k = :k",
{"query": qvec, "k": k},
).fetchall()
for rank, (nid, dist) in enumerate(results):
conn.execute(
"INSERT INTO ground_truth(query_vector_id, rank, neighbor_id, distance) "
"VALUES (?, ?, ?, ?)",
(qid, rank, nid, dist),
)
if (i + 1) % 10 == 0 or i == 0:
elapsed = time.perf_counter() - t0
eta = (elapsed / (i + 1)) * (len(query_vectors) - i - 1)
print(
f" {i+1}/{len(query_vectors)} queries "
f"elapsed={elapsed:.1f}s eta={eta:.1f}s",
flush=True,
)
conn.commit()
conn.execute("DROP TABLE tmp_vec")
conn.execute("DETACH DATABASE base")
conn.commit()
elapsed = time.perf_counter() - t0
total_rows = conn.execute("SELECT count(*) FROM ground_truth").fetchone()[0]
conn.close()
print(f" Ground truth: {total_rows} rows in {elapsed:.1f}s -> {out_path}")
def gen_ground_truth_full(base_db, n_queries, k, out_path):
"""Convert the existing neighbors table for the full 1M dataset."""
if os.path.exists(out_path):
os.remove(out_path)
conn = sqlite3.connect(out_path)
conn.execute(f"ATTACH DATABASE '{base_db}' AS base")
conn.execute(
"CREATE TABLE ground_truth ("
" query_vector_id INTEGER NOT NULL,"
" rank INTEGER NOT NULL,"
" neighbor_id INTEGER NOT NULL,"
" distance REAL,"
" PRIMARY KEY (query_vector_id, rank)"
")"
)
conn.execute(
"INSERT INTO ground_truth(query_vector_id, rank, neighbor_id) "
"SELECT query_vector_id, rank, CAST(neighbors_id AS INTEGER) "
"FROM base.neighbors "
"WHERE query_vector_id < :n AND rank < :k",
{"n": n_queries, "k": k},
)
conn.commit()
total_rows = conn.execute("SELECT count(*) FROM ground_truth").fetchone()[0]
conn.execute("DETACH DATABASE base")
conn.close()
print(f" Ground truth (full): {total_rows} rows -> {out_path}")
def main():
parser = argparse.ArgumentParser(description="Generate per-subset ground truth")
parser.add_argument(
"--subset-size", type=int, required=True, help="number of vectors in subset"
)
parser.add_argument("-n", type=int, default=100, help="number of query vectors")
parser.add_argument("-k", type=int, default=100, help="max k for ground truth")
parser.add_argument("--base-db", default=BASE_DB)
parser.add_argument("--ext", default=EXT_PATH)
parser.add_argument(
"-o", "--out-dir", default=os.path.join(_SCRIPT_DIR, "seed"),
help="output directory for ground_truth.{N}.db",
)
args = parser.parse_args()
os.makedirs(args.out_dir, exist_ok=True)
out_path = os.path.join(args.out_dir, f"ground_truth.{args.subset_size}.db")
if args.subset_size >= FULL_DATASET_SIZE:
gen_ground_truth_full(args.base_db, args.n, args.k, out_path)
else:
gen_ground_truth_subset(
args.base_db, args.ext, args.subset_size, args.n, args.k, out_path
)
if __name__ == "__main__":
main()