sqlite-vec/benchmarks-ann/datasets/nyt-1024/build-base.py

164 lines
5.2 KiB
Python
Raw Normal View History

# /// script
# requires-python = ">=3.12"
# dependencies = [
# "sentence-transformers",
# "torch<=2.7",
# "tqdm",
# ]
# ///
import argparse
import sqlite3
from array import array
from itertools import batched
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
def main():
parser = argparse.ArgumentParser(
description="Build base.db with train vectors, query vectors, and brute-force KNN neighbors",
)
parser.add_argument(
"--contents-db", "-c", default=None,
help="Path to contents.db (source of headlines and IDs)",
)
parser.add_argument(
"--model", "-m", default="mixedbread-ai/mxbai-embed-large-v1",
help="HuggingFace model ID (default: mixedbread-ai/mxbai-embed-large-v1)",
)
parser.add_argument(
"--queries-file", "-q", default="queries.txt",
help="Path to the queries file (default: queries.txt)",
)
parser.add_argument(
"--output", "-o", required=True,
help="Path to the output base.db",
)
parser.add_argument(
"--batch-size", "-b", type=int, default=256,
help="Batch size for embedding (default: 256)",
)
parser.add_argument(
"--k", "-k", type=int, default=100,
help="Number of nearest neighbors (default: 100)",
)
parser.add_argument(
"--limit", "-l", type=int, default=0,
help="Limit number of headlines to embed (0 = all)",
)
parser.add_argument(
"--vec-path", "-v", default="~/projects/sqlite-vec/dist/vec0",
help="Path to sqlite-vec extension (default: ~/projects/sqlite-vec/dist/vec0)",
)
parser.add_argument(
"--skip-neighbors", action="store_true",
help="Skip the brute-force KNN neighbor computation",
)
args = parser.parse_args()
import os
vec_path = os.path.expanduser(args.vec_path)
print(f"Loading model {args.model}...")
model = SentenceTransformer(args.model)
# Read headlines from contents.db
src = sqlite3.connect(args.contents_db)
limit_clause = f" LIMIT {args.limit}" if args.limit > 0 else ""
headlines = src.execute(
f"SELECT id, headline FROM contents ORDER BY id{limit_clause}"
).fetchall()
src.close()
print(f"Loaded {len(headlines)} headlines from {args.contents_db}")
# Read queries
with open(args.queries_file) as f:
queries = [line.strip() for line in f if line.strip()]
print(f"Loaded {len(queries)} queries from {args.queries_file}")
# Create output database
db = sqlite3.connect(args.output)
db.enable_load_extension(True)
db.load_extension(vec_path)
db.enable_load_extension(False)
db.execute("CREATE TABLE IF NOT EXISTS train(id INTEGER PRIMARY KEY, vector BLOB)")
db.execute("CREATE TABLE IF NOT EXISTS query_vectors(id INTEGER PRIMARY KEY, vector BLOB)")
db.execute(
"CREATE TABLE IF NOT EXISTS neighbors("
" query_vector_id INTEGER, rank INTEGER, neighbors_id TEXT,"
" UNIQUE(query_vector_id, rank))"
)
# Step 1: Embed headlines -> train table
print("Embedding headlines...")
for batch in tqdm(
batched(headlines, args.batch_size),
total=(len(headlines) + args.batch_size - 1) // args.batch_size,
):
ids = [r[0] for r in batch]
texts = [r[1] for r in batch]
embeddings = model.encode(texts, normalize_embeddings=True)
params = [
(int(rid), array("f", emb.tolist()).tobytes())
for rid, emb in zip(ids, embeddings)
]
db.executemany("INSERT INTO train VALUES (?, ?)", params)
db.commit()
del headlines
n = db.execute("SELECT count(*) FROM train").fetchone()[0]
print(f"Embedded {n} headlines")
# Step 2: Embed queries -> query_vectors table
print("Embedding queries...")
query_embeddings = model.encode(queries, normalize_embeddings=True)
query_params = []
for i, emb in enumerate(query_embeddings, 1):
blob = array("f", emb.tolist()).tobytes()
query_params.append((i, blob))
db.executemany("INSERT INTO query_vectors VALUES (?, ?)", query_params)
db.commit()
print(f"Embedded {len(queries)} queries")
if args.skip_neighbors:
db.close()
print(f"Done (skipped neighbors). Wrote {args.output}")
return
# Step 3: Brute-force KNN via sqlite-vec -> neighbors table
n_queries = db.execute("SELECT count(*) FROM query_vectors").fetchone()[0]
print(f"Computing {args.k}-NN for {n_queries} queries via sqlite-vec...")
for query_id, query_blob in tqdm(
db.execute("SELECT id, vector FROM query_vectors").fetchall()
):
results = db.execute(
"""
SELECT
train.id,
vec_distance_cosine(train.vector, ?) AS distance
FROM train
WHERE distance IS NOT NULL
ORDER BY distance ASC
LIMIT ?
""",
(query_blob, args.k),
).fetchall()
params = [
(query_id, rank, str(rid))
for rank, (rid, _dist) in enumerate(results)
]
db.executemany("INSERT INTO neighbors VALUES (?, ?, ?)", params)
db.commit()
db.close()
print(f"Done. Wrote {args.output}")
if __name__ == "__main__":
main()