mirror of
https://github.com/asg017/sqlite-vec.git
synced 2026-04-25 08:46:49 +02:00
Extend benchmarks-ann/ with results database (SQLite with per-query detail and continuous writes), dataset subfolder organization, --subset-size and --warmup options. Supports systematic comparison across flat, rescore, IVF, and DiskANN index types.
165 lines
5.5 KiB
Python
165 lines
5.5 KiB
Python
# /// script
|
|
# requires-python = ">=3.12"
|
|
# dependencies = [
|
|
# "model2vec",
|
|
# "torch<=2.7",
|
|
# "tqdm",
|
|
# ]
|
|
# ///
|
|
|
|
import argparse
|
|
import sqlite3
|
|
from array import array
|
|
from itertools import batched
|
|
|
|
from model2vec import StaticModel
|
|
from tqdm import tqdm
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Build base.db with train vectors, query vectors, and brute-force KNN neighbors",
|
|
)
|
|
parser.add_argument(
|
|
"--contents-db", "-c", default=None,
|
|
help="Path to contents.db (source of headlines and IDs)",
|
|
)
|
|
parser.add_argument(
|
|
"--model", "-m", default="minishlab/potion-base-8M",
|
|
help="HuggingFace model ID or local path (default: minishlab/potion-base-8M)",
|
|
)
|
|
parser.add_argument(
|
|
"--queries-file", "-q", default="queries.txt",
|
|
help="Path to the queries file (default: queries.txt)",
|
|
)
|
|
parser.add_argument(
|
|
"--output", "-o", required=True,
|
|
help="Path to the output base.db",
|
|
)
|
|
parser.add_argument(
|
|
"--batch-size", "-b", type=int, default=512,
|
|
help="Batch size for embedding (default: 512)",
|
|
)
|
|
parser.add_argument(
|
|
"--k", "-k", type=int, default=100,
|
|
help="Number of nearest neighbors (default: 100)",
|
|
)
|
|
parser.add_argument(
|
|
"--vec-path", "-v", default="~/projects/sqlite-vec/dist/vec0",
|
|
help="Path to sqlite-vec extension (default: ~/projects/sqlite-vec/dist/vec0)",
|
|
)
|
|
parser.add_argument(
|
|
"--rebuild-neighbors", action="store_true",
|
|
help="Only rebuild the neighbors table (skip embedding steps)",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
import os
|
|
vec_path = os.path.expanduser(args.vec_path)
|
|
|
|
if args.rebuild_neighbors:
|
|
# Skip embedding, just open existing DB and rebuild neighbors
|
|
db = sqlite3.connect(args.output)
|
|
db.enable_load_extension(True)
|
|
db.load_extension(vec_path)
|
|
db.enable_load_extension(False)
|
|
db.execute("DROP TABLE IF EXISTS neighbors")
|
|
db.execute(
|
|
"CREATE TABLE neighbors("
|
|
" query_vector_id INTEGER, rank INTEGER, neighbors_id TEXT,"
|
|
" UNIQUE(query_vector_id, rank))"
|
|
)
|
|
print(f"Rebuilding neighbors in {args.output}...")
|
|
else:
|
|
print(f"Loading model {args.model}...")
|
|
model = StaticModel.from_pretrained(args.model)
|
|
|
|
# Read headlines from contents.db
|
|
src = sqlite3.connect(args.contents_db)
|
|
headlines = src.execute("SELECT id, headline FROM contents ORDER BY id").fetchall()
|
|
src.close()
|
|
print(f"Loaded {len(headlines)} headlines from {args.contents_db}")
|
|
|
|
# Read queries
|
|
with open(args.queries_file) as f:
|
|
queries = [line.strip() for line in f if line.strip()]
|
|
print(f"Loaded {len(queries)} queries from {args.queries_file}")
|
|
|
|
# Create output database
|
|
db = sqlite3.connect(args.output)
|
|
db.enable_load_extension(True)
|
|
db.load_extension(vec_path)
|
|
db.enable_load_extension(False)
|
|
|
|
db.execute("CREATE TABLE train(id INTEGER PRIMARY KEY, vector BLOB)")
|
|
db.execute("CREATE TABLE query_vectors(id INTEGER PRIMARY KEY, vector BLOB)")
|
|
db.execute(
|
|
"CREATE TABLE neighbors("
|
|
" query_vector_id INTEGER, rank INTEGER, neighbors_id TEXT,"
|
|
" UNIQUE(query_vector_id, rank))"
|
|
)
|
|
|
|
# Step 1: Embed headlines -> train table
|
|
print("Embedding headlines...")
|
|
for batch in tqdm(
|
|
batched(headlines, args.batch_size),
|
|
total=(len(headlines) + args.batch_size - 1) // args.batch_size,
|
|
):
|
|
ids = [r[0] for r in batch]
|
|
texts = [r[1] for r in batch]
|
|
embeddings = model.encode(texts)
|
|
|
|
params = [
|
|
(int(rid), array("f", emb.tolist()).tobytes())
|
|
for rid, emb in zip(ids, embeddings)
|
|
]
|
|
db.executemany("INSERT INTO train VALUES (?, ?)", params)
|
|
db.commit()
|
|
|
|
del headlines
|
|
n = db.execute("SELECT count(*) FROM train").fetchone()[0]
|
|
print(f"Embedded {n} headlines")
|
|
|
|
# Step 2: Embed queries -> query_vectors table
|
|
print("Embedding queries...")
|
|
query_embeddings = model.encode(queries)
|
|
query_params = []
|
|
for i, emb in enumerate(query_embeddings, 1):
|
|
blob = array("f", emb.tolist()).tobytes()
|
|
query_params.append((i, blob))
|
|
db.executemany("INSERT INTO query_vectors VALUES (?, ?)", query_params)
|
|
db.commit()
|
|
print(f"Embedded {len(queries)} queries")
|
|
|
|
# Step 3: Brute-force KNN via sqlite-vec -> neighbors table
|
|
n_queries = db.execute("SELECT count(*) FROM query_vectors").fetchone()[0]
|
|
print(f"Computing {args.k}-NN for {n_queries} queries via sqlite-vec...")
|
|
for query_id, query_blob in tqdm(
|
|
db.execute("SELECT id, vector FROM query_vectors").fetchall()
|
|
):
|
|
results = db.execute(
|
|
"""
|
|
SELECT
|
|
train.id,
|
|
vec_distance_cosine(train.vector, ?) AS distance
|
|
FROM train
|
|
WHERE distance IS NOT NULL
|
|
ORDER BY distance ASC
|
|
LIMIT ?
|
|
""",
|
|
(query_blob, args.k),
|
|
).fetchall()
|
|
|
|
params = [
|
|
(query_id, rank, str(rid))
|
|
for rank, (rid, _dist) in enumerate(results)
|
|
]
|
|
db.executemany("INSERT INTO neighbors VALUES (?, ?, ?)", params)
|
|
|
|
db.commit()
|
|
db.close()
|
|
print(f"Done. Wrote {args.output}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|